From 4f9fed2c3ad4a5bbb9e5ec2457fbd75df627f47d Mon Sep 17 00:00:00 2001 From: Martin Liu Date: Wed, 16 Nov 2022 17:30:50 -0800 Subject: [PATCH] fix: aws_lb_listener cache key --- .../remote/aws/repository/elbv2_repository.go | 7 +- .../aws/repository/elbv2_repository_test.go | 205 ++++++++++++------ 2 files changed, 142 insertions(+), 70 deletions(-) diff --git a/enumeration/remote/aws/repository/elbv2_repository.go b/enumeration/remote/aws/repository/elbv2_repository.go index 28639b15..d6a001e6 100644 --- a/enumeration/remote/aws/repository/elbv2_repository.go +++ b/enumeration/remote/aws/repository/elbv2_repository.go @@ -1,6 +1,8 @@ package repository import ( + "fmt" + "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/elbv2" "github.com/aws/aws-sdk-go/service/elbv2/elbv2iface" @@ -45,7 +47,8 @@ func (r *elbv2Repository) ListAllLoadBalancers() ([]*elbv2.LoadBalancer, error) } func (r *elbv2Repository) ListAllLoadBalancerListeners(loadBalancerArn string) ([]*elbv2.Listener, error) { - if v := r.cache.Get("elbv2ListAllLoadBalancerListeners"); v != nil { + cacheKey := fmt.Sprintf("elbv2ListAllLoadBalancerListeners_%s", loadBalancerArn) + if v := r.cache.Get(cacheKey); v != nil { return v.([]*elbv2.Listener), nil } @@ -60,6 +63,6 @@ func (r *elbv2Repository) ListAllLoadBalancerListeners(loadBalancerArn string) ( if err != nil { return nil, err } - r.cache.Put("elbv2ListAllLoadBalancerListeners", results) + r.cache.Put(cacheKey, results) return results, err } diff --git a/enumeration/remote/aws/repository/elbv2_repository_test.go b/enumeration/remote/aws/repository/elbv2_repository_test.go index 2ca320d6..bc615fe7 100644 --- a/enumeration/remote/aws/repository/elbv2_repository_test.go +++ b/enumeration/remote/aws/repository/elbv2_repository_test.go @@ -1,18 +1,17 @@ package repository import ( - "github.com/snyk/driftctl/enumeration/remote/cache" "strings" "testing" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/elbv2" "github.com/pkg/errors" - awstest "github.com/snyk/driftctl/test/aws" - "github.com/stretchr/testify/mock" - "github.com/r3labs/diff/v2" + "github.com/snyk/driftctl/enumeration/remote/cache" + awstest "github.com/snyk/driftctl/test/aws" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) func Test_ELBV2Repository_ListAllLoadBalancers(t *testing.T) { @@ -133,17 +132,53 @@ func Test_ELBV2Repository_ListAllLoadBalancers(t *testing.T) { func Test_ELBV2Repository_ListAllLoadBalancerListeners(t *testing.T) { dummyError := errors.New("dummy error") + type call struct { + loadBalancerArn string + mocks func(*awstest.MockFakeELBV2, *cache.MockCache) + want []*elbv2.Listener + wantErr error + } + tests := []struct { - name string - mocks func(*awstest.MockFakeELBV2, *cache.MockCache) - want []*elbv2.Listener - wantErr error + name string + calls []call }{ { name: "list load balancer listeners", - mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { - results := &elbv2.DescribeListenersOutput{ - Listeners: []*elbv2.Listener{ + calls: []call{ + { + loadBalancerArn: "test-lb", + mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { + results := &elbv2.DescribeListenersOutput{ + Listeners: []*elbv2.Listener{ + { + LoadBalancerArn: aws.String("test-lb"), + ListenerArn: aws.String("test-lb-listener-1"), + }, + { + LoadBalancerArn: aws.String("test-lb"), + ListenerArn: aws.String("test-lb-listener-2"), + }, + }, + } + + store.On("Get", "elbv2ListAllLoadBalancerListeners_test-lb").Return(nil).Once() + + client.On("DescribeListenersPages", + &elbv2.DescribeListenersInput{LoadBalancerArn: aws.String("test-lb")}, + mock.MatchedBy(func(callback func(res *elbv2.DescribeListenersOutput, lastPage bool) bool) bool { + callback(&elbv2.DescribeListenersOutput{Listeners: []*elbv2.Listener{ + results.Listeners[0], + }}, false) + callback(&elbv2.DescribeListenersOutput{Listeners: []*elbv2.Listener{ + results.Listeners[1], + }}, true) + return true + })).Return(nil).Once() + + store.On("Put", "elbv2ListAllLoadBalancerListeners_test-lb", results.Listeners).Return(false).Once() + }, + want: []*elbv2.Listener{ { LoadBalancerArn: aws.String("test-lb"), ListenerArn: aws.String("test-lb-listener-1"), @@ -153,90 +188,124 @@ func Test_ELBV2Repository_ListAllLoadBalancerListeners(t *testing.T) { ListenerArn: aws.String("test-lb-listener-2"), }, }, - } - - store.On("Get", "elbv2ListAllLoadBalancerListeners").Return(nil).Once() - - client.On("DescribeListenersPages", - &elbv2.DescribeListenersInput{LoadBalancerArn: aws.String("test-lb")}, - mock.MatchedBy(func(callback func(res *elbv2.DescribeListenersOutput, lastPage bool) bool) bool { - callback(&elbv2.DescribeListenersOutput{Listeners: []*elbv2.Listener{ - results.Listeners[0], - }}, false) - callback(&elbv2.DescribeListenersOutput{Listeners: []*elbv2.Listener{ - results.Listeners[1], - }}, true) - return true - })).Return(nil).Once() - - store.On("Put", "elbv2ListAllLoadBalancerListeners", results.Listeners).Return(false).Once() - }, - want: []*elbv2.Listener{ - { - LoadBalancerArn: aws.String("test-lb"), - ListenerArn: aws.String("test-lb-listener-1"), - }, - { - LoadBalancerArn: aws.String("test-lb"), - ListenerArn: aws.String("test-lb-listener-2"), }, }, }, { name: "list load balancer listeners from cache", - mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { - output := &elbv2.DescribeListenersOutput{ - Listeners: []*elbv2.Listener{ + calls: []call{ + { + loadBalancerArn: "test-lb", + mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { + output := &elbv2.DescribeListenersOutput{ + Listeners: []*elbv2.Listener{ + { + LoadBalancerArn: aws.String("test-lb"), + ListenerArn: aws.String("test-lb-listener"), + }, + }, + } + + store.On("Get", "elbv2ListAllLoadBalancerListeners_test-lb").Return(output.Listeners).Once() + }, + want: []*elbv2.Listener{ { LoadBalancerArn: aws.String("test-lb"), ListenerArn: aws.String("test-lb-listener"), }, }, - } - - store.On("Get", "elbv2ListAllLoadBalancerListeners").Return(output.Listeners).Once() + }, }, - want: []*elbv2.Listener{ + }, + { + name: "list load balancer listeners from multiple load balancers", + calls: []call{ { - LoadBalancerArn: aws.String("test-lb"), - ListenerArn: aws.String("test-lb-listener"), + loadBalancerArn: "test-lb-1", + mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { + output := &elbv2.DescribeListenersOutput{ + Listeners: []*elbv2.Listener{ + { + LoadBalancerArn: aws.String("test-lb-1"), + ListenerArn: aws.String("test-lb-1-listener"), + }, + }, + } + + store.On("Get", "elbv2ListAllLoadBalancerListeners_test-lb-1").Return(output.Listeners).Once() + }, + want: []*elbv2.Listener{ + { + LoadBalancerArn: aws.String("test-lb-1"), + ListenerArn: aws.String("test-lb-1-listener"), + }, + }, + }, + { + loadBalancerArn: "test-lb-2", + mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { + output := &elbv2.DescribeListenersOutput{ + Listeners: []*elbv2.Listener{ + { + LoadBalancerArn: aws.String("test-lb-2"), + ListenerArn: aws.String("test-lb-2-listener"), + }, + }, + } + + store.On("Get", "elbv2ListAllLoadBalancerListeners_test-lb-2").Return(output.Listeners).Once() + }, + want: []*elbv2.Listener{ + { + LoadBalancerArn: aws.String("test-lb-2"), + ListenerArn: aws.String("test-lb-2-listener"), + }, + }, }, }, }, { name: "error listing load balancer listeners", - mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { - store.On("Get", "elbv2ListAllLoadBalancerListeners").Return(nil).Once() + calls: []call{ + { + loadBalancerArn: "test-lb", + mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { + store.On("Get", "elbv2ListAllLoadBalancerListeners_test-lb").Return(nil).Once() - client.On("DescribeListenersPages", - &elbv2.DescribeListenersInput{LoadBalancerArn: aws.String("test-lb")}, - mock.MatchedBy(func(callback func(res *elbv2.DescribeListenersOutput, lastPage bool) bool) bool { - callback(&elbv2.DescribeListenersOutput{Listeners: []*elbv2.Listener{}}, true) - return true - })).Return(dummyError).Once() + client.On("DescribeListenersPages", + &elbv2.DescribeListenersInput{LoadBalancerArn: aws.String("test-lb")}, + mock.MatchedBy(func(callback func(res *elbv2.DescribeListenersOutput, lastPage bool) bool) bool { + callback(&elbv2.DescribeListenersOutput{Listeners: []*elbv2.Listener{}}, true) + return true + })).Return(dummyError).Once() + }, + wantErr: dummyError, + }, }, - wantErr: dummyError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := &cache.MockCache{} client := &awstest.MockFakeELBV2{} - tt.mocks(client, store) - r := &elbv2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllLoadBalancerListeners("test-lb") - assert.Equal(t, tt.wantErr, err) - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + for _, call := range tt.calls { + call.mocks(client, store) + r := &elbv2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllLoadBalancerListeners(call.loadBalancerArn) + assert.Equal(t, call.wantErr, err) + + changelog, err := diff.Diff(got, call.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() } - t.Fail() } }) }