diff --git a/pkg/remote/aws/ecr_repository_supplier_test.go b/pkg/remote/aws/ecr_repository_supplier_test.go index fd6f35f3..c913a33d 100644 --- a/pkg/remote/aws/ecr_repository_supplier_test.go +++ b/pkg/remote/aws/ecr_repository_supplier_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/aws/aws-sdk-go/service/ecr" + "github.com/cloudskiff/driftctl/pkg/remote/cache" testresource "github.com/cloudskiff/driftctl/test/resource" "github.com/aws/aws-sdk-go/aws" @@ -75,7 +76,7 @@ func TestEcrRepositorySupplier_Resources(t *testing.T) { if err != nil { t.Fatal(err) } - supplierLibrary.AddSupplier(NewECRRepositorySupplier(provider, deserializer, repository.NewECRRepository(provider.session))) + supplierLibrary.AddSupplier(NewECRRepositorySupplier(provider, deserializer, repository.NewECRRepository(provider.session, cache.New(0)))) } t.Run(c.test, func(tt *testing.T) { diff --git a/pkg/remote/aws/init.go b/pkg/remote/aws/init.go index ee7463c7..1c531a99 100644 --- a/pkg/remote/aws/init.go +++ b/pkg/remote/aws/init.go @@ -45,8 +45,8 @@ func Init(version string, alerter *alerter.Alerter, snsRepository := repository.NewSNSClient(provider.session) dynamoDBRepository := repository.NewDynamoDBRepository(provider.session, repositoryCache) cloudfrontRepository := repository.NewCloudfrontClient(provider.session) - kmsRepository := repository.NewKMSRepository(provider.session) - ecrRepository := repository.NewECRRepository(provider.session) + ecrRepository := repository.NewECRRepository(provider.session, repositoryCache) + kmsRepository := repository.NewKMSRepository(provider.session, repositoryCache) iamRepository := repository.NewIAMRepository(provider.session, repositoryCache) deserializer := resource.NewDeserializer(factory) diff --git a/pkg/remote/aws/kms_alias_supplier_test.go b/pkg/remote/aws/kms_alias_supplier_test.go index 2380abf1..13c62749 100644 --- a/pkg/remote/aws/kms_alias_supplier_test.go +++ b/pkg/remote/aws/kms_alias_supplier_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/aws/aws-sdk-go/service/kms" + "github.com/cloudskiff/driftctl/pkg/remote/cache" testresource "github.com/cloudskiff/driftctl/test/resource" "github.com/aws/aws-sdk-go/aws" @@ -77,7 +78,7 @@ func TestKMSAliasSupplier_Resources(t *testing.T) { if err != nil { t.Fatal(err) } - supplierLibrary.AddSupplier(NewKMSAliasSupplier(provider, deserializer, repository.NewKMSRepository(provider.session))) + supplierLibrary.AddSupplier(NewKMSAliasSupplier(provider, deserializer, repository.NewKMSRepository(provider.session, cache.New(0)))) } t.Run(c.test, func(tt *testing.T) { diff --git a/pkg/remote/aws/kms_key_supplier_test.go b/pkg/remote/aws/kms_key_supplier_test.go index faed896b..74bdf490 100644 --- a/pkg/remote/aws/kms_key_supplier_test.go +++ b/pkg/remote/aws/kms_key_supplier_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/aws/aws-sdk-go/service/kms" + "github.com/cloudskiff/driftctl/pkg/remote/cache" testresource "github.com/cloudskiff/driftctl/test/resource" "github.com/aws/aws-sdk-go/aws" @@ -77,7 +78,7 @@ func TestKMSKeySupplier_Resources(t *testing.T) { if err != nil { t.Fatal(err) } - supplierLibrary.AddSupplier(NewKMSKeySupplier(provider, deserializer, repository.NewKMSRepository(provider.session))) + supplierLibrary.AddSupplier(NewKMSKeySupplier(provider, deserializer, repository.NewKMSRepository(provider.session, cache.New(0)))) } t.Run(c.test, func(tt *testing.T) { diff --git a/pkg/remote/aws/repository/ecr_repository.go b/pkg/remote/aws/repository/ecr_repository.go index ed4ff0c2..fea43ccb 100644 --- a/pkg/remote/aws/repository/ecr_repository.go +++ b/pkg/remote/aws/repository/ecr_repository.go @@ -4,6 +4,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ecr" "github.com/aws/aws-sdk-go/service/ecr/ecriface" + "github.com/cloudskiff/driftctl/pkg/remote/cache" ) type ECRRepository interface { @@ -12,15 +13,21 @@ type ECRRepository interface { type ecrRepository struct { client ecriface.ECRAPI + cache cache.Cache } -func NewECRRepository(session *session.Session) *ecrRepository { +func NewECRRepository(session *session.Session, c cache.Cache) *ecrRepository { return &ecrRepository{ ecr.New(session), + c, } } func (r *ecrRepository) ListAllRepositories() ([]*ecr.Repository, error) { + if v := r.cache.Get("ecrListAllRepositories"); v != nil { + return v.([]*ecr.Repository), nil + } + var repositories []*ecr.Repository input := &ecr.DescribeRepositoriesInput{} err := r.client.DescribeRepositoriesPages(input, func(res *ecr.DescribeRepositoriesOutput, lastPage bool) bool { @@ -30,5 +37,7 @@ func (r *ecrRepository) ListAllRepositories() ([]*ecr.Repository, error) { if err != nil { return nil, err } + + r.cache.Put("ecrListAllRepositories", repositories) return repositories, nil } diff --git a/pkg/remote/aws/repository/ecr_repository_test.go b/pkg/remote/aws/repository/ecr_repository_test.go index c8189453..11323714 100644 --- a/pkg/remote/aws/repository/ecr_repository_test.go +++ b/pkg/remote/aws/repository/ecr_repository_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/aws/aws-sdk-go/service/ecr" + "github.com/cloudskiff/driftctl/pkg/remote/cache" awstest "github.com/cloudskiff/driftctl/test/aws" "github.com/aws/aws-sdk-go/aws" @@ -44,7 +45,7 @@ func Test_ecrRepository_ListAllRepositories(t *testing.T) { }, }, true) return true - })).Return(nil) + })).Return(nil).Once() }, want: []*ecr.Repository{ {RepositoryName: aws.String("1")}, @@ -58,13 +59,24 @@ func Test_ecrRepository_ListAllRepositories(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) client := awstest.MockFakeECR{} tt.mocks(&client) r := &ecrRepository{ client: &client, + cache: store, } got, err := r.ListAllRepositories() assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllRepositories() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ecr.Repository{}, store.Get("ecrListAllRepositories")) + } + changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { diff --git a/pkg/remote/aws/repository/kms_repository.go b/pkg/remote/aws/repository/kms_repository.go index 09f4dfa4..67eaa4f3 100644 --- a/pkg/remote/aws/repository/kms_repository.go +++ b/pkg/remote/aws/repository/kms_repository.go @@ -6,6 +6,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kms" "github.com/aws/aws-sdk-go/service/kms/kmsiface" + "github.com/cloudskiff/driftctl/pkg/remote/cache" ) type KMSRepository interface { @@ -15,15 +16,21 @@ type KMSRepository interface { type kmsRepository struct { client kmsiface.KMSAPI + cache cache.Cache } -func NewKMSRepository(session *session.Session) *kmsRepository { +func NewKMSRepository(session *session.Session, c cache.Cache) *kmsRepository { return &kmsRepository{ kms.New(session), + c, } } func (r *kmsRepository) ListAllKeys() ([]*kms.KeyListEntry, error) { + if v := r.cache.Get("kmsListAllKeys"); v != nil { + return v.([]*kms.KeyListEntry), nil + } + var keys []*kms.KeyListEntry input := kms.ListKeysInput{} err := r.client.ListKeysPages(&input, @@ -39,10 +46,16 @@ func (r *kmsRepository) ListAllKeys() ([]*kms.KeyListEntry, error) { if err != nil { return nil, err } + + r.cache.Put("kmsListAllKeys", customerKeys) return customerKeys, nil } func (r *kmsRepository) ListAllAliases() ([]*kms.AliasListEntry, error) { + if v := r.cache.Get("kmsListAllAliases"); v != nil { + return v.([]*kms.AliasListEntry), nil + } + var aliases []*kms.AliasListEntry input := kms.ListAliasesInput{} err := r.client.ListAliasesPages(&input, @@ -54,7 +67,10 @@ func (r *kmsRepository) ListAllAliases() ([]*kms.AliasListEntry, error) { if err != nil { return nil, err } - return r.filterAliases(aliases), nil + + result := r.filterAliases(aliases) + r.cache.Put("kmsListAllAliases", result) + return result, nil } func (r *kmsRepository) filterKeys(keys []*kms.KeyListEntry) ([]*kms.KeyListEntry, error) { diff --git a/pkg/remote/aws/repository/kms_repository_test.go b/pkg/remote/aws/repository/kms_repository_test.go index 80e57a8d..9850abfd 100644 --- a/pkg/remote/aws/repository/kms_repository_test.go +++ b/pkg/remote/aws/repository/kms_repository_test.go @@ -4,11 +4,10 @@ import ( "strings" "testing" - "github.com/aws/aws-sdk-go/service/kms" - awstest "github.com/cloudskiff/driftctl/test/aws" - "github.com/aws/aws-sdk-go/aws" - + "github.com/aws/aws-sdk-go/service/kms" + "github.com/cloudskiff/driftctl/pkg/remote/cache" + awstest "github.com/cloudskiff/driftctl/test/aws" "github.com/stretchr/testify/mock" "github.com/r3labs/diff/v2" @@ -36,7 +35,7 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) { }, }, true) return true - })).Return(nil) + })).Return(nil).Once() client.On("DescribeKey", &kms.DescribeKeyInput{ KeyId: aws.String("1"), @@ -45,7 +44,7 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) { KeyId: aws.String("1"), KeyManager: aws.String("CUSTOMER"), }, - }, nil) + }, nil).Once() client.On("DescribeKey", &kms.DescribeKeyInput{ KeyId: aws.String("2"), @@ -54,7 +53,7 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) { KeyId: aws.String("2"), KeyManager: aws.String("AWS"), }, - }, nil) + }, nil).Once() client.On("DescribeKey", &kms.DescribeKeyInput{ KeyId: aws.String("3"), @@ -63,7 +62,7 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) { KeyId: aws.String("3"), KeyManager: aws.String("AWS"), }, - }, nil) + }, nil).Once() }, want: []*kms.KeyListEntry{ {KeyId: aws.String("1")}, @@ -72,13 +71,24 @@ func Test_KMSRepository_ListAllKeys(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) client := awstest.MockFakeKMS{} tt.mocks(&client) r := &kmsRepository{ client: &client, + cache: store, } got, err := r.ListAllKeys() assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllKeys() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*kms.KeyListEntry{}, store.Get("kmsListAllKeys")) + } + changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { @@ -116,7 +126,7 @@ func Test_KMSRepository_ListAllAliases(t *testing.T) { }, }, true) return true - })).Return(nil) + })).Return(nil).Once() }, want: []*kms.AliasListEntry{ {AliasName: aws.String("alias/1")}, @@ -129,13 +139,24 @@ func Test_KMSRepository_ListAllAliases(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) client := awstest.MockFakeKMS{} tt.mocks(&client) r := &kmsRepository{ client: &client, + cache: store, } got, err := r.ListAllAliases() assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllAliases() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*kms.AliasListEntry{}, store.Get("kmsListAllAliases")) + } + changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 {