driftctl/pkg/remote/azurerm/repository/network_test.go

1173 lines
32 KiB
Go

package repository
import (
"context"
"fmt"
"reflect"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/pkg/errors"
"github.com/snyk/driftctl/pkg/remote/cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func Test_ListAllVirtualNetwork_MultiplesResults(t *testing.T) {
expected := []*armnetwork.VirtualNetwork{
{
Resource: armnetwork.Resource{
ID: to.StringPtr("network1"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("network2"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("network3"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("network4"),
},
},
}
fakeClient := &mockVirtualNetworkClient{}
mockPager := &mockVirtualNetworksListAllPager{}
mockPager.On("Err").Return(nil).Times(3)
mockPager.On("NextPage", mock.Anything).Return(true).Times(2)
mockPager.On("NextPage", mock.Anything).Return(false).Times(1)
mockPager.On("PageResponse").Return(armnetwork.VirtualNetworksListAllResponse{
VirtualNetworksListAllResult: armnetwork.VirtualNetworksListAllResult{
VirtualNetworkListResult: armnetwork.VirtualNetworkListResult{
Value: []*armnetwork.VirtualNetwork{
{
Resource: armnetwork.Resource{
ID: to.StringPtr("network1"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("network2"),
},
},
},
},
},
}).Times(1)
mockPager.On("PageResponse").Return(armnetwork.VirtualNetworksListAllResponse{
VirtualNetworksListAllResult: armnetwork.VirtualNetworksListAllResult{
VirtualNetworkListResult: armnetwork.VirtualNetworkListResult{
Value: []*armnetwork.VirtualNetwork{
{
Resource: armnetwork.Resource{
ID: to.StringPtr("network3"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("network4"),
},
},
},
},
},
}).Times(1)
fakeClient.On("ListAll", mock.Anything).Return(mockPager)
c := &cache.MockCache{}
c.On("GetAndLock", "ListAllVirtualNetworks").Return(nil).Times(1)
c.On("Unlock", "ListAllVirtualNetworks").Times(1)
c.On("Put", "ListAllVirtualNetworks", expected).Return(true).Times(1)
s := &networkRepository{
virtualNetworksClient: fakeClient,
cache: c,
}
got, err := s.ListAllVirtualNetworks()
if err != nil {
t.Errorf("ListAllVirtualNetworks() error = %v", err)
return
}
mockPager.AssertExpectations(t)
fakeClient.AssertExpectations(t)
c.AssertExpectations(t)
if !reflect.DeepEqual(got, expected) {
t.Errorf("ListAllVirtualNetworks() got = %v, want %v", got, expected)
}
}
func Test_ListAllVirtualNetwork_MultiplesResults_WithCache(t *testing.T) {
expected := []*armnetwork.VirtualNetwork{
{
Resource: armnetwork.Resource{
ID: to.StringPtr("network3"),
},
},
}
fakeClient := &mockVirtualNetworkClient{}
c := &cache.MockCache{}
c.On("GetAndLock", "ListAllVirtualNetworks").Return(expected).Times(1)
c.On("Unlock", "ListAllVirtualNetworks").Times(1)
s := &networkRepository{
virtualNetworksClient: fakeClient,
cache: c,
}
got, err := s.ListAllVirtualNetworks()
if err != nil {
t.Errorf("ListAllVirtualNetworks() error = %v", err)
return
}
fakeClient.AssertExpectations(t)
c.AssertExpectations(t)
if !reflect.DeepEqual(got, expected) {
t.Errorf("ListAllVirtualNetworks() got = %v, want %v", got, expected)
}
}
func Test_ListAllVirtualNetwork_Error_OnPageResponse(t *testing.T) {
fakeClient := &mockVirtualNetworkClient{}
expectedErr := errors.New("unexpected error")
mockPager := &mockVirtualNetworksListAllPager{}
mockPager.On("Err").Return(expectedErr).Times(1)
mockPager.On("NextPage", mock.Anything).Return(true).Times(1)
mockPager.On("PageResponse").Return(armnetwork.VirtualNetworksListAllResponse{}).Times(1)
fakeClient.On("ListAll", mock.Anything).Return(mockPager)
s := &networkRepository{
virtualNetworksClient: fakeClient,
cache: cache.New(0),
}
got, err := s.ListAllVirtualNetworks()
mockPager.AssertExpectations(t)
fakeClient.AssertExpectations(t)
assert.Equal(t, expectedErr, err)
assert.Nil(t, got)
}
func Test_ListAllVirtualNetwork_Error(t *testing.T) {
fakeClient := &mockVirtualNetworkClient{}
expectedErr := errors.New("unexpected error")
mockPager := &mockVirtualNetworksListAllPager{}
mockPager.On("Err").Return(expectedErr).Times(1)
mockPager.On("NextPage", mock.Anything).Return(false).Times(1)
fakeClient.On("ListAll", mock.Anything).Return(mockPager)
s := &networkRepository{
virtualNetworksClient: fakeClient,
cache: cache.New(0),
}
got, err := s.ListAllVirtualNetworks()
mockPager.AssertExpectations(t)
fakeClient.AssertExpectations(t)
assert.Equal(t, expectedErr, err)
assert.Nil(t, got)
}
func Test_ListAllRouteTables_MultiplesResults(t *testing.T) {
expected := []*armnetwork.RouteTable{
{
Resource: armnetwork.Resource{
ID: to.StringPtr("table1"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("table2"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("table3"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("table4"),
},
},
}
fakeClient := &mockRouteTablesClient{}
mockPager := &mockRouteTablesListAllPager{}
mockPager.On("Err").Return(nil).Times(3)
mockPager.On("NextPage", mock.Anything).Return(true).Times(2)
mockPager.On("NextPage", mock.Anything).Return(false).Times(1)
mockPager.On("PageResponse").Return(armnetwork.RouteTablesListAllResponse{
RouteTablesListAllResult: armnetwork.RouteTablesListAllResult{
RouteTableListResult: armnetwork.RouteTableListResult{
Value: expected[:2],
},
},
}).Times(1)
mockPager.On("PageResponse").Return(armnetwork.RouteTablesListAllResponse{
RouteTablesListAllResult: armnetwork.RouteTablesListAllResult{
RouteTableListResult: armnetwork.RouteTableListResult{
Value: expected[2:],
},
},
}).Times(1)
fakeClient.On("ListAll", mock.Anything).Return(mockPager)
c := &cache.MockCache{}
c.On("GetAndLock", "ListAllRouteTables").Return(nil).Times(1)
c.On("Unlock", "ListAllRouteTables").Times(1)
c.On("Put", "ListAllRouteTables", expected).Return(true).Times(1)
s := &networkRepository{
routeTableClient: fakeClient,
cache: c,
}
got, err := s.ListAllRouteTables()
if err != nil {
t.Errorf("ListAllRouteTables() error = %v", err)
return
}
mockPager.AssertExpectations(t)
fakeClient.AssertExpectations(t)
c.AssertExpectations(t)
if !reflect.DeepEqual(got, expected) {
t.Errorf("ListAllRouteTables() got = %v, want %v", got, expected)
}
}
func Test_ListAllRouteTables_MultiplesResults_WithCache(t *testing.T) {
expected := []*armnetwork.RouteTable{
{
Resource: armnetwork.Resource{
ID: to.StringPtr("table1"),
},
},
}
fakeClient := &mockRouteTablesClient{}
c := &cache.MockCache{}
c.On("GetAndLock", "ListAllRouteTables").Return(expected).Times(1)
c.On("Unlock", "ListAllRouteTables").Times(1)
s := &networkRepository{
routeTableClient: fakeClient,
cache: c,
}
got, err := s.ListAllRouteTables()
if err != nil {
t.Errorf("ListAllRouteTables() error = %v", err)
return
}
fakeClient.AssertExpectations(t)
c.AssertExpectations(t)
if !reflect.DeepEqual(got, expected) {
t.Errorf("ListAllRouteTables() got = %v, want %v", got, expected)
}
}
func Test_ListAllRouteTables_Error_OnPageResponse(t *testing.T) {
fakeClient := &mockRouteTablesClient{}
expectedErr := errors.New("unexpected error")
mockPager := &mockRouteTablesListAllPager{}
mockPager.On("Err").Return(expectedErr).Times(1)
mockPager.On("NextPage", mock.Anything).Return(true).Times(1)
mockPager.On("PageResponse").Return(armnetwork.RouteTablesListAllResponse{}).Times(1)
fakeClient.On("ListAll", mock.Anything).Return(mockPager)
s := &networkRepository{
routeTableClient: fakeClient,
cache: cache.New(0),
}
got, err := s.ListAllRouteTables()
mockPager.AssertExpectations(t)
fakeClient.AssertExpectations(t)
assert.Equal(t, expectedErr, err)
assert.Nil(t, got)
}
func Test_ListAllRouteTables_Error(t *testing.T) {
fakeClient := &mockRouteTablesClient{}
expectedErr := errors.New("unexpected error")
mockPager := &mockRouteTablesListAllPager{}
mockPager.On("Err").Return(expectedErr).Times(1)
mockPager.On("NextPage", mock.Anything).Return(false).Times(1)
fakeClient.On("ListAll", mock.Anything).Return(mockPager)
s := &networkRepository{
routeTableClient: fakeClient,
cache: cache.New(0),
}
got, err := s.ListAllRouteTables()
mockPager.AssertExpectations(t)
fakeClient.AssertExpectations(t)
assert.Equal(t, expectedErr, err)
assert.Nil(t, got)
}
func Test_ListAllSubnets_MultiplesResults(t *testing.T) {
network := &armnetwork.VirtualNetwork{
Resource: armnetwork.Resource{
Name: to.StringPtr("network1"),
ID: to.StringPtr("/subscriptions/7bfb2c5c-0000-0000-0000-fffa356eb406/resourceGroups/test-dev/providers/Microsoft.Network/virtualNetworks/network1"),
},
}
expected := []*armnetwork.Subnet{
{
SubResource: armnetwork.SubResource{
ID: to.StringPtr("subnet1"),
},
},
{
SubResource: armnetwork.SubResource{
ID: to.StringPtr("subnet2"),
},
},
{
SubResource: armnetwork.SubResource{
ID: to.StringPtr("subnet3"),
},
},
{
SubResource: armnetwork.SubResource{
ID: to.StringPtr("subnet4"),
},
},
}
fakeClient := &mockSubnetsClient{}
mockPager := &mockSubnetsListPager{}
mockPager.On("Err").Return(nil).Times(3)
mockPager.On("NextPage", mock.Anything).Return(true).Times(2)
mockPager.On("NextPage", mock.Anything).Return(false).Times(1)
mockPager.On("PageResponse").Return(armnetwork.SubnetsListResponse{
SubnetsListResult: armnetwork.SubnetsListResult{
SubnetListResult: armnetwork.SubnetListResult{
Value: []*armnetwork.Subnet{
{
SubResource: armnetwork.SubResource{
ID: to.StringPtr("subnet1"),
},
},
{
SubResource: armnetwork.SubResource{
ID: to.StringPtr("subnet2"),
},
},
},
},
},
}).Times(1)
mockPager.On("PageResponse").Return(armnetwork.SubnetsListResponse{
SubnetsListResult: armnetwork.SubnetsListResult{
SubnetListResult: armnetwork.SubnetListResult{
Value: []*armnetwork.Subnet{
{
SubResource: armnetwork.SubResource{
ID: to.StringPtr("subnet3"),
},
},
{
SubResource: armnetwork.SubResource{
ID: to.StringPtr("subnet4"),
},
},
},
},
},
}).Times(1)
fakeClient.On("List", "test-dev", "network1", mock.Anything).Return(mockPager)
c := &cache.MockCache{}
cacheKey := fmt.Sprintf("ListAllSubnets_%s", *network.ID)
c.On("Get", cacheKey).Return(nil).Times(1)
c.On("Put", cacheKey, expected).Return(true).Times(1)
s := &networkRepository{
subnetsClient: fakeClient,
cache: c,
}
got, err := s.ListAllSubnets(network)
if err != nil {
t.Errorf("ListAllSubnets() error = %v", err)
return
}
mockPager.AssertExpectations(t)
fakeClient.AssertExpectations(t)
c.AssertExpectations(t)
if !reflect.DeepEqual(got, expected) {
t.Errorf("ListAllSubnets() got = %v, want %v", got, expected)
}
}
func Test_ListAllSubnets_MultiplesResults_WithCache(t *testing.T) {
network := &armnetwork.VirtualNetwork{
Resource: armnetwork.Resource{
ID: to.StringPtr("networkID"),
},
}
expected := []*armnetwork.Subnet{
{
Name: to.StringPtr("network1"),
},
}
fakeClient := &mockSubnetsClient{}
c := &cache.MockCache{}
c.On("Get", "ListAllSubnets_networkID").Return(expected).Times(1)
s := &networkRepository{
subnetsClient: fakeClient,
cache: c,
}
got, err := s.ListAllSubnets(network)
if err != nil {
t.Errorf("ListAllSubnets() error = %v", err)
return
}
fakeClient.AssertExpectations(t)
c.AssertExpectations(t)
if !reflect.DeepEqual(got, expected) {
t.Errorf("ListAllSubnets() got = %v, want %v", got, expected)
}
}
func Test_ListAllSubnets_Error_OnPageResponse(t *testing.T) {
network := &armnetwork.VirtualNetwork{
Resource: armnetwork.Resource{
Name: to.StringPtr("network1"),
ID: to.StringPtr("/subscriptions/7bfb2c5c-0000-0000-0000-fffa356eb406/resourceGroups/test-dev/providers/Microsoft.Network/virtualNetworks/network1"),
},
}
fakeClient := &mockSubnetsClient{}
expectedErr := errors.New("unexpected error")
mockPager := &mockSubnetsListPager{}
mockPager.On("Err").Return(expectedErr).Times(1)
mockPager.On("NextPage", mock.Anything).Return(true).Times(1)
mockPager.On("PageResponse").Return(armnetwork.SubnetsListResponse{}).Times(1)
fakeClient.On("List", "test-dev", "network1", mock.Anything).Return(mockPager)
s := &networkRepository{
subnetsClient: fakeClient,
cache: cache.New(0),
}
got, err := s.ListAllSubnets(network)
mockPager.AssertExpectations(t)
fakeClient.AssertExpectations(t)
assert.Equal(t, expectedErr, err)
assert.Nil(t, got)
}
func Test_ListAllSubnets_Error(t *testing.T) {
network := &armnetwork.VirtualNetwork{
Resource: armnetwork.Resource{
Name: to.StringPtr("network1"),
ID: to.StringPtr("/subscriptions/7bfb2c5c-0000-0000-0000-fffa356eb406/resourceGroups/test-dev/providers/Microsoft.Network/virtualNetworks/network1"),
},
}
fakeClient := &mockSubnetsClient{}
expectedErr := errors.New("unexpected error")
mockPager := &mockSubnetsListPager{}
mockPager.On("Err").Return(expectedErr).Times(1)
mockPager.On("NextPage", mock.Anything).Return(false).Times(1)
fakeClient.On("List", "test-dev", "network1", mock.Anything).Return(mockPager)
s := &networkRepository{
subnetsClient: fakeClient,
cache: cache.New(0),
}
got, err := s.ListAllSubnets(network)
mockPager.AssertExpectations(t)
fakeClient.AssertExpectations(t)
assert.Equal(t, expectedErr, err)
assert.Nil(t, got)
}
func Test_ListAllSubnets_ErrorOnInvalidNetworkID(t *testing.T) {
network := &armnetwork.VirtualNetwork{
Resource: armnetwork.Resource{
Name: to.StringPtr("network1"),
ID: to.StringPtr("foobar"),
},
}
fakeClient := &mockSubnetsClient{}
expectedErr := errors.New("parsing failed for foobar. Invalid resource Id format")
s := &networkRepository{
subnetsClient: fakeClient,
cache: cache.New(0),
}
got, err := s.ListAllSubnets(network)
fakeClient.AssertExpectations(t)
assert.Equal(t, expectedErr.Error(), err.Error())
assert.Nil(t, got)
}
func Test_ListAllFirewalls_MultiplesResults(t *testing.T) {
expected := []*armnetwork.AzureFirewall{
{
Resource: armnetwork.Resource{
ID: to.StringPtr("firewall1"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("firewall2"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("firewall3"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("firewall4"),
},
},
}
fakeClient := &mockFirewallsClient{}
mockPager := &mockFirewallsListAllPager{}
mockPager.On("Err").Return(nil).Times(3)
mockPager.On("NextPage", mock.Anything).Return(true).Times(2)
mockPager.On("NextPage", mock.Anything).Return(false).Times(1)
mockPager.On("PageResponse").Return(armnetwork.AzureFirewallsListAllResponse{
AzureFirewallsListAllResult: armnetwork.AzureFirewallsListAllResult{
AzureFirewallListResult: armnetwork.AzureFirewallListResult{
Value: expected[:2],
},
},
}).Times(1)
mockPager.On("PageResponse").Return(armnetwork.AzureFirewallsListAllResponse{
AzureFirewallsListAllResult: armnetwork.AzureFirewallsListAllResult{
AzureFirewallListResult: armnetwork.AzureFirewallListResult{
Value: expected[2:],
},
},
}).Times(1)
fakeClient.On("ListAll", mock.Anything).Return(mockPager)
c := &cache.MockCache{}
c.On("Get", "ListAllFirewalls").Return(nil).Times(1)
c.On("Put", "ListAllFirewalls", expected).Return(true).Times(1)
s := &networkRepository{
firewallsClient: fakeClient,
cache: c,
}
got, err := s.ListAllFirewalls()
if err != nil {
t.Errorf("ListAllFirewalls() error = %v", err)
return
}
mockPager.AssertExpectations(t)
fakeClient.AssertExpectations(t)
c.AssertExpectations(t)
if !reflect.DeepEqual(got, expected) {
t.Errorf("ListAllFirewalls() got = %v, want %v", got, expected)
}
}
func Test_ListAllFirewalls_MultiplesResults_WithCache(t *testing.T) {
expected := []*armnetwork.AzureFirewall{
{
Resource: armnetwork.Resource{
ID: to.StringPtr("firewall1"),
},
},
}
fakeClient := &mockFirewallsClient{}
c := &cache.MockCache{}
c.On("Get", "ListAllFirewalls").Return(expected).Times(1)
s := &networkRepository{
firewallsClient: fakeClient,
cache: c,
}
got, err := s.ListAllFirewalls()
if err != nil {
t.Errorf("ListAllFirewalls() error = %v", err)
return
}
fakeClient.AssertExpectations(t)
c.AssertExpectations(t)
if !reflect.DeepEqual(got, expected) {
t.Errorf("ListAllFirewalls() got = %v, want %v", got, expected)
}
}
func Test_ListAllFirewalls_Error_OnPageResponse(t *testing.T) {
fakeClient := &mockFirewallsClient{}
expectedErr := errors.New("unexpected error")
mockPager := &mockFirewallsListAllPager{}
mockPager.On("Err").Return(expectedErr).Times(1)
mockPager.On("NextPage", mock.Anything).Return(true).Times(1)
mockPager.On("PageResponse").Return(armnetwork.AzureFirewallsListAllResponse{}).Times(1)
fakeClient.On("ListAll", mock.Anything).Return(mockPager)
s := &networkRepository{
firewallsClient: fakeClient,
cache: cache.New(0),
}
got, err := s.ListAllFirewalls()
mockPager.AssertExpectations(t)
fakeClient.AssertExpectations(t)
assert.Equal(t, expectedErr, err)
assert.Nil(t, got)
}
func Test_ListAllFirewalls_Error(t *testing.T) {
fakeClient := &mockFirewallsClient{}
expectedErr := errors.New("unexpected error")
mockPager := &mockFirewallsListAllPager{}
mockPager.On("Err").Return(expectedErr).Times(1)
mockPager.On("NextPage", mock.Anything).Return(false).Times(1)
fakeClient.On("ListAll", mock.Anything).Return(mockPager)
s := &networkRepository{
firewallsClient: fakeClient,
cache: cache.New(0),
}
got, err := s.ListAllFirewalls()
mockPager.AssertExpectations(t)
fakeClient.AssertExpectations(t)
assert.Equal(t, expectedErr, err)
assert.Nil(t, got)
}
func Test_ListAllPublicIPAddresses_MultiplesResults(t *testing.T) {
expected := []*armnetwork.PublicIPAddress{
{
Resource: armnetwork.Resource{
ID: to.StringPtr("ip1"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("ip2"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("ip3"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("ip4"),
},
},
}
fakeClient := &mockPublicIPAddressesClient{}
mockPager := &mockPublicIPAddressesListAllPager{}
mockPager.On("Err").Return(nil).Times(3)
mockPager.On("NextPage", mock.Anything).Return(true).Times(2)
mockPager.On("NextPage", mock.Anything).Return(false).Times(1)
mockPager.On("PageResponse").Return(armnetwork.PublicIPAddressesListAllResponse{
PublicIPAddressesListAllResult: armnetwork.PublicIPAddressesListAllResult{
PublicIPAddressListResult: armnetwork.PublicIPAddressListResult{
Value: expected[:2],
},
},
}).Times(1)
mockPager.On("PageResponse").Return(armnetwork.PublicIPAddressesListAllResponse{
PublicIPAddressesListAllResult: armnetwork.PublicIPAddressesListAllResult{
PublicIPAddressListResult: armnetwork.PublicIPAddressListResult{
Value: expected[2:],
},
},
}).Times(1)
fakeClient.On("ListAll", mock.Anything).Return(mockPager)
c := &cache.MockCache{}
c.On("Get", "ListAllPublicIPAddresses").Return(nil).Times(1)
c.On("Put", "ListAllPublicIPAddresses", expected).Return(true).Times(1)
s := &networkRepository{
publicIPAddressesClient: fakeClient,
cache: c,
}
got, err := s.ListAllPublicIPAddresses()
if err != nil {
t.Errorf("ListAllPublicIPAddresses() error = %v", err)
return
}
mockPager.AssertExpectations(t)
fakeClient.AssertExpectations(t)
c.AssertExpectations(t)
if !reflect.DeepEqual(got, expected) {
t.Errorf("ListAllPublicIPAddresses() got = %v, want %v", got, expected)
}
}
func Test_ListAllPublicIPAddresses_MultiplesResults_WithCache(t *testing.T) {
expected := []*armnetwork.PublicIPAddress{
{
Resource: armnetwork.Resource{
ID: to.StringPtr("ip1"),
},
},
}
fakeClient := &mockPublicIPAddressesClient{}
c := &cache.MockCache{}
c.On("Get", "ListAllPublicIPAddresses").Return(expected).Times(1)
s := &networkRepository{
publicIPAddressesClient: fakeClient,
cache: c,
}
got, err := s.ListAllPublicIPAddresses()
if err != nil {
t.Errorf("ListAllPublicIPAddresses() error = %v", err)
return
}
fakeClient.AssertExpectations(t)
c.AssertExpectations(t)
if !reflect.DeepEqual(got, expected) {
t.Errorf("ListAllPublicIPAddresses() got = %v, want %v", got, expected)
}
}
func Test_ListAllPublicIPAddresses_Error_OnPageResponse(t *testing.T) {
fakeClient := &mockPublicIPAddressesClient{}
expectedErr := errors.New("unexpected error")
mockPager := &mockPublicIPAddressesListAllPager{}
mockPager.On("Err").Return(expectedErr).Times(1)
mockPager.On("NextPage", mock.Anything).Return(true).Times(1)
mockPager.On("PageResponse").Return(armnetwork.PublicIPAddressesListAllResponse{}).Times(1)
fakeClient.On("ListAll", mock.Anything).Return(mockPager)
s := &networkRepository{
publicIPAddressesClient: fakeClient,
cache: cache.New(0),
}
got, err := s.ListAllPublicIPAddresses()
mockPager.AssertExpectations(t)
fakeClient.AssertExpectations(t)
assert.Equal(t, expectedErr, err)
assert.Nil(t, got)
}
func Test_ListAllPublicIPAddresses_Error(t *testing.T) {
fakeClient := &mockPublicIPAddressesClient{}
expectedErr := errors.New("unexpected error")
mockPager := &mockPublicIPAddressesListAllPager{}
mockPager.On("Err").Return(expectedErr).Times(1)
mockPager.On("NextPage", mock.Anything).Return(false).Times(1)
fakeClient.On("ListAll", mock.Anything).Return(mockPager)
s := &networkRepository{
publicIPAddressesClient: fakeClient,
cache: cache.New(0),
}
got, err := s.ListAllPublicIPAddresses()
mockPager.AssertExpectations(t)
fakeClient.AssertExpectations(t)
assert.Equal(t, expectedErr, err)
assert.Nil(t, got)
}
func Test_Network_ListAllSecurityGroups(t *testing.T) {
expectedResults := []*armnetwork.NetworkSecurityGroup{
{
Resource: armnetwork.Resource{
ID: to.StringPtr("sgroup-1"),
Name: to.StringPtr("sgroup-1"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("sgroup-2"),
Name: to.StringPtr("sgroup-2"),
},
},
}
testcases := []struct {
name string
mocks func(*mockNetworkSecurityGroupsListAllPager, *cache.MockCache)
expected []*armnetwork.NetworkSecurityGroup
wantErr string
}{
{
name: "should return security groups",
mocks: func(pager *mockNetworkSecurityGroupsListAllPager, mockCache *cache.MockCache) {
pager.On("NextPage", context.Background()).Return(true).Times(1)
pager.On("NextPage", context.Background()).Return(false).Times(1)
pager.On("PageResponse").Return(armnetwork.NetworkSecurityGroupsListAllResponse{
NetworkSecurityGroupsListAllResult: armnetwork.NetworkSecurityGroupsListAllResult{
NetworkSecurityGroupListResult: armnetwork.NetworkSecurityGroupListResult{
Value: expectedResults,
},
},
}).Times(1)
pager.On("Err").Return(nil).Times(2)
mockCache.On("Get", "networkListAllSecurityGroups").Return(nil).Times(1)
mockCache.On("Put", "networkListAllSecurityGroups", expectedResults).Return(false).Times(1)
},
expected: expectedResults,
},
{
name: "should hit cache and return security groups",
mocks: func(pager *mockNetworkSecurityGroupsListAllPager, mockCache *cache.MockCache) {
mockCache.On("Get", "networkListAllSecurityGroups").Return(expectedResults).Times(1)
},
expected: expectedResults,
},
{
name: "should return remote error",
mocks: func(pager *mockNetworkSecurityGroupsListAllPager, mockCache *cache.MockCache) {
pager.On("NextPage", context.Background()).Return(true).Times(1)
pager.On("NextPage", context.Background()).Return(false).Times(1)
pager.On("PageResponse").Return(armnetwork.NetworkSecurityGroupsListAllResponse{}).Times(1)
pager.On("Err").Return(errors.New("remote error")).Times(1)
mockCache.On("Get", "networkListAllSecurityGroups").Return(nil).Times(1)
},
wantErr: "remote error",
},
}
for _, tt := range testcases {
t.Run(tt.name, func(t *testing.T) {
fakePager := &mockNetworkSecurityGroupsListAllPager{}
fakeClient := &mockNetworkSecurityGroupsClient{}
mockCache := &cache.MockCache{}
fakeClient.On("ListAll", (*armnetwork.NetworkSecurityGroupsListAllOptions)(nil)).Return(fakePager).Maybe()
tt.mocks(fakePager, mockCache)
s := &networkRepository{
networkSecurityGroupsClient: fakeClient,
cache: mockCache,
}
got, err := s.ListAllSecurityGroups()
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
} else {
assert.Nil(t, err)
}
fakeClient.AssertExpectations(t)
mockCache.AssertExpectations(t)
if !reflect.DeepEqual(got, tt.expected) {
t.Errorf("ListAllSecurityGroups() got = %v, want %v", got, tt.expected)
}
})
}
}
func Test_Network_ListAllLoadBalancers(t *testing.T) {
expectedResults := []*armnetwork.LoadBalancer{
{
Resource: armnetwork.Resource{
ID: to.StringPtr("lb-1"),
Name: to.StringPtr("lb-1"),
},
},
{
Resource: armnetwork.Resource{
ID: to.StringPtr("lb-2"),
Name: to.StringPtr("lb-2"),
},
},
}
testcases := []struct {
name string
mocks func(*mockLoadBalancersListAllPager, *cache.MockCache)
expected []*armnetwork.LoadBalancer
wantErr string
}{
{
name: "should return load balancers",
mocks: func(pager *mockLoadBalancersListAllPager, mockCache *cache.MockCache) {
pager.On("NextPage", context.Background()).Return(true).Times(1)
pager.On("NextPage", context.Background()).Return(false).Times(1)
pager.On("PageResponse").Return(armnetwork.LoadBalancersListAllResponse{
LoadBalancersListAllResult: armnetwork.LoadBalancersListAllResult{
LoadBalancerListResult: armnetwork.LoadBalancerListResult{
Value: expectedResults,
},
},
}).Times(1)
pager.On("Err").Return(nil).Times(2)
mockCache.On("GetAndLock", "networkListAllLoadBalancers").Return(nil).Times(1)
mockCache.On("Put", "networkListAllLoadBalancers", expectedResults).Return(false).Times(1)
mockCache.On("Unlock", "networkListAllLoadBalancers").Return(nil).Times(1)
},
expected: expectedResults,
},
{
name: "should hit cache and return load balancers",
mocks: func(pager *mockLoadBalancersListAllPager, mockCache *cache.MockCache) {
mockCache.On("GetAndLock", "networkListAllLoadBalancers").Return(expectedResults).Times(1)
mockCache.On("Unlock", "networkListAllLoadBalancers").Return(nil).Times(1)
},
expected: expectedResults,
},
{
name: "should return remote error",
mocks: func(pager *mockLoadBalancersListAllPager, mockCache *cache.MockCache) {
pager.On("NextPage", context.Background()).Return(true).Times(1)
pager.On("NextPage", context.Background()).Return(false).Times(1)
pager.On("PageResponse").Return(armnetwork.LoadBalancersListAllResponse{}).Times(1)
pager.On("Err").Return(errors.New("remote error")).Times(1)
mockCache.On("GetAndLock", "networkListAllLoadBalancers").Return(nil).Times(1)
mockCache.On("Unlock", "networkListAllLoadBalancers").Return(nil).Times(1)
},
wantErr: "remote error",
},
}
for _, tt := range testcases {
t.Run(tt.name, func(t *testing.T) {
fakePager := &mockLoadBalancersListAllPager{}
fakeClient := &mockLoadBalancersClient{}
mockCache := &cache.MockCache{}
fakeClient.On("ListAll", (*armnetwork.LoadBalancersListAllOptions)(nil)).Return(fakePager).Maybe()
tt.mocks(fakePager, mockCache)
s := &networkRepository{
loadBalancersClient: fakeClient,
cache: mockCache,
}
got, err := s.ListAllLoadBalancers()
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
} else {
assert.Nil(t, err)
}
fakeClient.AssertExpectations(t)
mockCache.AssertExpectations(t)
if !reflect.DeepEqual(got, tt.expected) {
t.Errorf("ListAllLoadBalancers() got = %v, want %v", got, tt.expected)
}
})
}
}
func Test_Network_ListLoadBalancerRules(t *testing.T) {
expectedResults := []*armnetwork.LoadBalancingRule{
{
SubResource: armnetwork.SubResource{
ID: to.StringPtr("lbrule-1"),
},
Name: to.StringPtr("lbrule-1"),
},
{
SubResource: armnetwork.SubResource{
ID: to.StringPtr("lbrule-1"),
},
Name: to.StringPtr("lbrule-1"),
},
}
testcases := []struct {
name string
loadBalancer *armnetwork.LoadBalancer
mocks func(*mockLoadBalancerRulesClient, *mockLoadBalancerRulesListAllPager, *cache.MockCache)
expected []*armnetwork.LoadBalancingRule
wantErr string
}{
{
name: "should return load balancer rules",
loadBalancer: &armnetwork.LoadBalancer{
Resource: armnetwork.Resource{ID: to.StringPtr("/subscriptions/xxx/resourceGroups/driftctl/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress")},
},
mocks: func(client *mockLoadBalancerRulesClient, pager *mockLoadBalancerRulesListAllPager, mockCache *cache.MockCache) {
client.On("List", "driftctl", "PublicIPAddress", &armnetwork.LoadBalancerLoadBalancingRulesListOptions{}).Return(pager)
pager.On("NextPage", context.Background()).Return(true).Times(1)
pager.On("NextPage", context.Background()).Return(false).Times(1)
pager.On("PageResponse").Return(armnetwork.LoadBalancerLoadBalancingRulesListResponse{
LoadBalancerLoadBalancingRulesListResult: armnetwork.LoadBalancerLoadBalancingRulesListResult{
LoadBalancerLoadBalancingRuleListResult: armnetwork.LoadBalancerLoadBalancingRuleListResult{
Value: expectedResults,
},
},
}).Times(1)
pager.On("Err").Return(nil).Times(2)
mockCache.On("Get", "networkListLoadBalancerRules_/subscriptions/xxx/resourceGroups/driftctl/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress").Return(nil).Times(1)
mockCache.On("Put", "networkListLoadBalancerRules_/subscriptions/xxx/resourceGroups/driftctl/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress", expectedResults).Return(false).Times(1)
},
expected: expectedResults,
},
{
name: "should hit cache and return load balancers",
loadBalancer: &armnetwork.LoadBalancer{
Resource: armnetwork.Resource{ID: to.StringPtr("lb-1")},
},
mocks: func(client *mockLoadBalancerRulesClient, pager *mockLoadBalancerRulesListAllPager, mockCache *cache.MockCache) {
mockCache.On("Get", "networkListLoadBalancerRules_lb-1").Return(expectedResults).Times(1)
},
expected: expectedResults,
},
{
name: "should return remote error",
loadBalancer: &armnetwork.LoadBalancer{
Resource: armnetwork.Resource{ID: to.StringPtr("/subscriptions/xxx/resourceGroups/driftctl/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress")},
},
mocks: func(client *mockLoadBalancerRulesClient, pager *mockLoadBalancerRulesListAllPager, mockCache *cache.MockCache) {
client.On("List", "driftctl", "PublicIPAddress", &armnetwork.LoadBalancerLoadBalancingRulesListOptions{}).Return(pager)
pager.On("NextPage", context.Background()).Return(true).Times(1)
pager.On("NextPage", context.Background()).Return(false).Times(1)
pager.On("PageResponse").Return(armnetwork.LoadBalancerLoadBalancingRulesListResponse{}).Times(1)
pager.On("Err").Return(errors.New("remote error")).Times(1)
mockCache.On("Get", "networkListLoadBalancerRules_/subscriptions/xxx/resourceGroups/driftctl/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress").Return(nil).Times(1)
},
wantErr: "remote error",
},
}
for _, tt := range testcases {
t.Run(tt.name, func(t *testing.T) {
fakePager := &mockLoadBalancerRulesListAllPager{}
fakeClient := &mockLoadBalancerRulesClient{}
mockCache := &cache.MockCache{}
tt.mocks(fakeClient, fakePager, mockCache)
s := &networkRepository{
loadBalancerRulesClient: fakeClient,
cache: mockCache,
}
got, err := s.ListLoadBalancerRules(tt.loadBalancer)
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
} else {
assert.Nil(t, err)
}
fakeClient.AssertExpectations(t)
mockCache.AssertExpectations(t)
if !reflect.DeepEqual(got, tt.expected) {
t.Errorf("ListAllLoadBalancers() got = %v, want %v", got, tt.expected)
}
})
}
}