package repository import ( "strings" "testing" "github.com/cloudskiff/driftctl/pkg/remote/cache" awstest "github.com/cloudskiff/driftctl/test/aws" "github.com/stretchr/testify/mock" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/aws" "github.com/r3labs/diff/v2" "github.com/stretchr/testify/assert" ) func Test_ec2Repository_ListAllImages(t *testing.T) { tests := []struct { name string mocks func(client *awstest.MockFakeEC2) want []*ec2.Image wantErr error }{ { name: "List all images", mocks: func(client *awstest.MockFakeEC2) { client.On("DescribeImages", &ec2.DescribeImagesInput{ Owners: []*string{ aws.String("self"), }, }).Return(&ec2.DescribeImagesOutput{ Images: []*ec2.Image{ {ImageId: aws.String("1")}, {ImageId: aws.String("2")}, {ImageId: aws.String("3")}, {ImageId: aws.String("4")}, }, }, nil).Once() }, want: []*ec2.Image{ {ImageId: aws.String("1")}, {ImageId: aws.String("2")}, {ImageId: aws.String("3")}, {ImageId: aws.String("4")}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := cache.New(1) client := &awstest.MockFakeEC2{} tt.mocks(client) r := &ec2Repository{ client: client, cache: store, } got, err := r.ListAllImages() assert.Equal(t, tt.wantErr, err) if err == nil { // Check that results were cached cachedData, err := r.ListAllImages() assert.NoError(t, err) assert.Equal(t, got, cachedData) assert.IsType(t, []*ec2.Image{}, store.Get("ec2ListAllImages")) } changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } }) } } func Test_ec2Repository_ListAllSnapshots(t *testing.T) { tests := []struct { name string mocks func(client *awstest.MockFakeEC2) want []*ec2.Snapshot wantErr error }{ {name: "List with 2 pages", mocks: func(client *awstest.MockFakeEC2) { client.On("DescribeSnapshotsPages", &ec2.DescribeSnapshotsInput{ OwnerIds: []*string{ aws.String("self"), }, }, mock.MatchedBy(func(callback func(res *ec2.DescribeSnapshotsOutput, lastPage bool) bool) bool { callback(&ec2.DescribeSnapshotsOutput{ Snapshots: []*ec2.Snapshot{ {VolumeId: aws.String("1")}, {VolumeId: aws.String("2")}, {VolumeId: aws.String("3")}, {VolumeId: aws.String("4")}, }, }, false) callback(&ec2.DescribeSnapshotsOutput{ Snapshots: []*ec2.Snapshot{ {VolumeId: aws.String("5")}, {VolumeId: aws.String("6")}, {VolumeId: aws.String("7")}, {VolumeId: aws.String("8")}, }, }, true) return true })).Return(nil).Once() }, want: []*ec2.Snapshot{ {VolumeId: aws.String("1")}, {VolumeId: aws.String("2")}, {VolumeId: aws.String("3")}, {VolumeId: aws.String("4")}, {VolumeId: aws.String("5")}, {VolumeId: aws.String("6")}, {VolumeId: aws.String("7")}, {VolumeId: aws.String("8")}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := cache.New(1) client := &awstest.MockFakeEC2{} tt.mocks(client) r := &ec2Repository{ client: client, cache: store, } got, err := r.ListAllSnapshots() assert.Equal(t, tt.wantErr, err) if err == nil { // Check that results were cached cachedData, err := r.ListAllSnapshots() assert.NoError(t, err) assert.Equal(t, got, cachedData) assert.IsType(t, []*ec2.Snapshot{}, store.Get("ec2ListAllSnapshots")) } changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } }) } } func Test_ec2Repository_ListAllVolumes(t *testing.T) { tests := []struct { name string mocks func(client *awstest.MockFakeEC2) want []*ec2.Volume wantErr error }{ {name: "List with 2 pages", mocks: func(client *awstest.MockFakeEC2) { client.On("DescribeVolumesPages", &ec2.DescribeVolumesInput{}, mock.MatchedBy(func(callback func(res *ec2.DescribeVolumesOutput, lastPage bool) bool) bool { callback(&ec2.DescribeVolumesOutput{ Volumes: []*ec2.Volume{ {VolumeId: aws.String("1")}, {VolumeId: aws.String("2")}, {VolumeId: aws.String("3")}, {VolumeId: aws.String("4")}, }, }, false) callback(&ec2.DescribeVolumesOutput{ Volumes: []*ec2.Volume{ {VolumeId: aws.String("5")}, {VolumeId: aws.String("6")}, {VolumeId: aws.String("7")}, {VolumeId: aws.String("8")}, }, }, true) return true })).Return(nil).Once() }, want: []*ec2.Volume{ {VolumeId: aws.String("1")}, {VolumeId: aws.String("2")}, {VolumeId: aws.String("3")}, {VolumeId: aws.String("4")}, {VolumeId: aws.String("5")}, {VolumeId: aws.String("6")}, {VolumeId: aws.String("7")}, {VolumeId: aws.String("8")}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := cache.New(1) client := &awstest.MockFakeEC2{} tt.mocks(client) r := &ec2Repository{ client: client, cache: store, } got, err := r.ListAllVolumes() assert.Equal(t, tt.wantErr, err) if err == nil { // Check that results were cached cachedData, err := r.ListAllVolumes() assert.NoError(t, err) assert.Equal(t, got, cachedData) assert.IsType(t, []*ec2.Volume{}, store.Get("ec2ListAllVolumes")) } changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } }) } } func Test_ec2Repository_ListAllAddresses(t *testing.T) { tests := []struct { name string mocks func(client *awstest.MockFakeEC2) want []*ec2.Address wantErr error }{ { name: "List address", mocks: func(client *awstest.MockFakeEC2) { client.On("DescribeAddresses", &ec2.DescribeAddressesInput{}). Return(&ec2.DescribeAddressesOutput{ Addresses: []*ec2.Address{ {AssociationId: aws.String("1")}, {AssociationId: aws.String("2")}, {AssociationId: aws.String("3")}, {AssociationId: aws.String("4")}, }, }, nil).Once() }, want: []*ec2.Address{ {AssociationId: aws.String("1")}, {AssociationId: aws.String("2")}, {AssociationId: aws.String("3")}, {AssociationId: aws.String("4")}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := cache.New(1) client := &awstest.MockFakeEC2{} tt.mocks(client) r := &ec2Repository{ client: client, cache: store, } got, err := r.ListAllAddresses() assert.Equal(t, tt.wantErr, err) if err == nil { // Check that results were cached cachedData, err := r.ListAllAddresses() assert.NoError(t, err) assert.Equal(t, got, cachedData) assert.IsType(t, []*ec2.Address{}, store.Get("ec2ListAllAddresses")) } changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } }) } } func Test_ec2Repository_ListAllAddressesAssociation(t *testing.T) { tests := []struct { name string mocks func(client *awstest.MockFakeEC2) want []string wantErr error }{ { name: "List address", mocks: func(client *awstest.MockFakeEC2) { client.On("DescribeAddresses", &ec2.DescribeAddressesInput{}). Return(&ec2.DescribeAddressesOutput{ Addresses: []*ec2.Address{ {AssociationId: aws.String("1")}, {AssociationId: aws.String("2")}, {AssociationId: aws.String("3")}, {AssociationId: aws.String("4")}, }, }, nil).Once() }, want: []string{ "1", "2", "3", "4", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := cache.New(1) client := &awstest.MockFakeEC2{} tt.mocks(client) r := &ec2Repository{ client: client, cache: store, } got, err := r.ListAllAddressesAssociation() assert.Equal(t, tt.wantErr, err) if err == nil { // Check that results were cached cachedData, err := r.ListAllAddressesAssociation() assert.NoError(t, err) assert.Equal(t, got, cachedData) assert.IsType(t, []string{}, store.Get("ec2ListAllAddressesAssociation")) } changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } }) } } func Test_ec2Repository_ListAllInstances(t *testing.T) { tests := []struct { name string mocks func(client *awstest.MockFakeEC2) want []*ec2.Instance wantErr error }{ {name: "List with 2 pages", mocks: func(client *awstest.MockFakeEC2) { client.On("DescribeInstancesPages", &ec2.DescribeInstancesInput{}, mock.MatchedBy(func(callback func(res *ec2.DescribeInstancesOutput, lastPage bool) bool) bool { callback(&ec2.DescribeInstancesOutput{ Reservations: []*ec2.Reservation{ { Instances: []*ec2.Instance{ {ImageId: aws.String("1")}, {ImageId: aws.String("2")}, {ImageId: aws.String("3")}, }, }, { Instances: []*ec2.Instance{ {ImageId: aws.String("4")}, {ImageId: aws.String("5")}, {ImageId: aws.String("6")}, }, }, }, }, false) callback(&ec2.DescribeInstancesOutput{ Reservations: []*ec2.Reservation{ { Instances: []*ec2.Instance{ {ImageId: aws.String("7")}, {ImageId: aws.String("8")}, {ImageId: aws.String("9")}, }, }, { Instances: []*ec2.Instance{ {ImageId: aws.String("10")}, {ImageId: aws.String("11")}, {ImageId: aws.String("12")}, }, }, }, }, true) return true })).Return(nil).Once() }, want: []*ec2.Instance{ {ImageId: aws.String("1")}, {ImageId: aws.String("2")}, {ImageId: aws.String("3")}, {ImageId: aws.String("4")}, {ImageId: aws.String("5")}, {ImageId: aws.String("6")}, {ImageId: aws.String("7")}, {ImageId: aws.String("8")}, {ImageId: aws.String("9")}, {ImageId: aws.String("10")}, {ImageId: aws.String("11")}, {ImageId: aws.String("12")}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := cache.New(1) client := &awstest.MockFakeEC2{} tt.mocks(client) r := &ec2Repository{ client: client, cache: store, } got, err := r.ListAllInstances() assert.Equal(t, tt.wantErr, err) if err == nil { // Check that results were cached cachedData, err := r.ListAllInstances() assert.NoError(t, err) assert.Equal(t, got, cachedData) assert.IsType(t, []*ec2.Instance{}, store.Get("ec2ListAllInstances")) } changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } }) } } func Test_ec2Repository_ListAllKeyPairs(t *testing.T) { tests := []struct { name string mocks func(client *awstest.MockFakeEC2) want []*ec2.KeyPairInfo wantErr error }{ { name: "List address", mocks: func(client *awstest.MockFakeEC2) { client.On("DescribeKeyPairs", &ec2.DescribeKeyPairsInput{}). Return(&ec2.DescribeKeyPairsOutput{ KeyPairs: []*ec2.KeyPairInfo{ {KeyPairId: aws.String("1")}, {KeyPairId: aws.String("2")}, {KeyPairId: aws.String("3")}, {KeyPairId: aws.String("4")}, }, }, nil).Once() }, want: []*ec2.KeyPairInfo{ {KeyPairId: aws.String("1")}, {KeyPairId: aws.String("2")}, {KeyPairId: aws.String("3")}, {KeyPairId: aws.String("4")}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := cache.New(1) client := &awstest.MockFakeEC2{} tt.mocks(client) r := &ec2Repository{ client: client, cache: store, } got, err := r.ListAllKeyPairs() assert.Equal(t, tt.wantErr, err) if err == nil { // Check that results were cached cachedData, err := r.ListAllKeyPairs() assert.NoError(t, err) assert.Equal(t, got, cachedData) assert.IsType(t, []*ec2.KeyPairInfo{}, store.Get("ec2ListAllKeyPairs")) } changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } }) } } func Test_ec2Repository_ListAllInternetGateways(t *testing.T) { tests := []struct { name string mocks func(client *awstest.MockFakeEC2) want []*ec2.InternetGateway wantErr error }{ { name: "List only gateways with multiple pages", mocks: func(client *awstest.MockFakeEC2) { client.On("DescribeInternetGatewaysPages", &ec2.DescribeInternetGatewaysInput{}, mock.MatchedBy(func(callback func(res *ec2.DescribeInternetGatewaysOutput, lastPage bool) bool) bool { callback(&ec2.DescribeInternetGatewaysOutput{ InternetGateways: []*ec2.InternetGateway{ { InternetGatewayId: aws.String("Internet-0"), }, { InternetGatewayId: aws.String("Internet-1"), }, }, }, false) callback(&ec2.DescribeInternetGatewaysOutput{ InternetGateways: []*ec2.InternetGateway{ { InternetGatewayId: aws.String("Internet-2"), }, { InternetGatewayId: aws.String("Internet-3"), }, }, }, true) return true })).Return(nil).Once() }, want: []*ec2.InternetGateway{ { InternetGatewayId: aws.String("Internet-0"), }, { InternetGatewayId: aws.String("Internet-1"), }, { InternetGatewayId: aws.String("Internet-2"), }, { InternetGatewayId: aws.String("Internet-3"), }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := cache.New(1) client := &awstest.MockFakeEC2{} tt.mocks(client) r := &ec2Repository{ client: client, cache: store, } got, err := r.ListAllInternetGateways() assert.Equal(t, tt.wantErr, err) if err == nil { // Check that results were cached cachedData, err := r.ListAllInternetGateways() assert.NoError(t, err) assert.Equal(t, got, cachedData) assert.IsType(t, []*ec2.InternetGateway{}, store.Get("ec2ListAllInternetGateways")) } changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } }) } } func Test_ec2Repository_ListAllSubnets(t *testing.T) { tests := []struct { name string mocks func(client *awstest.MockFakeEC2) wantSubnet []*ec2.Subnet wantDefaultSubnet []*ec2.Subnet wantErr error }{ { name: "List with 2 pages", mocks: func(client *awstest.MockFakeEC2) { client.On("DescribeSubnetsPages", &ec2.DescribeSubnetsInput{}, mock.MatchedBy(func(callback func(res *ec2.DescribeSubnetsOutput, lastPage bool) bool) bool { callback(&ec2.DescribeSubnetsOutput{ Subnets: []*ec2.Subnet{ { SubnetId: aws.String("subnet-0b13f1e0eacf67424"), // subnet2 DefaultForAz: aws.Bool(false), }, { SubnetId: aws.String("subnet-0c9b78001fe186e22"), // subnet3 DefaultForAz: aws.Bool(false), }, { SubnetId: aws.String("subnet-05810d3f933925f6d"), // subnet1 DefaultForAz: aws.Bool(false), }, }, }, false) callback(&ec2.DescribeSubnetsOutput{ Subnets: []*ec2.Subnet{ { SubnetId: aws.String("subnet-44fe0c65"), // us-east-1a DefaultForAz: aws.Bool(true), }, { SubnetId: aws.String("subnet-65e16628"), // us-east-1b DefaultForAz: aws.Bool(true), }, { SubnetId: aws.String("subnet-afa656f0"), // us-east-1c DefaultForAz: aws.Bool(true), }, }, }, true) return true })).Return(nil).Once() }, wantSubnet: []*ec2.Subnet{ { SubnetId: aws.String("subnet-0b13f1e0eacf67424"), // subnet2 DefaultForAz: aws.Bool(false), }, { SubnetId: aws.String("subnet-0c9b78001fe186e22"), // subnet3 DefaultForAz: aws.Bool(false), }, { SubnetId: aws.String("subnet-05810d3f933925f6d"), // subnet1 DefaultForAz: aws.Bool(false), }, }, wantDefaultSubnet: []*ec2.Subnet{ { SubnetId: aws.String("subnet-44fe0c65"), // us-east-1a DefaultForAz: aws.Bool(true), }, { SubnetId: aws.String("subnet-65e16628"), // us-east-1b DefaultForAz: aws.Bool(true), }, { SubnetId: aws.String("subnet-afa656f0"), // us-east-1c DefaultForAz: aws.Bool(true), }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := cache.New(2) client := &awstest.MockFakeEC2{} tt.mocks(client) r := &ec2Repository{ client: client, cache: store, } gotSubnet, gotDefaultSubnet, err := r.ListAllSubnets() assert.Equal(t, tt.wantErr, err) if err == nil { // Check that results were cached cachedData, cachedDefaultData, err := r.ListAllSubnets() assert.NoError(t, err) assert.Equal(t, gotSubnet, cachedData) assert.Equal(t, gotDefaultSubnet, cachedDefaultData) assert.IsType(t, []*ec2.Subnet{}, store.Get("ec2ListAllSubnets")) assert.IsType(t, []*ec2.Subnet{}, store.Get("ec2ListAllDefaultSubnets")) } changelog, err := diff.Diff(gotSubnet, tt.wantSubnet) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } changelog, err = diff.Diff(gotDefaultSubnet, tt.wantDefaultSubnet) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } }) } } func Test_ec2Repository_ListAllNatGateways(t *testing.T) { tests := []struct { name string mocks func(client *awstest.MockFakeEC2) want []*ec2.NatGateway wantErr error }{ { name: "List only gateways with multiple pages", mocks: func(client *awstest.MockFakeEC2) { client.On("DescribeNatGatewaysPages", &ec2.DescribeNatGatewaysInput{}, mock.MatchedBy(func(callback func(res *ec2.DescribeNatGatewaysOutput, lastPage bool) bool) bool { callback(&ec2.DescribeNatGatewaysOutput{ NatGateways: []*ec2.NatGateway{ { NatGatewayId: aws.String("nat-0"), }, { NatGatewayId: aws.String("nat-1"), }, }, }, false) callback(&ec2.DescribeNatGatewaysOutput{ NatGateways: []*ec2.NatGateway{ { NatGatewayId: aws.String("nat-2"), }, { NatGatewayId: aws.String("nat-3"), }, }, }, true) return true })).Return(nil).Once() }, want: []*ec2.NatGateway{ { NatGatewayId: aws.String("nat-0"), }, { NatGatewayId: aws.String("nat-1"), }, { NatGatewayId: aws.String("nat-2"), }, { NatGatewayId: aws.String("nat-3"), }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := cache.New(1) client := &awstest.MockFakeEC2{} tt.mocks(client) r := &ec2Repository{ client: client, cache: store, } got, err := r.ListAllNatGateways() assert.Equal(t, tt.wantErr, err) if err == nil { // Check that results were cached cachedData, err := r.ListAllNatGateways() assert.NoError(t, err) assert.Equal(t, got, cachedData) assert.IsType(t, []*ec2.NatGateway{}, store.Get("ec2ListAllNatGateways")) } changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } }) } } func Test_ec2Repository_ListAllRouteTables(t *testing.T) { tests := []struct { name string mocks func(client *awstest.MockFakeEC2) want []*ec2.RouteTable wantErr error }{ { name: "List only route with multiple pages", mocks: func(client *awstest.MockFakeEC2) { client.On("DescribeRouteTablesPages", &ec2.DescribeRouteTablesInput{}, mock.MatchedBy(func(callback func(res *ec2.DescribeRouteTablesOutput, lastPage bool) bool) bool { callback(&ec2.DescribeRouteTablesOutput{ RouteTables: []*ec2.RouteTable{ { RouteTableId: aws.String("rtb-096bdfb69309c54c3"), // table1 Routes: []*ec2.Route{ { DestinationCidrBlock: aws.String("10.0.0.0/16"), Origin: aws.String("CreateRouteTable"), // default route }, { DestinationCidrBlock: aws.String("1.1.1.1/32"), GatewayId: aws.String("igw-030e74f73bd67f21b"), }, { DestinationIpv6CidrBlock: aws.String("::/0"), GatewayId: aws.String("igw-030e74f73bd67f21b"), }, }, }, { RouteTableId: aws.String("rtb-0169b0937fd963ddc"), // table2 Routes: []*ec2.Route{ { DestinationCidrBlock: aws.String("10.0.0.0/16"), Origin: aws.String("CreateRouteTable"), // default route }, { DestinationCidrBlock: aws.String("0.0.0.0/0"), GatewayId: aws.String("igw-030e74f73bd67f21b"), }, { DestinationIpv6CidrBlock: aws.String("::/0"), GatewayId: aws.String("igw-030e74f73bd67f21b"), }, }, }, }, }, false) callback(&ec2.DescribeRouteTablesOutput{ RouteTables: []*ec2.RouteTable{ { RouteTableId: aws.String("rtb-02780c485f0be93c5"), // default_table VpcId: aws.String("vpc-09fe5abc2309ba49d"), Associations: []*ec2.RouteTableAssociation{ { Main: aws.Bool(true), }, }, Routes: []*ec2.Route{ { DestinationCidrBlock: aws.String("10.0.0.0/16"), Origin: aws.String("CreateRouteTable"), // default route }, { DestinationCidrBlock: aws.String("10.1.1.0/24"), GatewayId: aws.String("igw-030e74f73bd67f21b"), }, { DestinationCidrBlock: aws.String("10.1.2.0/24"), GatewayId: aws.String("igw-030e74f73bd67f21b"), }, }, }, { RouteTableId: aws.String(""), // table3 Routes: []*ec2.Route{ { DestinationCidrBlock: aws.String("10.0.0.0/16"), Origin: aws.String("CreateRouteTable"), // default route }, }, }, }, }, true) return true })).Return(nil).Once() }, want: []*ec2.RouteTable{ { RouteTableId: aws.String("rtb-096bdfb69309c54c3"), // table1 Routes: []*ec2.Route{ { DestinationCidrBlock: aws.String("10.0.0.0/16"), Origin: aws.String("CreateRouteTable"), // default route }, { DestinationCidrBlock: aws.String("1.1.1.1/32"), GatewayId: aws.String("igw-030e74f73bd67f21b"), }, { DestinationIpv6CidrBlock: aws.String("::/0"), GatewayId: aws.String("igw-030e74f73bd67f21b"), }, }, }, { RouteTableId: aws.String("rtb-0169b0937fd963ddc"), // table2 Routes: []*ec2.Route{ { DestinationCidrBlock: aws.String("10.0.0.0/16"), Origin: aws.String("CreateRouteTable"), // default route }, { DestinationCidrBlock: aws.String("0.0.0.0/0"), GatewayId: aws.String("igw-030e74f73bd67f21b"), }, { DestinationIpv6CidrBlock: aws.String("::/0"), GatewayId: aws.String("igw-030e74f73bd67f21b"), }, }, }, { RouteTableId: aws.String("rtb-02780c485f0be93c5"), // default_table VpcId: aws.String("vpc-09fe5abc2309ba49d"), Associations: []*ec2.RouteTableAssociation{ { Main: aws.Bool(true), }, }, Routes: []*ec2.Route{ { DestinationCidrBlock: aws.String("10.0.0.0/16"), Origin: aws.String("CreateRouteTable"), // default route }, { DestinationCidrBlock: aws.String("10.1.1.0/24"), GatewayId: aws.String("igw-030e74f73bd67f21b"), }, { DestinationCidrBlock: aws.String("10.1.2.0/24"), GatewayId: aws.String("igw-030e74f73bd67f21b"), }, }, }, { RouteTableId: aws.String(""), // table3 Routes: []*ec2.Route{ { DestinationCidrBlock: aws.String("10.0.0.0/16"), Origin: aws.String("CreateRouteTable"), // default route }, }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := cache.New(1) client := &awstest.MockFakeEC2{} tt.mocks(client) r := &ec2Repository{ client: client, cache: store, } got, err := r.ListAllRouteTables() assert.Equal(t, tt.wantErr, err) if err == nil { // Check that results were cached cachedData, err := r.ListAllRouteTables() assert.NoError(t, err) assert.Equal(t, got, cachedData) assert.IsType(t, []*ec2.RouteTable{}, store.Get("ec2ListAllRouteTables")) } changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } }) } } func Test_ec2Repository_ListAllVPCs(t *testing.T) { tests := []struct { name string mocks func(client *awstest.MockFakeEC2) wantVPC []*ec2.Vpc wantDefaultVPC []*ec2.Vpc wantErr error }{ { name: "mixed default VPC and VPC", mocks: func(client *awstest.MockFakeEC2) { client.On("DescribeVpcsPages", &ec2.DescribeVpcsInput{}, mock.MatchedBy(func(callback func(res *ec2.DescribeVpcsOutput, lastPage bool) bool) bool { callback(&ec2.DescribeVpcsOutput{ Vpcs: []*ec2.Vpc{ { VpcId: aws.String("vpc-a8c5d4c1"), IsDefault: aws.Bool(true), }, { VpcId: aws.String("vpc-0768e1fd0029e3fc3"), }, { VpcId: aws.String("vpc-020b072316a95b97f"), IsDefault: aws.Bool(false), }, }, }, false) callback(&ec2.DescribeVpcsOutput{ Vpcs: []*ec2.Vpc{ { VpcId: aws.String("vpc-02c50896b59598761"), IsDefault: aws.Bool(false), }, }, }, true) return true })).Return(nil).Once() }, wantVPC: []*ec2.Vpc{ { VpcId: aws.String("vpc-0768e1fd0029e3fc3"), }, { VpcId: aws.String("vpc-020b072316a95b97f"), IsDefault: aws.Bool(false), }, { VpcId: aws.String("vpc-02c50896b59598761"), IsDefault: aws.Bool(false), }, }, wantDefaultVPC: []*ec2.Vpc{ { VpcId: aws.String("vpc-a8c5d4c1"), IsDefault: aws.Bool(true), }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := cache.New(2) client := &awstest.MockFakeEC2{} tt.mocks(client) r := &ec2Repository{ client: client, cache: store, } gotVPCs, gotDefaultVPCs, err := r.ListAllVPCs() assert.Equal(t, tt.wantErr, err) if err == nil { // Check that results were cached cachedData, cachedDefaultData, err := r.ListAllVPCs() assert.NoError(t, err) assert.Equal(t, gotVPCs, cachedData) assert.Equal(t, gotDefaultVPCs, cachedDefaultData) assert.IsType(t, []*ec2.Vpc{}, store.Get("ec2ListAllVPCs")) assert.IsType(t, []*ec2.Vpc{}, store.Get("ec2ListAllDefaultVPCs")) } changelog, err := diff.Diff(gotVPCs, tt.wantVPC) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } changelog, err = diff.Diff(gotDefaultVPCs, tt.wantDefaultVPC) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } }) } } func Test_ec2Repository_ListAllSecurityGroups(t *testing.T) { tests := []struct { name string mocks func(client *awstest.MockFakeEC2) wantSecurityGroup []*ec2.SecurityGroup wantDefaultSecurityGroup []*ec2.SecurityGroup wantErr error }{ { name: "List with 1 pages", mocks: func(client *awstest.MockFakeEC2) { client.On("DescribeSecurityGroupsPages", &ec2.DescribeSecurityGroupsInput{}, mock.MatchedBy(func(callback func(res *ec2.DescribeSecurityGroupsOutput, lastPage bool) bool) bool { callback(&ec2.DescribeSecurityGroupsOutput{ SecurityGroups: []*ec2.SecurityGroup{ { GroupId: aws.String("sg-0254c038e32f25530"), GroupName: aws.String("foo"), }, { GroupId: aws.String("sg-9e0204ff"), GroupName: aws.String("default"), }, }, }, true) return true })).Return(nil).Once() }, wantSecurityGroup: []*ec2.SecurityGroup{ { GroupId: aws.String("sg-0254c038e32f25530"), GroupName: aws.String("foo"), }, }, wantDefaultSecurityGroup: []*ec2.SecurityGroup{ { GroupId: aws.String("sg-9e0204ff"), GroupName: aws.String("default"), }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := cache.New(2) client := &awstest.MockFakeEC2{} tt.mocks(client) r := &ec2Repository{ client: client, cache: store, } gotSecurityGroups, gotDefaultSecurityGroups, err := r.ListAllSecurityGroups() assert.Equal(t, tt.wantErr, err) if err == nil { // Check that results were cached cachedData, cachedDefaultData, err := r.ListAllSecurityGroups() assert.NoError(t, err) assert.Equal(t, gotSecurityGroups, cachedData) assert.Equal(t, gotDefaultSecurityGroups, cachedDefaultData) assert.IsType(t, []*ec2.SecurityGroup{}, store.Get("ec2ListAllSecurityGroups")) assert.IsType(t, []*ec2.SecurityGroup{}, store.Get("ec2ListAllDefaultSecurityGroups")) } changelog, err := diff.Diff(gotSecurityGroups, tt.wantSecurityGroup) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } changelog, err = diff.Diff(gotDefaultSecurityGroups, tt.wantDefaultSecurityGroup) assert.Nil(t, err) if len(changelog) > 0 { for _, change := range changelog { t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) } t.Fail() } }) } }