From b4a3458cc329229f27d84db990f2c4267701331f Mon Sep 17 00:00:00 2001 From: Louis TOUSSAINT Date: Thu, 20 May 2021 17:51:36 +0200 Subject: [PATCH] Issue 165: Add vpc_supplier to ec2_repository --- pkg/remote/aws/repository/ec2_repository.go | 23 +++++ .../aws/repository/ec2_repository_test.go | 90 +++++++++++++++++++ .../aws/repository/mock_EC2Repository.go | 32 +++++++ pkg/remote/aws/vpc_supplier.go | 22 +---- pkg/remote/aws/vpc_supplier_test.go | 70 ++++++--------- 5 files changed, 175 insertions(+), 62 deletions(-) diff --git a/pkg/remote/aws/repository/ec2_repository.go b/pkg/remote/aws/repository/ec2_repository.go index 244d8525..0c9a6a3b 100644 --- a/pkg/remote/aws/repository/ec2_repository.go +++ b/pkg/remote/aws/repository/ec2_repository.go @@ -20,6 +20,7 @@ type EC2Repository interface { ListAllSubnets() ([]*ec2.Subnet, []*ec2.Subnet, error) ListAllNatGateways() ([]*ec2.NatGateway, error) ListAllRouteTables() ([]*ec2.RouteTable, error) + ListAllVPCs() ([]*ec2.Vpc, []*ec2.Vpc, error) } type EC2Client interface { @@ -232,3 +233,25 @@ func (r *ec2Repository) ListAllRouteTables() ([]*ec2.RouteTable, error) { return routeTables, nil } + +func (r *ec2Repository) ListAllVPCs() ([]*ec2.Vpc, []*ec2.Vpc, error) { + input := ec2.DescribeVpcsInput{} + var VPCs []*ec2.Vpc + var defaultVPCs []*ec2.Vpc + err := r.client.DescribeVpcsPages(&input, + func(resp *ec2.DescribeVpcsOutput, lastPage bool) bool { + for _, vpc := range resp.Vpcs { + if vpc.IsDefault != nil && *vpc.IsDefault { + defaultVPCs = append(defaultVPCs, vpc) + continue + } + VPCs = append(VPCs, vpc) + } + return !lastPage + }, + ) + if err != nil { + return nil, nil, err + } + return VPCs, defaultVPCs, nil +} diff --git a/pkg/remote/aws/repository/ec2_repository_test.go b/pkg/remote/aws/repository/ec2_repository_test.go index ff4bbea1..da34293a 100644 --- a/pkg/remote/aws/repository/ec2_repository_test.go +++ b/pkg/remote/aws/repository/ec2_repository_test.go @@ -946,3 +946,93 @@ func Test_ec2Repository_ListAllRouteTables(t *testing.T) { }) } } + +func Test_ec2Repository_ListAllVPCs(t *testing.T) { + tests := []struct { + name string + mocks func(client *MockEC2Client) + wantVPC []*ec2.Vpc + wantDefaultVPC []*ec2.Vpc + wantErr error + }{ + { + name: "mixed default VPC and VPC", + mocks: func(client *MockEC2Client) { + client.On("DescribeVpcsPages", + &ec2.DescribeVpcsInput{}, + mock.MatchedBy(func(callback func(res *ec2.DescribeVpcsOutput, lastPage bool) bool) bool { + callback(&ec2.DescribeVpcsOutput{ + Vpcs: []*ec2.Vpc{ + { + VpcId: aws.String("vpc-a8c5d4c1"), + IsDefault: aws.Bool(true), + }, + { + VpcId: aws.String("vpc-0768e1fd0029e3fc3"), + }, + { + VpcId: aws.String("vpc-020b072316a95b97f"), + IsDefault: aws.Bool(false), + }, + }, + }, false) + callback(&ec2.DescribeVpcsOutput{ + Vpcs: []*ec2.Vpc{ + { + VpcId: aws.String("vpc-02c50896b59598761"), + IsDefault: aws.Bool(false), + }, + }, + }, true) + return true + })).Return(nil) + }, + wantVPC: []*ec2.Vpc{ + { + VpcId: aws.String("vpc-0768e1fd0029e3fc3"), + }, + { + VpcId: aws.String("vpc-020b072316a95b97f"), + IsDefault: aws.Bool(false), + }, + { + VpcId: aws.String("vpc-02c50896b59598761"), + IsDefault: aws.Bool(false), + }, + }, + wantDefaultVPC: []*ec2.Vpc{ + { + VpcId: aws.String("vpc-a8c5d4c1"), + IsDefault: aws.Bool(true), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &MockEC2Client{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + } + gotVPCs, gotDefaultVPCs, err := r.ListAllVPCs() + assert.Equal(t, tt.wantErr, err) + changelog, err := diff.Diff(gotVPCs, tt.wantVPC) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + changelog, err = diff.Diff(gotDefaultVPCs, tt.wantDefaultVPC) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} diff --git a/pkg/remote/aws/repository/mock_EC2Repository.go b/pkg/remote/aws/repository/mock_EC2Repository.go index c7ff0098..01308af4 100644 --- a/pkg/remote/aws/repository/mock_EC2Repository.go +++ b/pkg/remote/aws/repository/mock_EC2Repository.go @@ -251,6 +251,38 @@ func (_m *MockEC2Repository) ListAllSubnets() ([]*ec2.Subnet, []*ec2.Subnet, err return r0, r1, r2 } +// ListAllVPCs provides a mock function with given fields: +func (_m *MockEC2Repository) ListAllVPCs() ([]*ec2.Vpc, []*ec2.Vpc, error) { + ret := _m.Called() + + var r0 []*ec2.Vpc + if rf, ok := ret.Get(0).(func() []*ec2.Vpc); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*ec2.Vpc) + } + } + + var r1 []*ec2.Vpc + if rf, ok := ret.Get(1).(func() []*ec2.Vpc); ok { + r1 = rf() + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([]*ec2.Vpc) + } + } + + var r2 error + if rf, ok := ret.Get(2).(func() error); ok { + r2 = rf() + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + // ListAllVolumes provides a mock function with given fields: func (_m *MockEC2Repository) ListAllVolumes() ([]*ec2.Volume, error) { ret := _m.Called() diff --git a/pkg/remote/aws/vpc_supplier.go b/pkg/remote/aws/vpc_supplier.go index 958a0200..f4a5ae2c 100644 --- a/pkg/remote/aws/vpc_supplier.go +++ b/pkg/remote/aws/vpc_supplier.go @@ -2,7 +2,7 @@ package aws import ( "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/cloudskiff/driftctl/pkg/remote/aws/repository" remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error" "github.com/cloudskiff/driftctl/pkg/resource/aws" @@ -18,7 +18,7 @@ import ( type VPCSupplier struct { reader terraform.ResourceReader deserializer *resource.Deserializer - client ec2iface.EC2API + client repository.EC2Repository defaultVPCRunner *terraform.ParallelResourceReader vpcRunner *terraform.ParallelResourceReader } @@ -27,28 +27,14 @@ func NewVPCSupplier(provider *AWSTerraformProvider, deserializer *resource.Deser return &VPCSupplier{ provider, deserializer, - ec2.New(provider.session), + repository.NewEC2Repository(provider.session), terraform.NewParallelResourceReader(provider.Runner().SubRunner()), terraform.NewParallelResourceReader(provider.Runner().SubRunner()), } } func (s *VPCSupplier) Resources() ([]resource.Resource, error) { - input := ec2.DescribeVpcsInput{} - var VPCs []*ec2.Vpc - var defaultVPCs []*ec2.Vpc - err := s.client.DescribeVpcsPages(&input, - func(resp *ec2.DescribeVpcsOutput, lastPage bool) bool { - for _, vpc := range resp.Vpcs { - if vpc.IsDefault != nil && *vpc.IsDefault { - defaultVPCs = append(defaultVPCs, vpc) - continue - } - VPCs = append(VPCs, vpc) - } - return !lastPage - }, - ) + VPCs, defaultVPCs, err := s.client.ListAllVPCs() if err != nil { return nil, remoteerror.NewResourceEnumerationError(err, aws.AwsVpcResourceType) diff --git a/pkg/remote/aws/vpc_supplier_test.go b/pkg/remote/aws/vpc_supplier_test.go index 2903fa72..cf011a24 100644 --- a/pkg/remote/aws/vpc_supplier_test.go +++ b/pkg/remote/aws/vpc_supplier_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/cloudskiff/driftctl/pkg/remote/aws/repository" remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error" awstest "github.com/cloudskiff/driftctl/test/aws" testresource "github.com/cloudskiff/driftctl/test/resource" @@ -31,66 +32,47 @@ func TestVPCSupplier_Resources(t *testing.T) { cases := []struct { test string dirName string - mocks func(client *awstest.MockFakeEC2) + mocks func(client *repository.MockEC2Repository) err error }{ { test: "no VPC", dirName: "vpc_empty", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeVpcsPages", - &ec2.DescribeVpcsInput{}, - mock.MatchedBy(func(callback func(res *ec2.DescribeVpcsOutput, lastPage bool) bool) bool { - callback(&ec2.DescribeVpcsOutput{}, true) - return true - })).Return(nil) + mocks: func(client *repository.MockEC2Repository) { + client.On("ListAllVPCs").Once().Return([]*ec2.Vpc{}, []*ec2.Vpc{}, nil) }, err: nil, }, { test: "mixed default VPC and VPC", dirName: "vpc", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeVpcsPages", - &ec2.DescribeVpcsInput{}, - mock.MatchedBy(func(callback func(res *ec2.DescribeVpcsOutput, lastPage bool) bool) bool { - callback(&ec2.DescribeVpcsOutput{ - Vpcs: []*ec2.Vpc{ - { - VpcId: aws.String("vpc-a8c5d4c1"), - IsDefault: aws.Bool(true), - }, - { - VpcId: aws.String("vpc-0768e1fd0029e3fc3"), - }, - { - VpcId: aws.String("vpc-020b072316a95b97f"), - IsDefault: aws.Bool(false), - }, - }, - }, false) - callback(&ec2.DescribeVpcsOutput{ - Vpcs: []*ec2.Vpc{ - { - VpcId: aws.String("vpc-02c50896b59598761"), - IsDefault: aws.Bool(false), - }, - }, - }, true) - return true - })).Return(nil) + mocks: func(client *repository.MockEC2Repository) { + client.On("ListAllVPCs").Once().Return([]*ec2.Vpc{ + { + VpcId: aws.String("vpc-0768e1fd0029e3fc3"), + }, + { + VpcId: aws.String("vpc-020b072316a95b97f"), + IsDefault: aws.Bool(false), + }, + { + VpcId: aws.String("vpc-02c50896b59598761"), + IsDefault: aws.Bool(false), + }, + }, []*ec2.Vpc{ + { + VpcId: aws.String("vpc-a8c5d4c1"), + IsDefault: aws.Bool(true), + }, + }, nil) }, err: nil, }, { test: "cannot list VPC", dirName: "vpc_empty", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeVpcsPages", - &ec2.DescribeVpcsInput{}, - mock.MatchedBy(func(callback func(res *ec2.DescribeVpcsOutput, lastPage bool) bool) bool { - return true - })).Return(awserr.NewRequestFailure(nil, 403, "")) + mocks: func(client *repository.MockEC2Repository) { + client.On("ListAllVPCs").Once().Return(nil, nil, awserr.NewRequestFailure(nil, 403, "")) }, err: remoteerror.NewResourceEnumerationError(awserr.NewRequestFailure(nil, 403, ""), resourceaws.AwsVpcResourceType), }, @@ -115,7 +97,7 @@ func TestVPCSupplier_Resources(t *testing.T) { } t.Run(c.test, func(tt *testing.T) { - fakeEC2 := awstest.MockFakeEC2{} + fakeEC2 := repository.MockEC2Repository{} c.mocks(&fakeEC2) provider := mocks2.NewMockedGoldenTFProvider(c.dirName, providerLibrary.Provider(terraform.AWS), shouldUpdate) s := &VPCSupplier{