From 5880f641c58e3a91d37cffe91ae964be8de404a5 Mon Sep 17 00:00:00 2001 From: Elie Date: Tue, 20 Jul 2021 17:51:51 +0200 Subject: [PATCH] Add allocation_id to nat_gw and eip_assoc --- pkg/remote/aws/ec2_eip_association_enumerator.go | 12 +++++++----- pkg/remote/aws/ec2_nat_gateway_enumerator.go | 10 +++++++++- pkg/remote/aws/repository/ec2_repository.go | 11 ++++++----- pkg/remote/aws/repository/ec2_repository_test.go | 14 +++++++------- pkg/remote/aws/repository/mock_EC2Repository.go | 12 ++++++------ pkg/remote/ec2_scanner_test.go | 9 ++++++--- 6 files changed, 41 insertions(+), 27 deletions(-) diff --git a/pkg/remote/aws/ec2_eip_association_enumerator.go b/pkg/remote/aws/ec2_eip_association_enumerator.go index 6d1beeeb..828d553f 100644 --- a/pkg/remote/aws/ec2_eip_association_enumerator.go +++ b/pkg/remote/aws/ec2_eip_association_enumerator.go @@ -24,20 +24,22 @@ func (e *EC2EipAssociationEnumerator) SupportedType() resource.ResourceType { } func (e *EC2EipAssociationEnumerator) Enumerate() ([]resource.Resource, error) { - associationIds, err := e.repository.ListAllAddressesAssociation() + addresses, err := e.repository.ListAllAddressesAssociation() if err != nil { return nil, remoteerror.NewResourceEnumerationError(err, string(e.SupportedType())) } - results := make([]resource.Resource, len(associationIds)) + results := make([]resource.Resource, 0, len(addresses)) - for _, associationId := range associationIds { + for _, address := range addresses { results = append( results, e.factory.CreateAbstractResource( string(e.SupportedType()), - associationId, - map[string]interface{}{}, + *address.AssociationId, + map[string]interface{}{ + "allocation_id": *address.AllocationId, + }, ), ) } diff --git a/pkg/remote/aws/ec2_nat_gateway_enumerator.go b/pkg/remote/aws/ec2_nat_gateway_enumerator.go index 5539c69b..3200356d 100644 --- a/pkg/remote/aws/ec2_nat_gateway_enumerator.go +++ b/pkg/remote/aws/ec2_nat_gateway_enumerator.go @@ -32,12 +32,20 @@ func (e *EC2NatGatewayEnumerator) Enumerate() ([]resource.Resource, error) { results := make([]resource.Resource, len(natGateways)) for _, natGateway := range natGateways { + + attrs := map[string]interface{}{} + if len(natGateway.NatGatewayAddresses) > 0 { + if allocId := natGateway.NatGatewayAddresses[0].AllocationId; allocId != nil { + attrs["allocation_id"] = *allocId + } + } + results = append( results, e.factory.CreateAbstractResource( string(e.SupportedType()), *natGateway.NatGatewayId, - map[string]interface{}{}, + attrs, ), ) } diff --git a/pkg/remote/aws/repository/ec2_repository.go b/pkg/remote/aws/repository/ec2_repository.go index 15bdef07..4db59755 100644 --- a/pkg/remote/aws/repository/ec2_repository.go +++ b/pkg/remote/aws/repository/ec2_repository.go @@ -13,7 +13,7 @@ type EC2Repository interface { ListAllSnapshots() ([]*ec2.Snapshot, error) ListAllVolumes() ([]*ec2.Volume, error) ListAllAddresses() ([]*ec2.Address, error) - ListAllAddressesAssociation() ([]string, error) + ListAllAddressesAssociation() ([]*ec2.Address, error) ListAllInstances() ([]*ec2.Instance, error) ListAllKeyPairs() ([]*ec2.KeyPairInfo, error) ListAllInternetGateways() ([]*ec2.InternetGateway, error) @@ -108,19 +108,20 @@ func (r *ec2Repository) ListAllAddresses() ([]*ec2.Address, error) { return response.Addresses, nil } -func (r *ec2Repository) ListAllAddressesAssociation() ([]string, error) { +func (r *ec2Repository) ListAllAddressesAssociation() ([]*ec2.Address, error) { if v := r.cache.Get("ec2ListAllAddressesAssociation"); v != nil { - return v.([]string), nil + return v.([]*ec2.Address), nil } - results := make([]string, 0) addresses, err := r.ListAllAddresses() if err != nil { return nil, err } + results := make([]*ec2.Address, 0, len(addresses)) + for _, address := range addresses { if address.AssociationId != nil { - results = append(results, aws.StringValue(address.AssociationId)) + results = append(results, address) } } r.cache.Put("ec2ListAllAddressesAssociation", results) diff --git a/pkg/remote/aws/repository/ec2_repository_test.go b/pkg/remote/aws/repository/ec2_repository_test.go index b2215ff6..7408b89d 100644 --- a/pkg/remote/aws/repository/ec2_repository_test.go +++ b/pkg/remote/aws/repository/ec2_repository_test.go @@ -299,7 +299,7 @@ func Test_ec2Repository_ListAllAddressesAssociation(t *testing.T) { tests := []struct { name string mocks func(client *awstest.MockFakeEC2) - want []string + want []*ec2.Address wantErr error }{ { @@ -315,11 +315,11 @@ func Test_ec2Repository_ListAllAddressesAssociation(t *testing.T) { }, }, nil).Once() }, - want: []string{ - "1", - "2", - "3", - "4", + want: []*ec2.Address{ + {AssociationId: aws.String("1")}, + {AssociationId: aws.String("2")}, + {AssociationId: aws.String("3")}, + {AssociationId: aws.String("4")}, }, }, } @@ -340,7 +340,7 @@ func Test_ec2Repository_ListAllAddressesAssociation(t *testing.T) { cachedData, err := r.ListAllAddressesAssociation() assert.NoError(t, err) assert.Equal(t, got, cachedData) - assert.IsType(t, []string{}, store.Get("ec2ListAllAddressesAssociation")) + assert.IsType(t, []*ec2.Address{}, store.Get("ec2ListAllAddressesAssociation")) } changelog, err := diff.Diff(got, tt.want) diff --git a/pkg/remote/aws/repository/mock_EC2Repository.go b/pkg/remote/aws/repository/mock_EC2Repository.go index 258e4e5a..cba89fe1 100644 --- a/pkg/remote/aws/repository/mock_EC2Repository.go +++ b/pkg/remote/aws/repository/mock_EC2Repository.go @@ -1,4 +1,4 @@ -// Code generated by mockery v1.0.0. DO NOT EDIT. +// Code generated by mockery v0.0.0-dev. DO NOT EDIT. package repository @@ -7,7 +7,7 @@ import ( mock "github.com/stretchr/testify/mock" ) -// MockEC2Repository is an autogenerated mock type for the MockEC2Repository type +// MockEC2Repository is an autogenerated mock type for the EC2Repository type type MockEC2Repository struct { mock.Mock } @@ -36,15 +36,15 @@ func (_m *MockEC2Repository) ListAllAddresses() ([]*ec2.Address, error) { } // ListAllAddressesAssociation provides a mock function with given fields: -func (_m *MockEC2Repository) ListAllAddressesAssociation() ([]string, error) { +func (_m *MockEC2Repository) ListAllAddressesAssociation() ([]*ec2.Address, error) { ret := _m.Called() - var r0 []string - if rf, ok := ret.Get(0).(func() []string); ok { + var r0 []*ec2.Address + if rf, ok := ret.Get(0).(func() []*ec2.Address); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]string) + r0 = ret.Get(0).([]*ec2.Address) } } diff --git a/pkg/remote/ec2_scanner_test.go b/pkg/remote/ec2_scanner_test.go index d7c15a40..51b90382 100644 --- a/pkg/remote/ec2_scanner_test.go +++ b/pkg/remote/ec2_scanner_test.go @@ -479,15 +479,18 @@ func TestEC2EipAssociation(t *testing.T) { test: "no eip associations", dirName: "aws_ec2_eip_association_empty", mocks: func(repository *repository.MockEC2Repository) { - repository.On("ListAllAddressesAssociation").Return([]string{}, nil) + repository.On("ListAllAddressesAssociation").Return([]*ec2.Address{}, nil) }, }, { test: "single eip association", dirName: "aws_ec2_eip_association_single", mocks: func(repository *repository.MockEC2Repository) { - repository.On("ListAllAddressesAssociation").Return([]string{ - "eipassoc-0e9a7356e30f0c3d1", + repository.On("ListAllAddressesAssociation").Return([]*ec2.Address{ + { + AssociationId: awssdk.String("eipassoc-0e9a7356e30f0c3d1"), + AllocationId: awssdk.String("eipalloc-017d5267e4dda73f1"), + }, }, nil) }, },