feat: implement cache in RDS repository

main
sundowndev 2021-06-04 16:01:06 +02:00
parent 34297d4b79
commit 991e777364
5 changed files with 46 additions and 6 deletions

View File

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/cloudskiff/driftctl/pkg/remote/aws/repository" "github.com/cloudskiff/driftctl/pkg/remote/aws/repository"
"github.com/cloudskiff/driftctl/pkg/remote/cache"
testresource "github.com/cloudskiff/driftctl/test/resource" testresource "github.com/cloudskiff/driftctl/test/resource"
remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error" remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error"
@ -112,7 +113,7 @@ func TestDBInstanceSupplier_Resources(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
supplierLibrary.AddSupplier(NewDBInstanceSupplier(provider, deserializer, repository.NewRDSRepository(provider.session))) supplierLibrary.AddSupplier(NewDBInstanceSupplier(provider, deserializer, repository.NewRDSRepository(provider.session, cache.New(0))))
} }
t.Run(tt.test, func(t *testing.T) { t.Run(tt.test, func(t *testing.T) {

View File

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/cloudskiff/driftctl/pkg/remote/aws/repository" "github.com/cloudskiff/driftctl/pkg/remote/aws/repository"
"github.com/cloudskiff/driftctl/pkg/remote/cache"
testresource "github.com/cloudskiff/driftctl/test/resource" testresource "github.com/cloudskiff/driftctl/test/resource"
remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error" remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error"
@ -85,7 +86,7 @@ func TestDBSubnetGroupSupplier_Resources(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
supplierLibrary.AddSupplier(NewDBInstanceSupplier(provider, deserializer, repository.NewRDSRepository(provider.session))) supplierLibrary.AddSupplier(NewDBInstanceSupplier(provider, deserializer, repository.NewRDSRepository(provider.session, cache.New(0))))
} }
t.Run(tt.test, func(t *testing.T) { t.Run(tt.test, func(t *testing.T) {

View File

@ -40,7 +40,7 @@ func Init(version string, alerter *alerter.Alerter,
ec2repository := repository.NewEC2Repository(provider.session, repositoryCache) ec2repository := repository.NewEC2Repository(provider.session, repositoryCache)
route53repository := repository.NewRoute53Repository(provider.session) route53repository := repository.NewRoute53Repository(provider.session)
lambdaRepository := repository.NewLambdaRepository(provider.session) lambdaRepository := repository.NewLambdaRepository(provider.session)
rdsRepository := repository.NewRDSRepository(provider.session) rdsRepository := repository.NewRDSRepository(provider.session, repositoryCache)
sqsRepository := repository.NewSQSClient(provider.session) sqsRepository := repository.NewSQSClient(provider.session)
snsRepository := repository.NewSNSClient(provider.session) snsRepository := repository.NewSNSClient(provider.session)
dynamoDBRepository := repository.NewDynamoDBRepository(provider.session) dynamoDBRepository := repository.NewDynamoDBRepository(provider.session)

View File

@ -4,6 +4,7 @@ import (
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/rds"
"github.com/aws/aws-sdk-go/service/rds/rdsiface" "github.com/aws/aws-sdk-go/service/rds/rdsiface"
"github.com/cloudskiff/driftctl/pkg/remote/cache"
) )
type RDSClient interface { type RDSClient interface {
@ -17,15 +18,21 @@ type RDSRepository interface {
type rdsRepository struct { type rdsRepository struct {
client rdsiface.RDSAPI client rdsiface.RDSAPI
cache cache.Cache
} }
func NewRDSRepository(session *session.Session) *rdsRepository { func NewRDSRepository(session *session.Session, c cache.Cache) *rdsRepository {
return &rdsRepository{ return &rdsRepository{
rds.New(session), rds.New(session),
c,
} }
} }
func (r *rdsRepository) ListAllDBInstances() ([]*rds.DBInstance, error) { func (r *rdsRepository) ListAllDBInstances() ([]*rds.DBInstance, error) {
if v := r.cache.Get("rdsListAllDBInstances"); v != nil {
return v.([]*rds.DBInstance), nil
}
var result []*rds.DBInstance var result []*rds.DBInstance
input := &rds.DescribeDBInstancesInput{} input := &rds.DescribeDBInstancesInput{}
err := r.client.DescribeDBInstancesPages(input, func(res *rds.DescribeDBInstancesOutput, lastPage bool) bool { err := r.client.DescribeDBInstancesPages(input, func(res *rds.DescribeDBInstancesOutput, lastPage bool) bool {
@ -35,10 +42,16 @@ func (r *rdsRepository) ListAllDBInstances() ([]*rds.DBInstance, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
r.cache.Put("rdsListAllDBInstances", result)
return result, nil return result, nil
} }
func (r *rdsRepository) ListAllDbSubnetGroups() ([]*rds.DBSubnetGroup, error) { func (r *rdsRepository) ListAllDbSubnetGroups() ([]*rds.DBSubnetGroup, error) {
if v := r.cache.Get("rdsListAllDbSubnetGroups"); v != nil {
return v.([]*rds.DBSubnetGroup), nil
}
var subnetGroups []*rds.DBSubnetGroup var subnetGroups []*rds.DBSubnetGroup
input := rds.DescribeDBSubnetGroupsInput{} input := rds.DescribeDBSubnetGroupsInput{}
err := r.client.DescribeDBSubnetGroupsPages(&input, err := r.client.DescribeDBSubnetGroupsPages(&input,
@ -47,5 +60,7 @@ func (r *rdsRepository) ListAllDbSubnetGroups() ([]*rds.DBSubnetGroup, error) {
return !lastPage return !lastPage
}, },
) )
r.cache.Put("rdsListAllDbSubnetGroups", subnetGroups)
return subnetGroups, err return subnetGroups, err
} }

View File

@ -6,6 +6,7 @@ import (
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/rds"
"github.com/cloudskiff/driftctl/pkg/remote/cache"
"github.com/r3labs/diff/v2" "github.com/r3labs/diff/v2"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
@ -39,7 +40,7 @@ func Test_rdsRepository_ListAllDBInstances(t *testing.T) {
}, },
}, true) }, true)
return true return true
})).Return(nil) })).Return(nil).Once()
}, },
want: []*rds.DBInstance{ want: []*rds.DBInstance{
{DBInstanceIdentifier: aws.String("1")}, {DBInstanceIdentifier: aws.String("1")},
@ -53,13 +54,24 @@ func Test_rdsRepository_ListAllDBInstances(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
store := cache.New(1)
client := &MockRDSClient{} client := &MockRDSClient{}
tt.mocks(client) tt.mocks(client)
r := &rdsRepository{ r := &rdsRepository{
client: client, client: client,
cache: store,
} }
got, err := r.ListAllDBInstances() got, err := r.ListAllDBInstances()
assert.Equal(t, tt.wantErr, err) assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, err := r.ListAllDBInstances()
assert.Nil(t, err)
assert.Equal(t, got, cachedData)
assert.IsType(t, []*rds.DBInstance{}, store.Get("rdsListAllDBInstances"))
}
changelog, err := diff.Diff(got, tt.want) changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err) assert.Nil(t, err)
if len(changelog) > 0 { if len(changelog) > 0 {
@ -100,7 +112,7 @@ func Test_rdsRepository_ListAllDbSubnetGroups(t *testing.T) {
}, },
}, true) }, true)
return true return true
})).Return(nil) })).Return(nil).Once()
}, },
want: []*rds.DBSubnetGroup{ want: []*rds.DBSubnetGroup{
{DBSubnetGroupName: aws.String("1")}, {DBSubnetGroupName: aws.String("1")},
@ -114,13 +126,24 @@ func Test_rdsRepository_ListAllDbSubnetGroups(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
store := cache.New(1)
client := &MockRDSClient{} client := &MockRDSClient{}
tt.mocks(client) tt.mocks(client)
r := &rdsRepository{ r := &rdsRepository{
client: client, client: client,
cache: store,
} }
got, err := r.ListAllDbSubnetGroups() got, err := r.ListAllDbSubnetGroups()
assert.Equal(t, tt.wantErr, err) assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, err := r.ListAllDbSubnetGroups()
assert.Nil(t, err)
assert.Equal(t, got, cachedData)
assert.IsType(t, []*rds.DBSubnetGroup{}, store.Get("rdsListAllDbSubnetGroups"))
}
changelog, err := diff.Diff(got, tt.want) changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err) assert.Nil(t, err)
if len(changelog) > 0 { if len(changelog) > 0 {