Add allocation_id to nat_gw and eip_assoc

main
Elie 2021-07-20 17:51:51 +02:00
parent 2ac36fd416
commit 5880f641c5
No known key found for this signature in database
GPG Key ID: 399AF69092C727B6
6 changed files with 41 additions and 27 deletions

View File

@ -24,20 +24,22 @@ func (e *EC2EipAssociationEnumerator) SupportedType() resource.ResourceType {
} }
func (e *EC2EipAssociationEnumerator) Enumerate() ([]resource.Resource, error) { func (e *EC2EipAssociationEnumerator) Enumerate() ([]resource.Resource, error) {
associationIds, err := e.repository.ListAllAddressesAssociation() addresses, err := e.repository.ListAllAddressesAssociation()
if err != nil { if err != nil {
return nil, remoteerror.NewResourceEnumerationError(err, string(e.SupportedType())) return nil, remoteerror.NewResourceEnumerationError(err, string(e.SupportedType()))
} }
results := make([]resource.Resource, len(associationIds)) results := make([]resource.Resource, 0, len(addresses))
for _, associationId := range associationIds { for _, address := range addresses {
results = append( results = append(
results, results,
e.factory.CreateAbstractResource( e.factory.CreateAbstractResource(
string(e.SupportedType()), string(e.SupportedType()),
associationId, *address.AssociationId,
map[string]interface{}{}, map[string]interface{}{
"allocation_id": *address.AllocationId,
},
), ),
) )
} }

View File

@ -32,12 +32,20 @@ func (e *EC2NatGatewayEnumerator) Enumerate() ([]resource.Resource, error) {
results := make([]resource.Resource, len(natGateways)) results := make([]resource.Resource, len(natGateways))
for _, natGateway := range natGateways { for _, natGateway := range natGateways {
attrs := map[string]interface{}{}
if len(natGateway.NatGatewayAddresses) > 0 {
if allocId := natGateway.NatGatewayAddresses[0].AllocationId; allocId != nil {
attrs["allocation_id"] = *allocId
}
}
results = append( results = append(
results, results,
e.factory.CreateAbstractResource( e.factory.CreateAbstractResource(
string(e.SupportedType()), string(e.SupportedType()),
*natGateway.NatGatewayId, *natGateway.NatGatewayId,
map[string]interface{}{}, attrs,
), ),
) )
} }

View File

@ -13,7 +13,7 @@ type EC2Repository interface {
ListAllSnapshots() ([]*ec2.Snapshot, error) ListAllSnapshots() ([]*ec2.Snapshot, error)
ListAllVolumes() ([]*ec2.Volume, error) ListAllVolumes() ([]*ec2.Volume, error)
ListAllAddresses() ([]*ec2.Address, error) ListAllAddresses() ([]*ec2.Address, error)
ListAllAddressesAssociation() ([]string, error) ListAllAddressesAssociation() ([]*ec2.Address, error)
ListAllInstances() ([]*ec2.Instance, error) ListAllInstances() ([]*ec2.Instance, error)
ListAllKeyPairs() ([]*ec2.KeyPairInfo, error) ListAllKeyPairs() ([]*ec2.KeyPairInfo, error)
ListAllInternetGateways() ([]*ec2.InternetGateway, error) ListAllInternetGateways() ([]*ec2.InternetGateway, error)
@ -108,19 +108,20 @@ func (r *ec2Repository) ListAllAddresses() ([]*ec2.Address, error) {
return response.Addresses, nil return response.Addresses, nil
} }
func (r *ec2Repository) ListAllAddressesAssociation() ([]string, error) { func (r *ec2Repository) ListAllAddressesAssociation() ([]*ec2.Address, error) {
if v := r.cache.Get("ec2ListAllAddressesAssociation"); v != nil { if v := r.cache.Get("ec2ListAllAddressesAssociation"); v != nil {
return v.([]string), nil return v.([]*ec2.Address), nil
} }
results := make([]string, 0)
addresses, err := r.ListAllAddresses() addresses, err := r.ListAllAddresses()
if err != nil { if err != nil {
return nil, err return nil, err
} }
results := make([]*ec2.Address, 0, len(addresses))
for _, address := range addresses { for _, address := range addresses {
if address.AssociationId != nil { if address.AssociationId != nil {
results = append(results, aws.StringValue(address.AssociationId)) results = append(results, address)
} }
} }
r.cache.Put("ec2ListAllAddressesAssociation", results) r.cache.Put("ec2ListAllAddressesAssociation", results)

View File

@ -299,7 +299,7 @@ func Test_ec2Repository_ListAllAddressesAssociation(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
mocks func(client *awstest.MockFakeEC2) mocks func(client *awstest.MockFakeEC2)
want []string want []*ec2.Address
wantErr error wantErr error
}{ }{
{ {
@ -315,11 +315,11 @@ func Test_ec2Repository_ListAllAddressesAssociation(t *testing.T) {
}, },
}, nil).Once() }, nil).Once()
}, },
want: []string{ want: []*ec2.Address{
"1", {AssociationId: aws.String("1")},
"2", {AssociationId: aws.String("2")},
"3", {AssociationId: aws.String("3")},
"4", {AssociationId: aws.String("4")},
}, },
}, },
} }
@ -340,7 +340,7 @@ func Test_ec2Repository_ListAllAddressesAssociation(t *testing.T) {
cachedData, err := r.ListAllAddressesAssociation() cachedData, err := r.ListAllAddressesAssociation()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, got, cachedData) assert.Equal(t, got, cachedData)
assert.IsType(t, []string{}, store.Get("ec2ListAllAddressesAssociation")) assert.IsType(t, []*ec2.Address{}, store.Get("ec2ListAllAddressesAssociation"))
} }
changelog, err := diff.Diff(got, tt.want) changelog, err := diff.Diff(got, tt.want)

View File

@ -1,4 +1,4 @@
// Code generated by mockery v1.0.0. DO NOT EDIT. // Code generated by mockery v0.0.0-dev. DO NOT EDIT.
package repository package repository
@ -7,7 +7,7 @@ import (
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
) )
// MockEC2Repository is an autogenerated mock type for the MockEC2Repository type // MockEC2Repository is an autogenerated mock type for the EC2Repository type
type MockEC2Repository struct { type MockEC2Repository struct {
mock.Mock mock.Mock
} }
@ -36,15 +36,15 @@ func (_m *MockEC2Repository) ListAllAddresses() ([]*ec2.Address, error) {
} }
// ListAllAddressesAssociation provides a mock function with given fields: // ListAllAddressesAssociation provides a mock function with given fields:
func (_m *MockEC2Repository) ListAllAddressesAssociation() ([]string, error) { func (_m *MockEC2Repository) ListAllAddressesAssociation() ([]*ec2.Address, error) {
ret := _m.Called() ret := _m.Called()
var r0 []string var r0 []*ec2.Address
if rf, ok := ret.Get(0).(func() []string); ok { if rf, ok := ret.Get(0).(func() []*ec2.Address); ok {
r0 = rf() r0 = rf()
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).([]string) r0 = ret.Get(0).([]*ec2.Address)
} }
} }

View File

@ -479,15 +479,18 @@ func TestEC2EipAssociation(t *testing.T) {
test: "no eip associations", test: "no eip associations",
dirName: "aws_ec2_eip_association_empty", dirName: "aws_ec2_eip_association_empty",
mocks: func(repository *repository.MockEC2Repository) { mocks: func(repository *repository.MockEC2Repository) {
repository.On("ListAllAddressesAssociation").Return([]string{}, nil) repository.On("ListAllAddressesAssociation").Return([]*ec2.Address{}, nil)
}, },
}, },
{ {
test: "single eip association", test: "single eip association",
dirName: "aws_ec2_eip_association_single", dirName: "aws_ec2_eip_association_single",
mocks: func(repository *repository.MockEC2Repository) { mocks: func(repository *repository.MockEC2Repository) {
repository.On("ListAllAddressesAssociation").Return([]string{ repository.On("ListAllAddressesAssociation").Return([]*ec2.Address{
"eipassoc-0e9a7356e30f0c3d1", {
AssociationId: awssdk.String("eipassoc-0e9a7356e30f0c3d1"),
AllocationId: awssdk.String("eipalloc-017d5267e4dda73f1"),
},
}, nil) }, nil)
}, },
}, },