diff --git a/pkg/remote/aws/db_instance_supplier_test.go b/pkg/remote/aws/db_instance_supplier_test.go index 746a5f8e..278d625c 100644 --- a/pkg/remote/aws/db_instance_supplier_test.go +++ b/pkg/remote/aws/db_instance_supplier_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/cloudskiff/driftctl/pkg/remote/aws/repository" + "github.com/cloudskiff/driftctl/pkg/remote/cache" testresource "github.com/cloudskiff/driftctl/test/resource" remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error" @@ -112,7 +113,7 @@ func TestDBInstanceSupplier_Resources(t *testing.T) { if err != nil { t.Fatal(err) } - supplierLibrary.AddSupplier(NewDBInstanceSupplier(provider, deserializer, repository.NewRDSRepository(provider.session))) + supplierLibrary.AddSupplier(NewDBInstanceSupplier(provider, deserializer, repository.NewRDSRepository(provider.session, cache.New(0)))) } t.Run(tt.test, func(t *testing.T) { diff --git a/pkg/remote/aws/db_subnet_group_supplier_test.go b/pkg/remote/aws/db_subnet_group_supplier_test.go index 927cdc04..bcfbdef7 100644 --- a/pkg/remote/aws/db_subnet_group_supplier_test.go +++ b/pkg/remote/aws/db_subnet_group_supplier_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/cloudskiff/driftctl/pkg/remote/aws/repository" + "github.com/cloudskiff/driftctl/pkg/remote/cache" testresource "github.com/cloudskiff/driftctl/test/resource" remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error" @@ -85,7 +86,7 @@ func TestDBSubnetGroupSupplier_Resources(t *testing.T) { if err != nil { t.Fatal(err) } - supplierLibrary.AddSupplier(NewDBInstanceSupplier(provider, deserializer, repository.NewRDSRepository(provider.session))) + supplierLibrary.AddSupplier(NewDBInstanceSupplier(provider, deserializer, repository.NewRDSRepository(provider.session, cache.New(0)))) } t.Run(tt.test, func(t *testing.T) { diff --git a/pkg/remote/aws/init.go b/pkg/remote/aws/init.go index 2c0d96ce..03254ff2 100644 --- a/pkg/remote/aws/init.go +++ b/pkg/remote/aws/init.go @@ -40,7 +40,7 @@ func Init(version string, alerter *alerter.Alerter, ec2repository := repository.NewEC2Repository(provider.session, repositoryCache) route53repository := repository.NewRoute53Repository(provider.session) lambdaRepository := repository.NewLambdaRepository(provider.session) - rdsRepository := repository.NewRDSRepository(provider.session) + rdsRepository := repository.NewRDSRepository(provider.session, repositoryCache) sqsRepository := repository.NewSQSClient(provider.session) snsRepository := repository.NewSNSClient(provider.session) dynamoDBRepository := repository.NewDynamoDBRepository(provider.session) diff --git a/pkg/remote/aws/repository/rds_repository.go b/pkg/remote/aws/repository/rds_repository.go index e7590e12..da5702d9 100644 --- a/pkg/remote/aws/repository/rds_repository.go +++ b/pkg/remote/aws/repository/rds_repository.go @@ -4,6 +4,7 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/rds/rdsiface" + "github.com/cloudskiff/driftctl/pkg/remote/cache" ) type RDSClient interface { @@ -17,15 +18,21 @@ type RDSRepository interface { type rdsRepository struct { client rdsiface.RDSAPI + cache cache.Cache } -func NewRDSRepository(session *session.Session) *rdsRepository { +func NewRDSRepository(session *session.Session, c cache.Cache) *rdsRepository { return &rdsRepository{ rds.New(session), + c, } } func (r *rdsRepository) ListAllDBInstances() ([]*rds.DBInstance, error) { + if v := r.cache.Get("rdsListAllDBInstances"); v != nil { + return v.([]*rds.DBInstance), nil + } + var result []*rds.DBInstance input := &rds.DescribeDBInstancesInput{} err := r.client.DescribeDBInstancesPages(input, func(res *rds.DescribeDBInstancesOutput, lastPage bool) bool { @@ -35,10 +42,16 @@ func (r *rdsRepository) ListAllDBInstances() ([]*rds.DBInstance, error) { if err != nil { return nil, err } + + r.cache.Put("rdsListAllDBInstances", result) return result, nil } func (r *rdsRepository) ListAllDbSubnetGroups() ([]*rds.DBSubnetGroup, error) { + if v := r.cache.Get("rdsListAllDbSubnetGroups"); v != nil { + return v.([]*rds.DBSubnetGroup), nil + } + var subnetGroups []*rds.DBSubnetGroup input := rds.DescribeDBSubnetGroupsInput{} err := r.client.DescribeDBSubnetGroupsPages(&input, @@ -47,5 +60,7 @@ func (r *rdsRepository) ListAllDbSubnetGroups() ([]*rds.DBSubnetGroup, error) { return !lastPage }, ) + + r.cache.Put("rdsListAllDbSubnetGroups", subnetGroups) return subnetGroups, err } diff --git a/pkg/remote/aws/repository/rds_repository_test.go b/pkg/remote/aws/repository/rds_repository_test.go index 1778a240..8d64671a 100644 --- a/pkg/remote/aws/repository/rds_repository_test.go +++ b/pkg/remote/aws/repository/rds_repository_test.go @@ -6,6 +6,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/rds" + "github.com/cloudskiff/driftctl/pkg/remote/cache" "github.com/r3labs/diff/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -39,7 +40,7 @@ func Test_rdsRepository_ListAllDBInstances(t *testing.T) { }, }, true) return true - })).Return(nil) + })).Return(nil).Once() }, want: []*rds.DBInstance{ {DBInstanceIdentifier: aws.String("1")}, @@ -53,13 +54,24 @@ func Test_rdsRepository_ListAllDBInstances(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) client := &MockRDSClient{} tt.mocks(client) r := &rdsRepository{ client: client, + cache: store, } got, err := r.ListAllDBInstances() assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllDBInstances() + assert.Nil(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*rds.DBInstance{}, store.Get("rdsListAllDBInstances")) + } + changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 { @@ -100,7 +112,7 @@ func Test_rdsRepository_ListAllDbSubnetGroups(t *testing.T) { }, }, true) return true - })).Return(nil) + })).Return(nil).Once() }, want: []*rds.DBSubnetGroup{ {DBSubnetGroupName: aws.String("1")}, @@ -114,13 +126,24 @@ func Test_rdsRepository_ListAllDbSubnetGroups(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) client := &MockRDSClient{} tt.mocks(client) r := &rdsRepository{ client: client, + cache: store, } got, err := r.ListAllDbSubnetGroups() assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllDbSubnetGroups() + assert.Nil(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*rds.DBSubnetGroup{}, store.Get("rdsListAllDbSubnetGroups")) + } + changelog, err := diff.Diff(got, tt.want) assert.Nil(t, err) if len(changelog) > 0 {