Update SQS suppliers and tests

main
William Beuil 2021-02-03 21:08:57 +01:00
parent d3c542c004
commit 0d738f9dc4
No known key found for this signature in database
GPG Key ID: BED2072C5C2BF537
9 changed files with 224 additions and 94 deletions

33
mocks/SQSRepository.go Normal file
View File

@ -0,0 +1,33 @@
// Code generated by mockery v1.0.0. DO NOT EDIT.
package mocks
import mock "github.com/stretchr/testify/mock"
// SQSRepository is an autogenerated mock type for the SQSRepository type
type SQSRepository struct {
mock.Mock
}
// ListAllQueues provides a mock function with given fields:
func (_m *SQSRepository) ListAllQueues() ([]*string, error) {
ret := _m.Called()
var r0 []*string
if rf, ok := ret.Get(0).(func() []*string); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*string)
}
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}

View File

@ -26,6 +26,10 @@ func (m AwsSqsQueuePolicyExpander) Execute(_, resourcesFromState *[]resource.Res
queue, _ := res.(*aws.AwsSqsQueue)
newList = append(newList, res)
if queue.Policy == nil {
continue
}
if m.hasPolicyAttached(queue, resourcesFromState) {
queue.Policy = nil
continue
@ -41,10 +45,6 @@ func (m AwsSqsQueuePolicyExpander) Execute(_, resourcesFromState *[]resource.Res
}
func (m *AwsSqsQueuePolicyExpander) handlePolicy(queue *aws.AwsSqsQueue, results *[]resource.Resource) error {
if queue.Policy == nil || *queue.Policy == "" {
return nil
}
newPolicy := &aws.AwsSqsQueuePolicy{
Id: queue.Id,
QueueUrl: awssdk.String(queue.Id),

View File

@ -0,0 +1,36 @@
package repository
import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
)
type SQSRepository interface {
ListAllQueues() ([]*string, error)
}
type sqsRepository struct {
client sqsiface.SQSAPI
}
func NewSQSClient(session *session.Session) *sqsRepository {
return &sqsRepository{
sqs.New(session),
}
}
func (r *sqsRepository) ListAllQueues() ([]*string, error) {
var queues []*string
input := sqs.ListQueuesInput{}
err := r.client.ListQueuesPages(&input,
func(resp *sqs.ListQueuesOutput, lastPage bool) bool {
queues = append(queues, resp.QueueUrls...)
return !lastPage
},
)
if err != nil {
return nil, err
}
return queues, nil
}

View File

@ -0,0 +1,69 @@
package repository
import (
"strings"
"testing"
awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/cloudskiff/driftctl/mocks"
"github.com/r3labs/diff/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func Test_sqsRepository_ListAllQueues(t *testing.T) {
tests := []struct {
name string
mocks func(client *mocks.FakeSQS)
want []*string
wantErr error
}{
{
name: "list with multiple pages",
mocks: func(client *mocks.FakeSQS) {
client.On("ListQueuesPages",
&sqs.ListQueuesInput{},
mock.MatchedBy(func(callback func(res *sqs.ListQueuesOutput, lastPage bool) bool) bool {
callback(&sqs.ListQueuesOutput{
QueueUrls: []*string{
awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/bar.fifo"),
awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/foo"),
},
}, false)
callback(&sqs.ListQueuesOutput{
QueueUrls: []*string{
awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/baz"),
},
}, true)
return true
})).Return(nil)
},
want: []*string{
awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/bar.fifo"),
awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/foo"),
awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/baz"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := &mocks.FakeSQS{}
tt.mocks(client)
r := &sqsRepository{
client: client,
}
got, err := r.ListAllQueues()
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
})
}
}

View File

@ -1,8 +1,7 @@
package aws
import (
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
"github.com/cloudskiff/driftctl/pkg/remote/aws/repository"
"github.com/cloudskiff/driftctl/pkg/remote/deserializer"
remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error"
"github.com/cloudskiff/driftctl/pkg/resource"
@ -16,7 +15,7 @@ import (
type SqsQueuePolicySupplier struct {
reader terraform.ResourceReader
deserializer deserializer.CTYDeserializer
client sqsiface.SQSAPI
client repository.SQSRepository
runner *terraform.ParallelResourceReader
}
@ -24,13 +23,13 @@ func NewSqsQueuePolicySupplier(provider *TerraformProvider) *SqsQueuePolicySuppl
return &SqsQueuePolicySupplier{
provider,
awsdeserializer.NewSqsQueuePolicyDeserializer(),
sqs.New(provider.session),
repository.NewSQSClient(provider.session),
terraform.NewParallelResourceReader(provider.Runner().SubRunner()),
}
}
func (s SqsQueuePolicySupplier) Resources() ([]resource.Resource, error) {
queues, err := listSqsQueues(s.client)
queues, err := s.client.ListAllQueues()
if err != nil {
return nil, remoteerror.NewResourceEnumerationErrorWithType(err, aws.AwsSqsQueuePolicyResourceType, aws.AwsSqsQueueResourceType)
}
@ -50,11 +49,11 @@ func (s SqsQueuePolicySupplier) Resources() ([]resource.Resource, error) {
return s.deserializer.Deserialize(resources)
}
func (s SqsQueuePolicySupplier) readSqsQueuePolicy(queue string) (cty.Value, error) {
func (s SqsQueuePolicySupplier) readSqsQueuePolicy(queueURL string) (cty.Value, error) {
var Ty resource.ResourceType = aws.AwsSqsQueuePolicyResourceType
val, err := s.reader.ReadResource(terraform.ReadResourceArgs{
Ty: Ty,
ID: queue,
ID: queueURL,
})
if err != nil {
logrus.WithFields(logrus.Fields{

View File

@ -8,8 +8,6 @@ import (
remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error"
resourceaws "github.com/cloudskiff/driftctl/pkg/resource/aws"
"github.com/aws/aws-sdk-go/service/sqs"
awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/cloudskiff/driftctl/mocks"
@ -29,7 +27,7 @@ func TestSqsQueuePolicySupplier_Resources(t *testing.T) {
cases := []struct {
test string
dirName string
mocks func(client *mocks.FakeSQS)
mocks func(client *mocks.SQSRepository)
err error
}{
{
@ -37,44 +35,28 @@ func TestSqsQueuePolicySupplier_Resources(t *testing.T) {
// as a default SQSDefaultPolicy (e.g. policy="") will always be present in each queue
test: "no sqs queue policies",
dirName: "sqs_queue_policy_empty",
mocks: func(client *mocks.FakeSQS) {
client.On("ListQueuesPages",
&sqs.ListQueuesInput{},
mock.MatchedBy(func(callback func(res *sqs.ListQueuesOutput, lastPage bool) bool) bool {
callback(&sqs.ListQueuesOutput{}, true)
return true
})).Return(nil)
mocks: func(client *mocks.SQSRepository) {
client.On("ListAllQueues").Return([]*string{}, nil)
},
err: nil,
},
{
test: "multiple sqs queue policies (default or not)",
dirName: "sqs_queue_policy_multiple",
mocks: func(client *mocks.FakeSQS) {
client.On("ListQueuesPages",
&sqs.ListQueuesInput{},
mock.MatchedBy(func(callback func(res *sqs.ListQueuesOutput, lastPage bool) bool) bool {
callback(&sqs.ListQueuesOutput{
QueueUrls: []*string{
mocks: func(client *mocks.SQSRepository) {
client.On("ListAllQueues").Return([]*string{
awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/bar.fifo"),
awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/foo"),
awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/baz"),
},
}, true)
return true
})).Return(nil)
}, nil)
},
err: nil,
},
{
test: "cannot list sqs queues, thus sqs queue policies",
dirName: "sqs_queue_policy_empty",
mocks: func(client *mocks.FakeSQS) {
client.On(
"ListQueuesPages",
&sqs.ListQueuesInput{},
mock.Anything,
).Once().Return(awserr.NewRequestFailure(nil, 403, ""))
mocks: func(client *mocks.SQSRepository) {
client.On("ListAllQueues").Return(nil, awserr.NewRequestFailure(nil, 403, ""))
},
err: remoteerror.NewResourceEnumerationErrorWithType(awserr.NewRequestFailure(nil, 403, ""), resourceaws.AwsSqsQueuePolicyResourceType, resourceaws.AwsSqsQueueResourceType),
},
@ -95,7 +77,7 @@ func TestSqsQueuePolicySupplier_Resources(t *testing.T) {
}
t.Run(c.test, func(tt *testing.T) {
fakeSQS := mocks.FakeSQS{}
fakeSQS := mocks.SQSRepository{}
c.mocks(&fakeSQS)
provider := mocks2.NewMockedGoldenTFProvider(c.dirName, providerLibrary.Provider(terraform.AWS), shouldUpdate)
sqsQueuePolicyDeserializer := awsdeserializer.NewSqsQueuePolicyDeserializer()

View File

@ -1,8 +1,7 @@
package aws
import (
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
"github.com/cloudskiff/driftctl/pkg/remote/aws/repository"
"github.com/cloudskiff/driftctl/pkg/remote/deserializer"
remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error"
"github.com/cloudskiff/driftctl/pkg/resource"
@ -16,7 +15,7 @@ import (
type SqsQueueSupplier struct {
reader terraform.ResourceReader
deserializer deserializer.CTYDeserializer
client sqsiface.SQSAPI
client repository.SQSRepository
runner *terraform.ParallelResourceReader
}
@ -24,13 +23,13 @@ func NewSqsQueueSupplier(provider *TerraformProvider) *SqsQueueSupplier {
return &SqsQueueSupplier{
provider,
awsdeserializer.NewSqsQueueDeserializer(),
sqs.New(provider.session),
repository.NewSQSClient(provider.session),
terraform.NewParallelResourceReader(provider.Runner().SubRunner()),
}
}
func (s SqsQueueSupplier) Resources() ([]resource.Resource, error) {
queues, err := listSqsQueues(s.client)
queues, err := s.client.ListAllQueues()
if err != nil {
return nil, remoteerror.NewResourceEnumerationError(err, aws.AwsSqsQueueResourceType)
}
@ -50,11 +49,11 @@ func (s SqsQueueSupplier) Resources() ([]resource.Resource, error) {
return s.deserializer.Deserialize(resources)
}
func (s SqsQueueSupplier) readSqsQueue(queue string) (cty.Value, error) {
func (s SqsQueueSupplier) readSqsQueue(queueURL string) (cty.Value, error) {
var Ty resource.ResourceType = aws.AwsSqsQueueResourceType
val, err := s.reader.ReadResource(terraform.ReadResourceArgs{
Ty: Ty,
ID: queue,
ID: queueURL,
})
if err != nil {
logrus.WithFields(logrus.Fields{
@ -64,18 +63,3 @@ func (s SqsQueueSupplier) readSqsQueue(queue string) (cty.Value, error) {
}
return *val, nil
}
func listSqsQueues(client sqsiface.SQSAPI) ([]*string, error) {
var queues []*string
input := sqs.ListQueuesInput{}
err := client.ListQueuesPages(&input,
func(resp *sqs.ListQueuesOutput, lastPage bool) bool {
queues = append(queues, resp.QueueUrls...)
return !lastPage
},
)
if err != nil {
return nil, err
}
return queues, nil
}

View File

@ -4,14 +4,12 @@ import (
"context"
"testing"
awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
remoteerror "github.com/cloudskiff/driftctl/pkg/remote/error"
resourceaws "github.com/cloudskiff/driftctl/pkg/resource/aws"
"github.com/aws/aws-sdk-go/service/sqs"
awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/cloudskiff/driftctl/mocks"
"github.com/cloudskiff/driftctl/pkg/parallel"
"github.com/cloudskiff/driftctl/pkg/remote/deserializer"
@ -29,49 +27,33 @@ func TestSqsQueueSupplier_Resources(t *testing.T) {
cases := []struct {
test string
dirName string
mocks func(client *mocks.FakeSQS)
mocks func(client *mocks.SQSRepository)
err error
}{
{
test: "no sqs queues",
dirName: "sqs_queue_empty",
mocks: func(client *mocks.FakeSQS) {
client.On("ListQueuesPages",
&sqs.ListQueuesInput{},
mock.MatchedBy(func(callback func(res *sqs.ListQueuesOutput, lastPage bool) bool) bool {
callback(&sqs.ListQueuesOutput{}, true)
return true
})).Return(nil)
mocks: func(client *mocks.SQSRepository) {
client.On("ListAllQueues").Return([]*string{}, nil)
},
err: nil,
},
{
test: "multiple sqs queues",
dirName: "sqs_queue_multiple",
mocks: func(client *mocks.FakeSQS) {
client.On("ListQueuesPages",
&sqs.ListQueuesInput{},
mock.MatchedBy(func(callback func(res *sqs.ListQueuesOutput, lastPage bool) bool) bool {
callback(&sqs.ListQueuesOutput{
QueueUrls: []*string{
mocks: func(client *mocks.SQSRepository) {
client.On("ListAllQueues").Return([]*string{
awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/bar.fifo"),
awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/foo"),
},
}, true)
return true
})).Return(nil)
}, nil)
},
err: nil,
},
{
test: "cannot list sqs queues",
dirName: "sqs_queue_empty",
mocks: func(client *mocks.FakeSQS) {
client.On(
"ListQueuesPages",
&sqs.ListQueuesInput{},
mock.Anything,
).Once().Return(awserr.NewRequestFailure(nil, 403, ""))
mocks: func(client *mocks.SQSRepository) {
client.On("ListAllQueues").Return(nil, awserr.NewRequestFailure(nil, 403, ""))
},
err: remoteerror.NewResourceEnumerationError(awserr.NewRequestFailure(nil, 403, ""), resourceaws.AwsSqsQueueResourceType),
},
@ -92,7 +74,7 @@ func TestSqsQueueSupplier_Resources(t *testing.T) {
}
t.Run(c.test, func(tt *testing.T) {
fakeSQS := mocks.FakeSQS{}
fakeSQS := mocks.SQSRepository{}
c.mocks(&fakeSQS)
provider := mocks2.NewMockedGoldenTFProvider(c.dirName, providerLibrary.Provider(terraform.AWS), shouldUpdate)
sqsQueueDeserializer := awsdeserializer.NewSqsQueueDeserializer()

View File

@ -3,10 +3,19 @@ package aws_test
import (
"testing"
"github.com/cloudskiff/driftctl/pkg/analyser"
awsresources "github.com/cloudskiff/driftctl/pkg/resource/aws"
"github.com/r3labs/diff/v2"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/cloudskiff/driftctl/test/acceptance"
"github.com/cloudskiff/driftctl/test/acceptance/awsutils"
)
func TestAcc_AwsSqsQueue(t *testing.T) {
var mutatedQueue string
acceptance.Run(t, acceptance.AccTestCase{
Path: "./testdata/acc/aws_sqs_queue",
Args: []string{"scan", "--filter", "Type=='aws_sqs_queue'"},
@ -21,6 +30,42 @@ func TestAcc_AwsSqsQueue(t *testing.T) {
}
result.AssertInfrastructureIsInSync()
result.Equal(2, result.Summary().TotalManaged)
mutatedQueue = result.Managed()[0].TerraformId()
},
},
{
Env: map[string]string{
"AWS_REGION": "us-east-1",
},
PreExec: func() {
client := sqs.New(awsutils.Session())
attributes := make(map[string]*string)
attributes["DelaySeconds"] = aws.String("200")
_, err := client.SetQueueAttributes(&sqs.SetQueueAttributesInput{
Attributes: attributes,
QueueUrl: aws.String(mutatedQueue),
})
if err != nil {
t.Fatal(err)
}
},
Check: func(result *acceptance.ScanResult, stdout string, err error) {
if err != nil {
t.Fatal(err)
}
result.AssertDriftCountTotal(1)
result.AssertResourceHasDrift(
mutatedQueue,
awsresources.AwsSqsQueueResourceType,
analyser.Change{
Change: diff.Change{
Type: diff.UPDATE,
Path: []string{"DelaySeconds"},
From: float64(0),
To: float64(200),
},
},
)
},
},
},