From 34de289bd425ce23bb4ac28a25e46b2127fbc04a Mon Sep 17 00:00:00 2001 From: Elie Date: Fri, 16 Jul 2021 17:34:07 +0200 Subject: [PATCH] Add policy to aws_sqs_queue_policy enumerator --- .../aws/repository/mock_SQSRepository.go | 32 +++++++++- pkg/remote/aws/repository/sqs_repository.go | 23 +++++++ .../aws/repository/sqs_repository_test.go | 64 +++++++++++++++++++ pkg/remote/aws/sqs_queue_policy_enumerator.go | 9 ++- pkg/remote/sqs_scanner_test.go | 10 +++ 5 files changed, 134 insertions(+), 4 deletions(-) diff --git a/pkg/remote/aws/repository/mock_SQSRepository.go b/pkg/remote/aws/repository/mock_SQSRepository.go index 9028b57c..5e4c6bd2 100644 --- a/pkg/remote/aws/repository/mock_SQSRepository.go +++ b/pkg/remote/aws/repository/mock_SQSRepository.go @@ -1,14 +1,40 @@ -// Code generated by mockery v1.0.0. DO NOT EDIT. +// Code generated by mockery v0.0.0-dev. DO NOT EDIT. package repository -import "github.com/stretchr/testify/mock" +import ( + sqs "github.com/aws/aws-sdk-go/service/sqs" + mock "github.com/stretchr/testify/mock" +) -// MockSQSRepository is an autogenerated mock type for the MockSQSRepository type +// MockSQSRepository is an autogenerated mock type for the SQSRepository type type MockSQSRepository struct { mock.Mock } +// GetQueueAttributes provides a mock function with given fields: url +func (_m *MockSQSRepository) GetQueueAttributes(url string) (*sqs.GetQueueAttributesOutput, error) { + ret := _m.Called(url) + + var r0 *sqs.GetQueueAttributesOutput + if rf, ok := ret.Get(0).(func(string) *sqs.GetQueueAttributesOutput); ok { + r0 = rf(url) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sqs.GetQueueAttributesOutput) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(url) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // ListAllQueues provides a mock function with given fields: func (_m *MockSQSRepository) ListAllQueues() ([]*string, error) { ret := _m.Called() diff --git a/pkg/remote/aws/repository/sqs_repository.go b/pkg/remote/aws/repository/sqs_repository.go index 77dc387f..d4debc98 100644 --- a/pkg/remote/aws/repository/sqs_repository.go +++ b/pkg/remote/aws/repository/sqs_repository.go @@ -1,6 +1,9 @@ package repository import ( + "fmt" + + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/sqs" "github.com/aws/aws-sdk-go/service/sqs/sqsiface" @@ -9,6 +12,7 @@ import ( type SQSRepository interface { ListAllQueues() ([]*string, error) + GetQueueAttributes(url string) (*sqs.GetQueueAttributesOutput, error) } type sqsRepository struct { @@ -23,6 +27,25 @@ func NewSQSRepository(session *session.Session, c cache.Cache) *sqsRepository { } } +func (r *sqsRepository) GetQueueAttributes(url string) (*sqs.GetQueueAttributesOutput, error) { + cacheKey := fmt.Sprintf("sqsGetQueueAttributes_%s", url) + if v := r.cache.Get(cacheKey); v != nil { + return v.(*sqs.GetQueueAttributesOutput), nil + } + + attributes, err := r.client.GetQueueAttributes(&sqs.GetQueueAttributesInput{ + AttributeNames: aws.StringSlice([]string{sqs.QueueAttributeNamePolicy}), + QueueUrl: &url, + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, attributes) + + return attributes, nil +} + func (r *sqsRepository) ListAllQueues() ([]*string, error) { if v := r.cache.Get("sqsListAllQueues"); v != nil { return v.([]*string), nil diff --git a/pkg/remote/aws/repository/sqs_repository_test.go b/pkg/remote/aws/repository/sqs_repository_test.go index f623af14..237c1c26 100644 --- a/pkg/remote/aws/repository/sqs_repository_test.go +++ b/pkg/remote/aws/repository/sqs_repository_test.go @@ -79,3 +79,67 @@ func Test_sqsRepository_ListAllQueues(t *testing.T) { }) } } + +func Test_sqsRepository_GetQueueAttributes(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeSQS) + want *sqs.GetQueueAttributesOutput + wantErr error + }{ + { + name: "get attributes", + mocks: func(client *awstest.MockFakeSQS) { + client.On( + "GetQueueAttributes", + &sqs.GetQueueAttributesInput{ + AttributeNames: awssdk.StringSlice([]string{sqs.QueueAttributeNamePolicy}), + QueueUrl: awssdk.String("http://example.com"), + }, + ).Return( + &sqs.GetQueueAttributesOutput{ + Attributes: map[string]*string{ + sqs.QueueAttributeNamePolicy: awssdk.String("foobar"), + }, + }, + nil, + ).Once() + }, + want: &sqs.GetQueueAttributesOutput{ + Attributes: map[string]*string{ + sqs.QueueAttributeNamePolicy: awssdk.String("foobar"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := &awstest.MockFakeSQS{} + tt.mocks(client) + r := &sqsRepository{ + client: client, + cache: store, + } + got, err := r.GetQueueAttributes("http://example.com") + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.GetQueueAttributes("http://example.com") + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, &sqs.GetQueueAttributesOutput{}, store.Get("sqsGetQueueAttributes_http://example.com")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} diff --git a/pkg/remote/aws/sqs_queue_policy_enumerator.go b/pkg/remote/aws/sqs_queue_policy_enumerator.go index 90d8b581..b14e5dd6 100644 --- a/pkg/remote/aws/sqs_queue_policy_enumerator.go +++ b/pkg/remote/aws/sqs_queue_policy_enumerator.go @@ -1,6 +1,7 @@ package aws import ( + "github.com/aws/aws-sdk-go/service/sqs" "github.com/cloudskiff/driftctl/pkg/remote/aws/repository" remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error" "github.com/cloudskiff/driftctl/pkg/resource" @@ -34,12 +35,18 @@ func (e *SQSQueuePolicyEnumerator) Enumerate() ([]resource.Resource, error) { results := make([]resource.Resource, len(queues)) for _, queue := range queues { + attributes, err := e.repository.GetQueueAttributes(*queue) + if err != nil { + return nil, remoteerror.NewResourceEnumerationError(err, string(e.SupportedType())) + } results = append( results, e.factory.CreateAbstractResource( string(e.SupportedType()), awssdk.StringValue(queue), - map[string]interface{}{}, + map[string]interface{}{ + "policy": *attributes.Attributes[sqs.QueueAttributeNamePolicy], + }, ), ) } diff --git a/pkg/remote/sqs_scanner_test.go b/pkg/remote/sqs_scanner_test.go index 82331ba3..38e35e7a 100644 --- a/pkg/remote/sqs_scanner_test.go +++ b/pkg/remote/sqs_scanner_test.go @@ -6,6 +6,7 @@ import ( awssdk "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sqs" "github.com/cloudskiff/driftctl/mocks" "github.com/cloudskiff/driftctl/pkg/remote/aws" "github.com/cloudskiff/driftctl/pkg/remote/aws/repository" @@ -139,6 +140,15 @@ func TestSQSQueuePolicy(t *testing.T) { awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/foo"), awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/baz"), }, nil) + + client.On("GetQueueAttributes", mock.Anything).Return( + &sqs.GetQueueAttributesOutput{ + Attributes: map[string]*string{ + sqs.QueueAttributeNamePolicy: awssdk.String(""), + }, + }, + nil, + ) }, wantErr: nil, },