driftctl/pkg/remote/aws/repository/ec2_repository_test.go

1183 lines
32 KiB
Go

package repository
import (
"strings"
"testing"
"github.com/cloudskiff/driftctl/pkg/remote/cache"
"github.com/stretchr/testify/mock"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/aws"
"github.com/r3labs/diff/v2"
"github.com/stretchr/testify/assert"
)
func Test_ec2Repository_ListAllImages(t *testing.T) {
tests := []struct {
name string
mocks func(client *MockEC2Client)
want []*ec2.Image
wantErr error
}{
{
name: "List all images",
mocks: func(client *MockEC2Client) {
client.On("DescribeImages",
&ec2.DescribeImagesInput{
Owners: []*string{
aws.String("self"),
},
}).Return(&ec2.DescribeImagesOutput{
Images: []*ec2.Image{
{ImageId: aws.String("1")},
{ImageId: aws.String("2")},
{ImageId: aws.String("3")},
{ImageId: aws.String("4")},
},
}, nil).Once()
},
want: []*ec2.Image{
{ImageId: aws.String("1")},
{ImageId: aws.String("2")},
{ImageId: aws.String("3")},
{ImageId: aws.String("4")},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := cache.New(1)
client := &MockEC2Client{}
tt.mocks(client)
r := &ec2Repository{
client: client,
cache: store,
}
got, err := r.ListAllImages()
assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, err := r.ListAllImages()
assert.NoError(t, err)
assert.Equal(t, got, cachedData)
assert.IsType(t, []*ec2.Image{}, store.Get("ec2ListAllImages"))
}
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
})
}
}
func Test_ec2Repository_ListAllSnapshots(t *testing.T) {
tests := []struct {
name string
mocks func(client *MockEC2Client)
want []*ec2.Snapshot
wantErr error
}{
{name: "List with 2 pages",
mocks: func(client *MockEC2Client) {
client.On("DescribeSnapshotsPages",
&ec2.DescribeSnapshotsInput{
OwnerIds: []*string{
aws.String("self"),
},
},
mock.MatchedBy(func(callback func(res *ec2.DescribeSnapshotsOutput, lastPage bool) bool) bool {
callback(&ec2.DescribeSnapshotsOutput{
Snapshots: []*ec2.Snapshot{
{VolumeId: aws.String("1")},
{VolumeId: aws.String("2")},
{VolumeId: aws.String("3")},
{VolumeId: aws.String("4")},
},
}, false)
callback(&ec2.DescribeSnapshotsOutput{
Snapshots: []*ec2.Snapshot{
{VolumeId: aws.String("5")},
{VolumeId: aws.String("6")},
{VolumeId: aws.String("7")},
{VolumeId: aws.String("8")},
},
}, true)
return true
})).Return(nil).Once()
},
want: []*ec2.Snapshot{
{VolumeId: aws.String("1")},
{VolumeId: aws.String("2")},
{VolumeId: aws.String("3")},
{VolumeId: aws.String("4")},
{VolumeId: aws.String("5")},
{VolumeId: aws.String("6")},
{VolumeId: aws.String("7")},
{VolumeId: aws.String("8")},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := cache.New(1)
client := &MockEC2Client{}
tt.mocks(client)
r := &ec2Repository{
client: client,
cache: store,
}
got, err := r.ListAllSnapshots()
assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, err := r.ListAllSnapshots()
assert.NoError(t, err)
assert.Equal(t, got, cachedData)
assert.IsType(t, []*ec2.Snapshot{}, store.Get("ec2ListAllSnapshots"))
}
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
})
}
}
func Test_ec2Repository_ListAllVolumes(t *testing.T) {
tests := []struct {
name string
mocks func(client *MockEC2Client)
want []*ec2.Volume
wantErr error
}{
{name: "List with 2 pages",
mocks: func(client *MockEC2Client) {
client.On("DescribeVolumesPages",
&ec2.DescribeVolumesInput{},
mock.MatchedBy(func(callback func(res *ec2.DescribeVolumesOutput, lastPage bool) bool) bool {
callback(&ec2.DescribeVolumesOutput{
Volumes: []*ec2.Volume{
{VolumeId: aws.String("1")},
{VolumeId: aws.String("2")},
{VolumeId: aws.String("3")},
{VolumeId: aws.String("4")},
},
}, false)
callback(&ec2.DescribeVolumesOutput{
Volumes: []*ec2.Volume{
{VolumeId: aws.String("5")},
{VolumeId: aws.String("6")},
{VolumeId: aws.String("7")},
{VolumeId: aws.String("8")},
},
}, true)
return true
})).Return(nil).Once()
},
want: []*ec2.Volume{
{VolumeId: aws.String("1")},
{VolumeId: aws.String("2")},
{VolumeId: aws.String("3")},
{VolumeId: aws.String("4")},
{VolumeId: aws.String("5")},
{VolumeId: aws.String("6")},
{VolumeId: aws.String("7")},
{VolumeId: aws.String("8")},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := cache.New(1)
client := &MockEC2Client{}
tt.mocks(client)
r := &ec2Repository{
client: client,
cache: store,
}
got, err := r.ListAllVolumes()
assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, err := r.ListAllVolumes()
assert.NoError(t, err)
assert.Equal(t, got, cachedData)
assert.IsType(t, []*ec2.Volume{}, store.Get("ec2ListAllVolumes"))
}
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
})
}
}
func Test_ec2Repository_ListAllAddresses(t *testing.T) {
tests := []struct {
name string
mocks func(client *MockEC2Client)
want []*ec2.Address
wantErr error
}{
{
name: "List address",
mocks: func(client *MockEC2Client) {
client.On("DescribeAddresses", &ec2.DescribeAddressesInput{}).
Return(&ec2.DescribeAddressesOutput{
Addresses: []*ec2.Address{
{AssociationId: aws.String("1")},
{AssociationId: aws.String("2")},
{AssociationId: aws.String("3")},
{AssociationId: aws.String("4")},
},
}, nil).Once()
},
want: []*ec2.Address{
{AssociationId: aws.String("1")},
{AssociationId: aws.String("2")},
{AssociationId: aws.String("3")},
{AssociationId: aws.String("4")},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := cache.New(1)
client := &MockEC2Client{}
tt.mocks(client)
r := &ec2Repository{
client: client,
cache: store,
}
got, err := r.ListAllAddresses()
assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, err := r.ListAllAddresses()
assert.NoError(t, err)
assert.Equal(t, got, cachedData)
assert.IsType(t, []*ec2.Address{}, store.Get("ec2ListAllAddresses"))
}
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
})
}
}
func Test_ec2Repository_ListAllAddressesAssociation(t *testing.T) {
tests := []struct {
name string
mocks func(client *MockEC2Client)
want []string
wantErr error
}{
{
name: "List address",
mocks: func(client *MockEC2Client) {
client.On("DescribeAddresses", &ec2.DescribeAddressesInput{}).
Return(&ec2.DescribeAddressesOutput{
Addresses: []*ec2.Address{
{AssociationId: aws.String("1")},
{AssociationId: aws.String("2")},
{AssociationId: aws.String("3")},
{AssociationId: aws.String("4")},
},
}, nil).Once()
},
want: []string{
"1",
"2",
"3",
"4",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := cache.New(1)
client := &MockEC2Client{}
tt.mocks(client)
r := &ec2Repository{
client: client,
cache: store,
}
got, err := r.ListAllAddressesAssociation()
assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, err := r.ListAllAddressesAssociation()
assert.NoError(t, err)
assert.Equal(t, got, cachedData)
assert.IsType(t, []string{}, store.Get("ec2ListAllAddressesAssociation"))
}
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
})
}
}
func Test_ec2Repository_ListAllInstances(t *testing.T) {
tests := []struct {
name string
mocks func(client *MockEC2Client)
want []*ec2.Instance
wantErr error
}{
{name: "List with 2 pages",
mocks: func(client *MockEC2Client) {
client.On("DescribeInstancesPages",
&ec2.DescribeInstancesInput{},
mock.MatchedBy(func(callback func(res *ec2.DescribeInstancesOutput, lastPage bool) bool) bool {
callback(&ec2.DescribeInstancesOutput{
Reservations: []*ec2.Reservation{
{
Instances: []*ec2.Instance{
{ImageId: aws.String("1")},
{ImageId: aws.String("2")},
{ImageId: aws.String("3")},
},
},
{
Instances: []*ec2.Instance{
{ImageId: aws.String("4")},
{ImageId: aws.String("5")},
{ImageId: aws.String("6")},
},
},
},
}, false)
callback(&ec2.DescribeInstancesOutput{
Reservations: []*ec2.Reservation{
{
Instances: []*ec2.Instance{
{ImageId: aws.String("7")},
{ImageId: aws.String("8")},
{ImageId: aws.String("9")},
},
},
{
Instances: []*ec2.Instance{
{ImageId: aws.String("10")},
{ImageId: aws.String("11")},
{ImageId: aws.String("12")},
},
},
},
}, true)
return true
})).Return(nil).Once()
},
want: []*ec2.Instance{
{ImageId: aws.String("1")},
{ImageId: aws.String("2")},
{ImageId: aws.String("3")},
{ImageId: aws.String("4")},
{ImageId: aws.String("5")},
{ImageId: aws.String("6")},
{ImageId: aws.String("7")},
{ImageId: aws.String("8")},
{ImageId: aws.String("9")},
{ImageId: aws.String("10")},
{ImageId: aws.String("11")},
{ImageId: aws.String("12")},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := cache.New(1)
client := &MockEC2Client{}
tt.mocks(client)
r := &ec2Repository{
client: client,
cache: store,
}
got, err := r.ListAllInstances()
assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, err := r.ListAllInstances()
assert.NoError(t, err)
assert.Equal(t, got, cachedData)
assert.IsType(t, []*ec2.Instance{}, store.Get("ec2ListAllInstances"))
}
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
})
}
}
func Test_ec2Repository_ListAllKeyPairs(t *testing.T) {
tests := []struct {
name string
mocks func(client *MockEC2Client)
want []*ec2.KeyPairInfo
wantErr error
}{
{
name: "List address",
mocks: func(client *MockEC2Client) {
client.On("DescribeKeyPairs", &ec2.DescribeKeyPairsInput{}).
Return(&ec2.DescribeKeyPairsOutput{
KeyPairs: []*ec2.KeyPairInfo{
{KeyPairId: aws.String("1")},
{KeyPairId: aws.String("2")},
{KeyPairId: aws.String("3")},
{KeyPairId: aws.String("4")},
},
}, nil).Once()
},
want: []*ec2.KeyPairInfo{
{KeyPairId: aws.String("1")},
{KeyPairId: aws.String("2")},
{KeyPairId: aws.String("3")},
{KeyPairId: aws.String("4")},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := cache.New(1)
client := &MockEC2Client{}
tt.mocks(client)
r := &ec2Repository{
client: client,
cache: store,
}
got, err := r.ListAllKeyPairs()
assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, err := r.ListAllKeyPairs()
assert.NoError(t, err)
assert.Equal(t, got, cachedData)
assert.IsType(t, []*ec2.KeyPairInfo{}, store.Get("ec2ListAllKeyPairs"))
}
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
})
}
}
func Test_ec2Repository_ListAllInternetGateways(t *testing.T) {
tests := []struct {
name string
mocks func(client *MockEC2Client)
want []*ec2.InternetGateway
wantErr error
}{
{
name: "List only gateways with multiple pages",
mocks: func(client *MockEC2Client) {
client.On("DescribeInternetGatewaysPages",
&ec2.DescribeInternetGatewaysInput{},
mock.MatchedBy(func(callback func(res *ec2.DescribeInternetGatewaysOutput, lastPage bool) bool) bool {
callback(&ec2.DescribeInternetGatewaysOutput{
InternetGateways: []*ec2.InternetGateway{
{
InternetGatewayId: aws.String("Internet-0"),
},
{
InternetGatewayId: aws.String("Internet-1"),
},
},
}, false)
callback(&ec2.DescribeInternetGatewaysOutput{
InternetGateways: []*ec2.InternetGateway{
{
InternetGatewayId: aws.String("Internet-2"),
},
{
InternetGatewayId: aws.String("Internet-3"),
},
},
}, true)
return true
})).Return(nil).Once()
},
want: []*ec2.InternetGateway{
{
InternetGatewayId: aws.String("Internet-0"),
},
{
InternetGatewayId: aws.String("Internet-1"),
},
{
InternetGatewayId: aws.String("Internet-2"),
},
{
InternetGatewayId: aws.String("Internet-3"),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := cache.New(1)
client := &MockEC2Client{}
tt.mocks(client)
r := &ec2Repository{
client: client,
cache: store,
}
got, err := r.ListAllInternetGateways()
assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, err := r.ListAllInternetGateways()
assert.NoError(t, err)
assert.Equal(t, got, cachedData)
assert.IsType(t, []*ec2.InternetGateway{}, store.Get("ec2ListAllInternetGateways"))
}
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
})
}
}
func Test_ec2Repository_ListAllSubnets(t *testing.T) {
tests := []struct {
name string
mocks func(client *MockEC2Client)
wantSubnet []*ec2.Subnet
wantDefaultSubnet []*ec2.Subnet
wantErr error
}{
{
name: "List with 2 pages",
mocks: func(client *MockEC2Client) {
client.On("DescribeSubnetsPages",
&ec2.DescribeSubnetsInput{},
mock.MatchedBy(func(callback func(res *ec2.DescribeSubnetsOutput, lastPage bool) bool) bool {
callback(&ec2.DescribeSubnetsOutput{
Subnets: []*ec2.Subnet{
{
SubnetId: aws.String("subnet-0b13f1e0eacf67424"), // subnet2
DefaultForAz: aws.Bool(false),
},
{
SubnetId: aws.String("subnet-0c9b78001fe186e22"), // subnet3
DefaultForAz: aws.Bool(false),
},
{
SubnetId: aws.String("subnet-05810d3f933925f6d"), // subnet1
DefaultForAz: aws.Bool(false),
},
},
}, false)
callback(&ec2.DescribeSubnetsOutput{
Subnets: []*ec2.Subnet{
{
SubnetId: aws.String("subnet-44fe0c65"), // us-east-1a
DefaultForAz: aws.Bool(true),
},
{
SubnetId: aws.String("subnet-65e16628"), // us-east-1b
DefaultForAz: aws.Bool(true),
},
{
SubnetId: aws.String("subnet-afa656f0"), // us-east-1c
DefaultForAz: aws.Bool(true),
},
},
}, true)
return true
})).Return(nil).Once()
},
wantSubnet: []*ec2.Subnet{
{
SubnetId: aws.String("subnet-0b13f1e0eacf67424"), // subnet2
DefaultForAz: aws.Bool(false),
},
{
SubnetId: aws.String("subnet-0c9b78001fe186e22"), // subnet3
DefaultForAz: aws.Bool(false),
},
{
SubnetId: aws.String("subnet-05810d3f933925f6d"), // subnet1
DefaultForAz: aws.Bool(false),
},
},
wantDefaultSubnet: []*ec2.Subnet{
{
SubnetId: aws.String("subnet-44fe0c65"), // us-east-1a
DefaultForAz: aws.Bool(true),
},
{
SubnetId: aws.String("subnet-65e16628"), // us-east-1b
DefaultForAz: aws.Bool(true),
},
{
SubnetId: aws.String("subnet-afa656f0"), // us-east-1c
DefaultForAz: aws.Bool(true),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := cache.New(2)
client := &MockEC2Client{}
tt.mocks(client)
r := &ec2Repository{
client: client,
cache: store,
}
gotSubnet, gotDefaultSubnet, err := r.ListAllSubnets()
assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, cachedDefaultData, err := r.ListAllSubnets()
assert.NoError(t, err)
assert.Equal(t, gotSubnet, cachedData)
assert.Equal(t, gotDefaultSubnet, cachedDefaultData)
assert.IsType(t, []*ec2.Subnet{}, store.Get("ec2ListAllSubnets"))
assert.IsType(t, []*ec2.Subnet{}, store.Get("ec2ListAllDefaultSubnets"))
}
changelog, err := diff.Diff(gotSubnet, tt.wantSubnet)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
changelog, err = diff.Diff(gotDefaultSubnet, tt.wantDefaultSubnet)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
})
}
}
func Test_ec2Repository_ListAllNatGateways(t *testing.T) {
tests := []struct {
name string
mocks func(client *MockEC2Client)
want []*ec2.NatGateway
wantErr error
}{
{
name: "List only gateways with multiple pages",
mocks: func(client *MockEC2Client) {
client.On("DescribeNatGatewaysPages",
&ec2.DescribeNatGatewaysInput{},
mock.MatchedBy(func(callback func(res *ec2.DescribeNatGatewaysOutput, lastPage bool) bool) bool {
callback(&ec2.DescribeNatGatewaysOutput{
NatGateways: []*ec2.NatGateway{
{
NatGatewayId: aws.String("nat-0"),
},
{
NatGatewayId: aws.String("nat-1"),
},
},
}, false)
callback(&ec2.DescribeNatGatewaysOutput{
NatGateways: []*ec2.NatGateway{
{
NatGatewayId: aws.String("nat-2"),
},
{
NatGatewayId: aws.String("nat-3"),
},
},
}, true)
return true
})).Return(nil).Once()
},
want: []*ec2.NatGateway{
{
NatGatewayId: aws.String("nat-0"),
},
{
NatGatewayId: aws.String("nat-1"),
},
{
NatGatewayId: aws.String("nat-2"),
},
{
NatGatewayId: aws.String("nat-3"),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := cache.New(1)
client := &MockEC2Client{}
tt.mocks(client)
r := &ec2Repository{
client: client,
cache: store,
}
got, err := r.ListAllNatGateways()
assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, err := r.ListAllNatGateways()
assert.NoError(t, err)
assert.Equal(t, got, cachedData)
assert.IsType(t, []*ec2.NatGateway{}, store.Get("ec2ListAllNatGateways"))
}
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
})
}
}
func Test_ec2Repository_ListAllRouteTables(t *testing.T) {
tests := []struct {
name string
mocks func(client *MockEC2Client)
want []*ec2.RouteTable
wantErr error
}{
{
name: "List only route with multiple pages",
mocks: func(client *MockEC2Client) {
client.On("DescribeRouteTablesPages",
&ec2.DescribeRouteTablesInput{},
mock.MatchedBy(func(callback func(res *ec2.DescribeRouteTablesOutput, lastPage bool) bool) bool {
callback(&ec2.DescribeRouteTablesOutput{
RouteTables: []*ec2.RouteTable{
{
RouteTableId: aws.String("rtb-096bdfb69309c54c3"), // table1
Routes: []*ec2.Route{
{
DestinationCidrBlock: aws.String("10.0.0.0/16"),
Origin: aws.String("CreateRouteTable"), // default route
},
{
DestinationCidrBlock: aws.String("1.1.1.1/32"),
GatewayId: aws.String("igw-030e74f73bd67f21b"),
},
{
DestinationIpv6CidrBlock: aws.String("::/0"),
GatewayId: aws.String("igw-030e74f73bd67f21b"),
},
},
},
{
RouteTableId: aws.String("rtb-0169b0937fd963ddc"), // table2
Routes: []*ec2.Route{
{
DestinationCidrBlock: aws.String("10.0.0.0/16"),
Origin: aws.String("CreateRouteTable"), // default route
},
{
DestinationCidrBlock: aws.String("0.0.0.0/0"),
GatewayId: aws.String("igw-030e74f73bd67f21b"),
},
{
DestinationIpv6CidrBlock: aws.String("::/0"),
GatewayId: aws.String("igw-030e74f73bd67f21b"),
},
},
},
},
}, false)
callback(&ec2.DescribeRouteTablesOutput{
RouteTables: []*ec2.RouteTable{
{
RouteTableId: aws.String("rtb-02780c485f0be93c5"), // default_table
VpcId: aws.String("vpc-09fe5abc2309ba49d"),
Associations: []*ec2.RouteTableAssociation{
{
Main: aws.Bool(true),
},
},
Routes: []*ec2.Route{
{
DestinationCidrBlock: aws.String("10.0.0.0/16"),
Origin: aws.String("CreateRouteTable"), // default route
},
{
DestinationCidrBlock: aws.String("10.1.1.0/24"),
GatewayId: aws.String("igw-030e74f73bd67f21b"),
},
{
DestinationCidrBlock: aws.String("10.1.2.0/24"),
GatewayId: aws.String("igw-030e74f73bd67f21b"),
},
},
},
{
RouteTableId: aws.String(""), // table3
Routes: []*ec2.Route{
{
DestinationCidrBlock: aws.String("10.0.0.0/16"),
Origin: aws.String("CreateRouteTable"), // default route
},
},
},
},
}, true)
return true
})).Return(nil).Once()
},
want: []*ec2.RouteTable{
{
RouteTableId: aws.String("rtb-096bdfb69309c54c3"), // table1
Routes: []*ec2.Route{
{
DestinationCidrBlock: aws.String("10.0.0.0/16"),
Origin: aws.String("CreateRouteTable"), // default route
},
{
DestinationCidrBlock: aws.String("1.1.1.1/32"),
GatewayId: aws.String("igw-030e74f73bd67f21b"),
},
{
DestinationIpv6CidrBlock: aws.String("::/0"),
GatewayId: aws.String("igw-030e74f73bd67f21b"),
},
},
},
{
RouteTableId: aws.String("rtb-0169b0937fd963ddc"), // table2
Routes: []*ec2.Route{
{
DestinationCidrBlock: aws.String("10.0.0.0/16"),
Origin: aws.String("CreateRouteTable"), // default route
},
{
DestinationCidrBlock: aws.String("0.0.0.0/0"),
GatewayId: aws.String("igw-030e74f73bd67f21b"),
},
{
DestinationIpv6CidrBlock: aws.String("::/0"),
GatewayId: aws.String("igw-030e74f73bd67f21b"),
},
},
},
{
RouteTableId: aws.String("rtb-02780c485f0be93c5"), // default_table
VpcId: aws.String("vpc-09fe5abc2309ba49d"),
Associations: []*ec2.RouteTableAssociation{
{
Main: aws.Bool(true),
},
},
Routes: []*ec2.Route{
{
DestinationCidrBlock: aws.String("10.0.0.0/16"),
Origin: aws.String("CreateRouteTable"), // default route
},
{
DestinationCidrBlock: aws.String("10.1.1.0/24"),
GatewayId: aws.String("igw-030e74f73bd67f21b"),
},
{
DestinationCidrBlock: aws.String("10.1.2.0/24"),
GatewayId: aws.String("igw-030e74f73bd67f21b"),
},
},
},
{
RouteTableId: aws.String(""), // table3
Routes: []*ec2.Route{
{
DestinationCidrBlock: aws.String("10.0.0.0/16"),
Origin: aws.String("CreateRouteTable"), // default route
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := cache.New(1)
client := &MockEC2Client{}
tt.mocks(client)
r := &ec2Repository{
client: client,
cache: store,
}
got, err := r.ListAllRouteTables()
assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, err := r.ListAllRouteTables()
assert.NoError(t, err)
assert.Equal(t, got, cachedData)
assert.IsType(t, []*ec2.RouteTable{}, store.Get("ec2ListAllRouteTables"))
}
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
})
}
}
func Test_ec2Repository_ListAllVPCs(t *testing.T) {
tests := []struct {
name string
mocks func(client *MockEC2Client)
wantVPC []*ec2.Vpc
wantDefaultVPC []*ec2.Vpc
wantErr error
}{
{
name: "mixed default VPC and VPC",
mocks: func(client *MockEC2Client) {
client.On("DescribeVpcsPages",
&ec2.DescribeVpcsInput{},
mock.MatchedBy(func(callback func(res *ec2.DescribeVpcsOutput, lastPage bool) bool) bool {
callback(&ec2.DescribeVpcsOutput{
Vpcs: []*ec2.Vpc{
{
VpcId: aws.String("vpc-a8c5d4c1"),
IsDefault: aws.Bool(true),
},
{
VpcId: aws.String("vpc-0768e1fd0029e3fc3"),
},
{
VpcId: aws.String("vpc-020b072316a95b97f"),
IsDefault: aws.Bool(false),
},
},
}, false)
callback(&ec2.DescribeVpcsOutput{
Vpcs: []*ec2.Vpc{
{
VpcId: aws.String("vpc-02c50896b59598761"),
IsDefault: aws.Bool(false),
},
},
}, true)
return true
})).Return(nil).Once()
},
wantVPC: []*ec2.Vpc{
{
VpcId: aws.String("vpc-0768e1fd0029e3fc3"),
},
{
VpcId: aws.String("vpc-020b072316a95b97f"),
IsDefault: aws.Bool(false),
},
{
VpcId: aws.String("vpc-02c50896b59598761"),
IsDefault: aws.Bool(false),
},
},
wantDefaultVPC: []*ec2.Vpc{
{
VpcId: aws.String("vpc-a8c5d4c1"),
IsDefault: aws.Bool(true),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := cache.New(2)
client := &MockEC2Client{}
tt.mocks(client)
r := &ec2Repository{
client: client,
cache: store,
}
gotVPCs, gotDefaultVPCs, err := r.ListAllVPCs()
assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, cachedDefaultData, err := r.ListAllVPCs()
assert.NoError(t, err)
assert.Equal(t, gotVPCs, cachedData)
assert.Equal(t, gotDefaultVPCs, cachedDefaultData)
assert.IsType(t, []*ec2.Vpc{}, store.Get("ec2ListAllVPCs"))
assert.IsType(t, []*ec2.Vpc{}, store.Get("ec2ListAllDefaultVPCs"))
}
changelog, err := diff.Diff(gotVPCs, tt.wantVPC)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
changelog, err = diff.Diff(gotDefaultVPCs, tt.wantDefaultVPC)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
})
}
}
func Test_ec2Repository_ListAllSecurityGroups(t *testing.T) {
tests := []struct {
name string
mocks func(client *MockEC2Client)
wantSecurityGroup []*ec2.SecurityGroup
wantDefaultSecurityGroup []*ec2.SecurityGroup
wantErr error
}{
{
name: "List with 1 pages",
mocks: func(client *MockEC2Client) {
client.On("DescribeSecurityGroupsPages",
&ec2.DescribeSecurityGroupsInput{},
mock.MatchedBy(func(callback func(res *ec2.DescribeSecurityGroupsOutput, lastPage bool) bool) bool {
callback(&ec2.DescribeSecurityGroupsOutput{
SecurityGroups: []*ec2.SecurityGroup{
{
GroupId: aws.String("sg-0254c038e32f25530"),
GroupName: aws.String("foo"),
},
{
GroupId: aws.String("sg-9e0204ff"),
GroupName: aws.String("default"),
},
},
}, true)
return true
})).Return(nil).Once()
},
wantSecurityGroup: []*ec2.SecurityGroup{
{
GroupId: aws.String("sg-0254c038e32f25530"),
GroupName: aws.String("foo"),
},
},
wantDefaultSecurityGroup: []*ec2.SecurityGroup{
{
GroupId: aws.String("sg-9e0204ff"),
GroupName: aws.String("default"),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := cache.New(2)
client := &MockEC2Client{}
tt.mocks(client)
r := &ec2Repository{
client: client,
cache: store,
}
gotSecurityGroups, gotDefaultSecurityGroups, err := r.ListAllSecurityGroups()
assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, cachedDefaultData, err := r.ListAllSecurityGroups()
assert.NoError(t, err)
assert.Equal(t, gotSecurityGroups, cachedData)
assert.Equal(t, gotDefaultSecurityGroups, cachedDefaultData)
assert.IsType(t, []*ec2.SecurityGroup{}, store.Get("ec2ListAllSecurityGroups"))
assert.IsType(t, []*ec2.SecurityGroup{}, store.Get("ec2ListAllDefaultSecurityGroups"))
}
changelog, err := diff.Diff(gotSecurityGroups, tt.wantSecurityGroup)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
changelog, err = diff.Diff(gotDefaultSecurityGroups, tt.wantDefaultSecurityGroup)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
})
}
}