feat: implement cache in RDS repository
parent
34297d4b79
commit
991e777364
|
@ -5,6 +5,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/cloudskiff/driftctl/pkg/remote/aws/repository"
|
"github.com/cloudskiff/driftctl/pkg/remote/aws/repository"
|
||||||
|
"github.com/cloudskiff/driftctl/pkg/remote/cache"
|
||||||
testresource "github.com/cloudskiff/driftctl/test/resource"
|
testresource "github.com/cloudskiff/driftctl/test/resource"
|
||||||
|
|
||||||
remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error"
|
remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error"
|
||||||
|
@ -112,7 +113,7 @@ func TestDBInstanceSupplier_Resources(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
supplierLibrary.AddSupplier(NewDBInstanceSupplier(provider, deserializer, repository.NewRDSRepository(provider.session)))
|
supplierLibrary.AddSupplier(NewDBInstanceSupplier(provider, deserializer, repository.NewRDSRepository(provider.session, cache.New(0))))
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run(tt.test, func(t *testing.T) {
|
t.Run(tt.test, func(t *testing.T) {
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/cloudskiff/driftctl/pkg/remote/aws/repository"
|
"github.com/cloudskiff/driftctl/pkg/remote/aws/repository"
|
||||||
|
"github.com/cloudskiff/driftctl/pkg/remote/cache"
|
||||||
testresource "github.com/cloudskiff/driftctl/test/resource"
|
testresource "github.com/cloudskiff/driftctl/test/resource"
|
||||||
|
|
||||||
remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error"
|
remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error"
|
||||||
|
@ -85,7 +86,7 @@ func TestDBSubnetGroupSupplier_Resources(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
supplierLibrary.AddSupplier(NewDBInstanceSupplier(provider, deserializer, repository.NewRDSRepository(provider.session)))
|
supplierLibrary.AddSupplier(NewDBInstanceSupplier(provider, deserializer, repository.NewRDSRepository(provider.session, cache.New(0))))
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run(tt.test, func(t *testing.T) {
|
t.Run(tt.test, func(t *testing.T) {
|
||||||
|
|
|
@ -40,7 +40,7 @@ func Init(version string, alerter *alerter.Alerter,
|
||||||
ec2repository := repository.NewEC2Repository(provider.session, repositoryCache)
|
ec2repository := repository.NewEC2Repository(provider.session, repositoryCache)
|
||||||
route53repository := repository.NewRoute53Repository(provider.session)
|
route53repository := repository.NewRoute53Repository(provider.session)
|
||||||
lambdaRepository := repository.NewLambdaRepository(provider.session)
|
lambdaRepository := repository.NewLambdaRepository(provider.session)
|
||||||
rdsRepository := repository.NewRDSRepository(provider.session)
|
rdsRepository := repository.NewRDSRepository(provider.session, repositoryCache)
|
||||||
sqsRepository := repository.NewSQSClient(provider.session)
|
sqsRepository := repository.NewSQSClient(provider.session)
|
||||||
snsRepository := repository.NewSNSClient(provider.session)
|
snsRepository := repository.NewSNSClient(provider.session)
|
||||||
dynamoDBRepository := repository.NewDynamoDBRepository(provider.session)
|
dynamoDBRepository := repository.NewDynamoDBRepository(provider.session)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"github.com/aws/aws-sdk-go/aws/session"
|
"github.com/aws/aws-sdk-go/aws/session"
|
||||||
"github.com/aws/aws-sdk-go/service/rds"
|
"github.com/aws/aws-sdk-go/service/rds"
|
||||||
"github.com/aws/aws-sdk-go/service/rds/rdsiface"
|
"github.com/aws/aws-sdk-go/service/rds/rdsiface"
|
||||||
|
"github.com/cloudskiff/driftctl/pkg/remote/cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RDSClient interface {
|
type RDSClient interface {
|
||||||
|
@ -17,15 +18,21 @@ type RDSRepository interface {
|
||||||
|
|
||||||
type rdsRepository struct {
|
type rdsRepository struct {
|
||||||
client rdsiface.RDSAPI
|
client rdsiface.RDSAPI
|
||||||
|
cache cache.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRDSRepository(session *session.Session) *rdsRepository {
|
func NewRDSRepository(session *session.Session, c cache.Cache) *rdsRepository {
|
||||||
return &rdsRepository{
|
return &rdsRepository{
|
||||||
rds.New(session),
|
rds.New(session),
|
||||||
|
c,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *rdsRepository) ListAllDBInstances() ([]*rds.DBInstance, error) {
|
func (r *rdsRepository) ListAllDBInstances() ([]*rds.DBInstance, error) {
|
||||||
|
if v := r.cache.Get("rdsListAllDBInstances"); v != nil {
|
||||||
|
return v.([]*rds.DBInstance), nil
|
||||||
|
}
|
||||||
|
|
||||||
var result []*rds.DBInstance
|
var result []*rds.DBInstance
|
||||||
input := &rds.DescribeDBInstancesInput{}
|
input := &rds.DescribeDBInstancesInput{}
|
||||||
err := r.client.DescribeDBInstancesPages(input, func(res *rds.DescribeDBInstancesOutput, lastPage bool) bool {
|
err := r.client.DescribeDBInstancesPages(input, func(res *rds.DescribeDBInstancesOutput, lastPage bool) bool {
|
||||||
|
@ -35,10 +42,16 @@ func (r *rdsRepository) ListAllDBInstances() ([]*rds.DBInstance, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.cache.Put("rdsListAllDBInstances", result)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *rdsRepository) ListAllDbSubnetGroups() ([]*rds.DBSubnetGroup, error) {
|
func (r *rdsRepository) ListAllDbSubnetGroups() ([]*rds.DBSubnetGroup, error) {
|
||||||
|
if v := r.cache.Get("rdsListAllDbSubnetGroups"); v != nil {
|
||||||
|
return v.([]*rds.DBSubnetGroup), nil
|
||||||
|
}
|
||||||
|
|
||||||
var subnetGroups []*rds.DBSubnetGroup
|
var subnetGroups []*rds.DBSubnetGroup
|
||||||
input := rds.DescribeDBSubnetGroupsInput{}
|
input := rds.DescribeDBSubnetGroupsInput{}
|
||||||
err := r.client.DescribeDBSubnetGroupsPages(&input,
|
err := r.client.DescribeDBSubnetGroupsPages(&input,
|
||||||
|
@ -47,5 +60,7 @@ func (r *rdsRepository) ListAllDbSubnetGroups() ([]*rds.DBSubnetGroup, error) {
|
||||||
return !lastPage
|
return !lastPage
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
r.cache.Put("rdsListAllDbSubnetGroups", subnetGroups)
|
||||||
return subnetGroups, err
|
return subnetGroups, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go/aws"
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
"github.com/aws/aws-sdk-go/service/rds"
|
"github.com/aws/aws-sdk-go/service/rds"
|
||||||
|
"github.com/cloudskiff/driftctl/pkg/remote/cache"
|
||||||
"github.com/r3labs/diff/v2"
|
"github.com/r3labs/diff/v2"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
|
@ -39,7 +40,7 @@ func Test_rdsRepository_ListAllDBInstances(t *testing.T) {
|
||||||
},
|
},
|
||||||
}, true)
|
}, true)
|
||||||
return true
|
return true
|
||||||
})).Return(nil)
|
})).Return(nil).Once()
|
||||||
},
|
},
|
||||||
want: []*rds.DBInstance{
|
want: []*rds.DBInstance{
|
||||||
{DBInstanceIdentifier: aws.String("1")},
|
{DBInstanceIdentifier: aws.String("1")},
|
||||||
|
@ -53,13 +54,24 @@ func Test_rdsRepository_ListAllDBInstances(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
store := cache.New(1)
|
||||||
client := &MockRDSClient{}
|
client := &MockRDSClient{}
|
||||||
tt.mocks(client)
|
tt.mocks(client)
|
||||||
r := &rdsRepository{
|
r := &rdsRepository{
|
||||||
client: client,
|
client: client,
|
||||||
|
cache: store,
|
||||||
}
|
}
|
||||||
got, err := r.ListAllDBInstances()
|
got, err := r.ListAllDBInstances()
|
||||||
assert.Equal(t, tt.wantErr, err)
|
assert.Equal(t, tt.wantErr, err)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
// Check that results were cached
|
||||||
|
cachedData, err := r.ListAllDBInstances()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, got, cachedData)
|
||||||
|
assert.IsType(t, []*rds.DBInstance{}, store.Get("rdsListAllDBInstances"))
|
||||||
|
}
|
||||||
|
|
||||||
changelog, err := diff.Diff(got, tt.want)
|
changelog, err := diff.Diff(got, tt.want)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
if len(changelog) > 0 {
|
if len(changelog) > 0 {
|
||||||
|
@ -100,7 +112,7 @@ func Test_rdsRepository_ListAllDbSubnetGroups(t *testing.T) {
|
||||||
},
|
},
|
||||||
}, true)
|
}, true)
|
||||||
return true
|
return true
|
||||||
})).Return(nil)
|
})).Return(nil).Once()
|
||||||
},
|
},
|
||||||
want: []*rds.DBSubnetGroup{
|
want: []*rds.DBSubnetGroup{
|
||||||
{DBSubnetGroupName: aws.String("1")},
|
{DBSubnetGroupName: aws.String("1")},
|
||||||
|
@ -114,13 +126,24 @@ func Test_rdsRepository_ListAllDbSubnetGroups(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
store := cache.New(1)
|
||||||
client := &MockRDSClient{}
|
client := &MockRDSClient{}
|
||||||
tt.mocks(client)
|
tt.mocks(client)
|
||||||
r := &rdsRepository{
|
r := &rdsRepository{
|
||||||
client: client,
|
client: client,
|
||||||
|
cache: store,
|
||||||
}
|
}
|
||||||
got, err := r.ListAllDbSubnetGroups()
|
got, err := r.ListAllDbSubnetGroups()
|
||||||
assert.Equal(t, tt.wantErr, err)
|
assert.Equal(t, tt.wantErr, err)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
// Check that results were cached
|
||||||
|
cachedData, err := r.ListAllDbSubnetGroups()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, got, cachedData)
|
||||||
|
assert.IsType(t, []*rds.DBSubnetGroup{}, store.Get("rdsListAllDbSubnetGroups"))
|
||||||
|
}
|
||||||
|
|
||||||
changelog, err := diff.Diff(got, tt.want)
|
changelog, err := diff.Diff(got, tt.want)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
if len(changelog) > 0 {
|
if len(changelog) > 0 {
|
||||||
|
|
Loading…
Reference in New Issue