feat: implement cache in cloudfront repository

main
sundowndev 2021-06-04 17:36:38 +02:00
parent 34297d4b79
commit e14552efed
4 changed files with 28 additions and 7 deletions

View File

@ -6,6 +6,7 @@ import (
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/cloudskiff/driftctl/pkg/remote/aws/repository" "github.com/cloudskiff/driftctl/pkg/remote/aws/repository"
"github.com/cloudskiff/driftctl/pkg/remote/cache"
testresource "github.com/cloudskiff/driftctl/test/resource" testresource "github.com/cloudskiff/driftctl/test/resource"
"github.com/aws/aws-sdk-go/service/cloudfront" "github.com/aws/aws-sdk-go/service/cloudfront"
@ -76,7 +77,7 @@ func TestCloudfrontDistributionSupplier_Resources(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
supplierLibrary.AddSupplier(NewCloudfrontDistributionSupplier(provider, deserializer, repository.NewCloudfrontClient(provider.session))) supplierLibrary.AddSupplier(NewCloudfrontDistributionSupplier(provider, deserializer, repository.NewCloudfrontClient(provider.session, cache.New(0))))
} }
t.Run(c.test, func(tt *testing.T) { t.Run(c.test, func(tt *testing.T) {

View File

@ -44,7 +44,7 @@ func Init(version string, alerter *alerter.Alerter,
sqsRepository := repository.NewSQSClient(provider.session) sqsRepository := repository.NewSQSClient(provider.session)
snsRepository := repository.NewSNSClient(provider.session) snsRepository := repository.NewSNSClient(provider.session)
dynamoDBRepository := repository.NewDynamoDBRepository(provider.session) dynamoDBRepository := repository.NewDynamoDBRepository(provider.session)
cloudfrontRepository := repository.NewCloudfrontClient(provider.session) cloudfrontRepository := repository.NewCloudfrontClient(provider.session, repositoryCache)
kmsRepository := repository.NewKMSRepository(provider.session) kmsRepository := repository.NewKMSRepository(provider.session)
ecrRepository := repository.NewECRRepository(provider.session) ecrRepository := repository.NewECRRepository(provider.session)
iamRepository := repository.NewIAMRepository(provider.session, repositoryCache) iamRepository := repository.NewIAMRepository(provider.session, repositoryCache)

View File

@ -4,6 +4,7 @@ import (
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/cloudfront" "github.com/aws/aws-sdk-go/service/cloudfront"
"github.com/aws/aws-sdk-go/service/cloudfront/cloudfrontiface" "github.com/aws/aws-sdk-go/service/cloudfront/cloudfrontiface"
"github.com/cloudskiff/driftctl/pkg/remote/cache"
) )
type CloudfrontRepository interface { type CloudfrontRepository interface {
@ -12,15 +13,21 @@ type CloudfrontRepository interface {
type cloudfrontRepository struct { type cloudfrontRepository struct {
client cloudfrontiface.CloudFrontAPI client cloudfrontiface.CloudFrontAPI
cache cache.Cache
} }
func NewCloudfrontClient(session *session.Session) *cloudfrontRepository { func NewCloudfrontClient(session *session.Session, c cache.Cache) *cloudfrontRepository {
return &cloudfrontRepository{ return &cloudfrontRepository{
cloudfront.New(session), cloudfront.New(session),
c,
} }
} }
func (r *cloudfrontRepository) ListAllDistributions() ([]*cloudfront.DistributionSummary, error) { func (r *cloudfrontRepository) ListAllDistributions() ([]*cloudfront.DistributionSummary, error) {
if v := r.cache.Get("cloudfrontListAllDistributions"); v != nil {
return v.([]*cloudfront.DistributionSummary), nil
}
var distributions []*cloudfront.DistributionSummary var distributions []*cloudfront.DistributionSummary
input := cloudfront.ListDistributionsInput{} input := cloudfront.ListDistributionsInput{}
err := r.client.ListDistributionsPages(&input, err := r.client.ListDistributionsPages(&input,
@ -34,5 +41,7 @@ func (r *cloudfrontRepository) ListAllDistributions() ([]*cloudfront.Distributio
if err != nil { if err != nil {
return nil, err return nil, err
} }
r.cache.Put("cloudfrontListAllDistributions", distributions)
return distributions, nil return distributions, nil
} }

View File

@ -4,10 +4,10 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/aws/aws-sdk-go/service/cloudfront"
awstest "github.com/cloudskiff/driftctl/test/aws"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/cloudfront"
"github.com/cloudskiff/driftctl/pkg/remote/cache"
awstest "github.com/cloudskiff/driftctl/test/aws"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
@ -47,7 +47,7 @@ func Test_cloudfrontRepository_ListAllDistributions(t *testing.T) {
}, },
}, true) }, true)
return true return true
})).Return(nil) })).Return(nil).Once()
}, },
want: []*cloudfront.DistributionSummary{ want: []*cloudfront.DistributionSummary{
{Id: aws.String("distribution1")}, {Id: aws.String("distribution1")},
@ -61,13 +61,24 @@ func Test_cloudfrontRepository_ListAllDistributions(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
store := cache.New(1)
client := awstest.MockFakeCloudFront{} client := awstest.MockFakeCloudFront{}
tt.mocks(&client) tt.mocks(&client)
r := &cloudfrontRepository{ r := &cloudfrontRepository{
client: &client, client: &client,
cache: store,
} }
got, err := r.ListAllDistributions() got, err := r.ListAllDistributions()
assert.Equal(t, tt.wantErr, err) assert.Equal(t, tt.wantErr, err)
if err == nil {
// Check that results were cached
cachedData, err := r.ListAllDistributions()
assert.NoError(t, err)
assert.Equal(t, got, cachedData)
assert.IsType(t, []*cloudfront.DistributionSummary{}, store.Get("cloudfrontListAllDistributions"))
}
changelog, err := diff.Diff(got, tt.want) changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err) assert.Nil(t, err)
if len(changelog) > 0 { if len(changelog) > 0 {