diff --git a/pkg/remote/aws/repository/ec2_repository.go b/pkg/remote/aws/repository/ec2_repository.go index 4744641b..4cc7749a 100644 --- a/pkg/remote/aws/repository/ec2_repository.go +++ b/pkg/remote/aws/repository/ec2_repository.go @@ -35,7 +35,7 @@ func NewEC2Repository(session *session.Session, c cache.Cache) *ec2Repository { } func (r *ec2Repository) ListAllImages() ([]*ec2.Image, error) { - if v := r.cache.Get("ec2AllImages"); v != nil { + if v := r.cache.Get("ec2ListAllImages"); v != nil { return v.([]*ec2.Image), nil } @@ -48,12 +48,12 @@ func (r *ec2Repository) ListAllImages() ([]*ec2.Image, error) { if err != nil { return nil, err } - r.cache.Put("ec2AllImages", images.Images) + r.cache.Put("ec2ListAllImages", images.Images) return images.Images, err } func (r *ec2Repository) ListAllSnapshots() ([]*ec2.Snapshot, error) { - if v := r.cache.Get("ec2AllSnapshots"); v != nil { + if v := r.cache.Get("ec2ListAllSnapshots"); v != nil { return v.([]*ec2.Snapshot), nil } @@ -70,12 +70,12 @@ func (r *ec2Repository) ListAllSnapshots() ([]*ec2.Snapshot, error) { if err != nil { return nil, err } - r.cache.Put("ec2AllSnapshots", snapshots) + r.cache.Put("ec2ListAllSnapshots", snapshots) return snapshots, err } func (r *ec2Repository) ListAllVolumes() ([]*ec2.Volume, error) { - if v := r.cache.Get("ec2AllVolumes"); v != nil { + if v := r.cache.Get("ec2ListAllVolumes"); v != nil { return v.([]*ec2.Volume), nil } @@ -88,12 +88,12 @@ func (r *ec2Repository) ListAllVolumes() ([]*ec2.Volume, error) { if err != nil { return nil, err } - r.cache.Put("ec2AllVolumes", volumes) + r.cache.Put("ec2ListAllVolumes", volumes) return volumes, nil } func (r *ec2Repository) ListAllAddresses() ([]*ec2.Address, error) { - if v := r.cache.Get("ec2AllAddresses"); v != nil { + if v := r.cache.Get("ec2ListAllAddresses"); v != nil { return v.([]*ec2.Address), nil } @@ -102,12 +102,12 @@ func (r *ec2Repository) ListAllAddresses() ([]*ec2.Address, error) { if err != nil { return nil, err } - r.cache.Put("ec2AllAddresses", response.Addresses) + r.cache.Put("ec2ListAllAddresses", response.Addresses) return response.Addresses, nil } func (r *ec2Repository) ListAllAddressesAssociation() ([]string, error) { - if v := r.cache.Get("ec2AddressesAssociation"); v != nil { + if v := r.cache.Get("ec2ListAllAddressesAssociation"); v != nil { return v.([]string), nil } @@ -121,12 +121,12 @@ func (r *ec2Repository) ListAllAddressesAssociation() ([]string, error) { results = append(results, aws.StringValue(address.AssociationId)) } } - r.cache.Put("ec2AddressesAssociation", results) + r.cache.Put("ec2ListAllAddressesAssociation", results) return results, nil } func (r *ec2Repository) ListAllInstances() ([]*ec2.Instance, error) { - if v := r.cache.Get("ec2AllInstances"); v != nil { + if v := r.cache.Get("ec2ListAllInstances"); v != nil { return v.([]*ec2.Instance), nil } @@ -141,12 +141,12 @@ func (r *ec2Repository) ListAllInstances() ([]*ec2.Instance, error) { if err != nil { return nil, err } - r.cache.Put("ec2AllInstances", instances) + r.cache.Put("ec2ListAllInstances", instances) return instances, nil } func (r *ec2Repository) ListAllKeyPairs() ([]*ec2.KeyPairInfo, error) { - if v := r.cache.Get("ec2AllKeyPairs"); v != nil { + if v := r.cache.Get("ec2ListAllKeyPairs"); v != nil { return v.([]*ec2.KeyPairInfo), nil } @@ -155,6 +155,6 @@ func (r *ec2Repository) ListAllKeyPairs() ([]*ec2.KeyPairInfo, error) { if err != nil { return nil, err } - r.cache.Put("ec2AllKeyPairs", pairs.KeyPairs) + r.cache.Put("ec2ListAllKeyPairs", pairs.KeyPairs) return pairs.KeyPairs, err } diff --git a/pkg/remote/aws/repository/ec2_repository_test.go b/pkg/remote/aws/repository/ec2_repository_test.go index 2f094a0a..741c1ee9 100644 --- a/pkg/remote/aws/repository/ec2_repository_test.go +++ b/pkg/remote/aws/repository/ec2_repository_test.go @@ -1,7 +1,6 @@ package repository import ( - "reflect" "strings" "testing" @@ -51,14 +50,24 @@ func Test_ec2Repository_ListAllImages(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) client := &MockEC2Client{} tt.mocks(client) r := &ec2Repository{ client: client, - cache: cache.New(10), + cache: store, } got, err := r.ListAllImages() assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllImages() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.Image{}, store.Get("ec2ListAllImages")) + } + changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { @@ -67,11 +76,6 @@ func Test_ec2Repository_ListAllImages(t *testing.T) { } t.Fail() } - - // Check that results were cached - cachedData, err := r.ListAllImages() - assert.NoError(t, err) - assert.True(t, reflect.DeepEqual(got, cachedData)) }) } } @@ -125,14 +129,24 @@ func Test_ec2Repository_ListAllSnapshots(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) client := &MockEC2Client{} tt.mocks(client) r := &ec2Repository{ client: client, - cache: cache.New(10), + cache: store, } got, err := r.ListAllSnapshots() assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllSnapshots() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.Snapshot{}, store.Get("ec2ListAllSnapshots")) + } + changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { @@ -141,11 +155,6 @@ func Test_ec2Repository_ListAllSnapshots(t *testing.T) { } t.Fail() } - - // Check that results were cached - cachedData, err := r.ListAllSnapshots() - assert.NoError(t, err) - assert.True(t, reflect.DeepEqual(got, cachedData)) }) } } @@ -195,14 +204,24 @@ func Test_ec2Repository_ListAllVolumes(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) client := &MockEC2Client{} tt.mocks(client) r := &ec2Repository{ client: client, - cache: cache.New(10), + cache: store, } got, err := r.ListAllVolumes() assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllVolumes() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.Volume{}, store.Get("ec2ListAllVolumes")) + } + changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { @@ -211,11 +230,6 @@ func Test_ec2Repository_ListAllVolumes(t *testing.T) { } t.Fail() } - - // Check that results were cached - cachedData, err := r.ListAllVolumes() - assert.NoError(t, err) - assert.True(t, reflect.DeepEqual(got, cachedData)) }) } } @@ -250,14 +264,24 @@ func Test_ec2Repository_ListAllAddresses(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) client := &MockEC2Client{} tt.mocks(client) r := &ec2Repository{ client: client, - cache: cache.New(10), + cache: store, } got, err := r.ListAllAddresses() assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllAddresses() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.Address{}, store.Get("ec2ListAllAddresses")) + } + changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { @@ -266,11 +290,6 @@ func Test_ec2Repository_ListAllAddresses(t *testing.T) { } t.Fail() } - - // Check that results were cached - cachedData, err := r.ListAllAddresses() - assert.NoError(t, err) - assert.True(t, reflect.DeepEqual(got, cachedData)) }) } } @@ -305,14 +324,24 @@ func Test_ec2Repository_ListAllAddressesAssociation(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) client := &MockEC2Client{} tt.mocks(client) r := &ec2Repository{ client: client, - cache: cache.New(10), + cache: store, } got, err := r.ListAllAddressesAssociation() assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllAddressesAssociation() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []string{}, store.Get("ec2ListAllAddressesAssociation")) + } + changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { @@ -321,11 +350,6 @@ func Test_ec2Repository_ListAllAddressesAssociation(t *testing.T) { } t.Fail() } - - // Check that results were cached - cachedData, err := r.ListAllAddressesAssociation() - assert.NoError(t, err) - assert.True(t, reflect.DeepEqual(got, cachedData)) }) } } @@ -399,14 +423,23 @@ func Test_ec2Repository_ListAllInstances(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) client := &MockEC2Client{} tt.mocks(client) r := &ec2Repository{ client: client, - cache: cache.New(10), + cache: store, } got, err := r.ListAllInstances() assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllInstances() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.Instance{}, store.Get("ec2ListAllInstances")) + } changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { @@ -415,11 +448,6 @@ func Test_ec2Repository_ListAllInstances(t *testing.T) { } t.Fail() } - - // Check that results were cached - cachedData, err := r.ListAllInstances() - assert.NoError(t, err) - assert.True(t, reflect.DeepEqual(got, cachedData)) }) } } @@ -454,14 +482,24 @@ func Test_ec2Repository_ListAllKeyPairs(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) client := &MockEC2Client{} tt.mocks(client) r := &ec2Repository{ client: client, - cache: cache.New(10), + cache: store, } got, err := r.ListAllKeyPairs() assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllKeyPairs() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.KeyPairInfo{}, store.Get("ec2ListAllKeyPairs")) + } + changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { @@ -470,11 +508,6 @@ func Test_ec2Repository_ListAllKeyPairs(t *testing.T) { } t.Fail() } - - // Check that results were cached - cachedData, err := r.ListAllKeyPairs() - assert.NoError(t, err) - assert.True(t, reflect.DeepEqual(got, cachedData)) }) } }