From 2b71c8e6506e8e6869a2edfe38a12162e60561f8 Mon Sep 17 00:00:00 2001 From: Elie Date: Mon, 29 Mar 2021 18:10:50 +0200 Subject: [PATCH] Add resource factory --- pkg/cmd/scan.go | 5 +- pkg/driftctl.go | 34 +++--- pkg/driftctl_test.go | 10 +- pkg/middlewares/aws_bucket_policy_expander.go | 24 +++- .../aws_bucket_policy_expander_test.go | 10 +- pkg/middlewares/aws_instance_block_device.go | 43 ++++++- .../aws_instance_block_device_test.go | 16 ++- pkg/middlewares/aws_route_table_expander.go | 56 ++++++++- .../aws_route_table_expander_test.go | 49 +++++--- .../aws_sns_topic_policy_expander.go | 24 +++- .../aws_sns_topic_policy_expander_test.go | 14 ++- .../aws_sqs_queue_policy_expander.go | 23 +++- .../aws_sqs_queue_policy_expander_test.go | 10 +- pkg/middlewares/vpc_security_group_rules.go | 115 ++++++++++-------- .../vpc_security_group_rules_test.go | 8 +- pkg/resource/resource.go | 4 + pkg/terraform/mock_ResourceFactory.go | 36 ++++++ pkg/terraform/providers.go | 22 ++++ pkg/terraform/resource_factory.go | 46 +++++++ 19 files changed, 439 insertions(+), 110 deletions(-) create mode 100644 pkg/terraform/mock_ResourceFactory.go create mode 100644 pkg/terraform/resource_factory.go diff --git a/pkg/cmd/scan.go b/pkg/cmd/scan.go index caf84d1a..c8360d6d 100644 --- a/pkg/cmd/scan.go +++ b/pkg/cmd/scan.go @@ -171,7 +171,10 @@ func scanRun(opts *ScanOptions) error { if err != nil { return err } - ctl := pkg.NewDriftCTL(scanner, iacSupplier, opts.Filter, alerter) + + resFactory := terraform.NewTerraformResourceFactory(providerLibrary) + + ctl := pkg.NewDriftCTL(scanner, iacSupplier, opts.Filter, alerter, resFactory) go func() { <-c diff --git a/pkg/driftctl.go b/pkg/driftctl.go index 7b46cbc4..01f307b8 100644 --- a/pkg/driftctl.go +++ b/pkg/driftctl.go @@ -3,26 +3,28 @@ package pkg import ( "fmt" + "github.com/jmespath/go-jmespath" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/cloudskiff/driftctl/pkg/alerter" "github.com/cloudskiff/driftctl/pkg/analyser" "github.com/cloudskiff/driftctl/pkg/filter" "github.com/cloudskiff/driftctl/pkg/middlewares" "github.com/cloudskiff/driftctl/pkg/resource" - "github.com/jmespath/go-jmespath" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" ) type DriftCTL struct { - remoteSupplier resource.Supplier - iacSupplier resource.Supplier - alerter alerter.AlerterInterface - analyzer analyser.Analyzer - filter *jmespath.JMESPath + remoteSupplier resource.Supplier + iacSupplier resource.Supplier + alerter alerter.AlerterInterface + analyzer analyser.Analyzer + filter *jmespath.JMESPath + resourceFactory resource.ResourceFactory } -func NewDriftCTL(remoteSupplier resource.Supplier, iacSupplier resource.Supplier, filter *jmespath.JMESPath, alerter *alerter.Alerter) *DriftCTL { - return &DriftCTL{remoteSupplier, iacSupplier, alerter, analyser.NewAnalyzer(alerter), filter} +func NewDriftCTL(remoteSupplier resource.Supplier, iacSupplier resource.Supplier, filter *jmespath.JMESPath, alerter *alerter.Alerter, resFactory resource.ResourceFactory) *DriftCTL { + return &DriftCTL{remoteSupplier, iacSupplier, alerter, analyser.NewAnalyzer(alerter), filter, resFactory} } func (d DriftCTL) Run() (*analyser.Analysis, error) { @@ -34,23 +36,23 @@ func (d DriftCTL) Run() (*analyser.Analysis, error) { middleware := middlewares.NewChain( middlewares.NewRoute53DefaultZoneRecordSanitizer(), middlewares.NewS3BucketAcl(), - middlewares.NewAwsInstanceBlockDeviceResourceMapper(), + middlewares.NewAwsInstanceBlockDeviceResourceMapper(d.resourceFactory), middlewares.NewVPCDefaultSecurityGroupSanitizer(), - middlewares.NewVPCSecurityGroupRuleSanitizer(), + middlewares.NewVPCSecurityGroupRuleSanitizer(d.resourceFactory), middlewares.NewIamPolicyAttachmentSanitizer(), middlewares.AwsInstanceEIP{}, middlewares.NewAwsDefaultInternetGatewayRoute(), middlewares.NewAwsDefaultInternetGateway(), middlewares.NewAwsDefaultVPC(), middlewares.NewAwsDefaultSubnet(), - middlewares.NewAwsRouteTableExpander(d.alerter), + middlewares.NewAwsRouteTableExpander(d.alerter, d.resourceFactory), middlewares.NewAwsDefaultRouteTable(), middlewares.NewAwsDefaultRoute(), middlewares.NewAwsNatGatewayEipAssoc(), - middlewares.NewAwsBucketPolicyExpander(), - middlewares.NewAwsSqsQueuePolicyExpander(), + middlewares.NewAwsBucketPolicyExpander(d.resourceFactory), + middlewares.NewAwsSqsQueuePolicyExpander(d.resourceFactory), middlewares.NewAwsDefaultSqsQueuePolicy(), - middlewares.NewAwsSNSTopicPolicyExpander(), + middlewares.NewAwsSNSTopicPolicyExpander(d.resourceFactory), ) logrus.Debug("Ready to run middlewares") diff --git a/pkg/driftctl_test.go b/pkg/driftctl_test.go index e4979911..5e13e14a 100644 --- a/pkg/driftctl_test.go +++ b/pkg/driftctl_test.go @@ -12,6 +12,7 @@ import ( "github.com/cloudskiff/driftctl/pkg/analyser" filter2 "github.com/cloudskiff/driftctl/pkg/filter" "github.com/cloudskiff/driftctl/pkg/resource" + "github.com/cloudskiff/driftctl/pkg/terraform" "github.com/cloudskiff/driftctl/test" testresource "github.com/cloudskiff/driftctl/test/resource" ) @@ -21,6 +22,7 @@ type TestCase struct { stateResources []resource.Resource remoteResources []resource.Resource filter string + mocks func(factory resource.ResourceFactory) assert func(result *test.ScanResult, err error) } @@ -52,7 +54,13 @@ func runTest(t *testing.T, cases TestCases) { filter = f } - driftctl := pkg.NewDriftCTL(remoteSupplier, stateSupplier, filter, testAlerter) + resourceFactory := &terraform.MockResourceFactory{} + + if c.mocks != nil { + c.mocks(resourceFactory) + } + + driftctl := pkg.NewDriftCTL(remoteSupplier, stateSupplier, filter, testAlerter, resourceFactory) analysis, err := driftctl.Run() diff --git a/pkg/middlewares/aws_bucket_policy_expander.go b/pkg/middlewares/aws_bucket_policy_expander.go index 657b4c2f..6e9d1193 100644 --- a/pkg/middlewares/aws_bucket_policy_expander.go +++ b/pkg/middlewares/aws_bucket_policy_expander.go @@ -1,16 +1,21 @@ package middlewares import ( + "github.com/sirupsen/logrus" + "github.com/cloudskiff/driftctl/pkg/resource" "github.com/cloudskiff/driftctl/pkg/resource/aws" - "github.com/sirupsen/logrus" ) // Explodes policy found in aws_s3_bucket.policy from state resources to dedicated resources -type AwsBucketPolicyExpander struct{} +type AwsBucketPolicyExpander struct { + resourceFactory resource.ResourceFactory +} -func NewAwsBucketPolicyExpander() AwsBucketPolicyExpander { - return AwsBucketPolicyExpander{} +func NewAwsBucketPolicyExpander(resourceFactory resource.ResourceFactory) AwsBucketPolicyExpander { + return AwsBucketPolicyExpander{ + resourceFactory: resourceFactory, + } } func (m AwsBucketPolicyExpander) Execute(_, resourcesFromState *[]resource.Resource) error { @@ -44,10 +49,21 @@ func (m *AwsBucketPolicyExpander) handlePolicy(bucket *aws.AwsS3Bucket, results return nil } + data := map[string]interface{}{ + "id": bucket.Id, + "bucket": bucket.Bucket, + "policy": bucket.Policy, + } + ctyVal, err := m.resourceFactory.CreateResource(data, "aws_s3_bucket_policy") + if err != nil { + return err + } + newPolicy := &aws.AwsS3BucketPolicy{ Id: bucket.Id, Bucket: bucket.Bucket, Policy: bucket.Policy, + CtyVal: ctyVal, } normalizedRes, err := newPolicy.NormalizeForState() if err != nil { diff --git a/pkg/middlewares/aws_bucket_policy_expander_test.go b/pkg/middlewares/aws_bucket_policy_expander_test.go index 5676e511..36a5657e 100644 --- a/pkg/middlewares/aws_bucket_policy_expander_test.go +++ b/pkg/middlewares/aws_bucket_policy_expander_test.go @@ -6,8 +6,12 @@ import ( awssdk "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awsutil" + "github.com/stretchr/testify/mock" + "github.com/cloudskiff/driftctl/pkg/resource" "github.com/cloudskiff/driftctl/pkg/resource/aws" + "github.com/cloudskiff/driftctl/pkg/terraform" + "github.com/r3labs/diff/v2" ) @@ -96,7 +100,11 @@ func TestAwsBucketPolicyExpander_Execute(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - m := NewAwsBucketPolicyExpander() + + factory := &terraform.MockResourceFactory{} + factory.On("CreateResource", mock.Anything, "aws_s3_bucket_policy").Once().Return(nil, nil) + + m := NewAwsBucketPolicyExpander(factory) err := m.Execute(&[]resource.Resource{}, &tt.resourcesFromState) if err != nil { t.Fatal(err) diff --git a/pkg/middlewares/aws_instance_block_device.go b/pkg/middlewares/aws_instance_block_device.go index 98fae645..82dd6a35 100644 --- a/pkg/middlewares/aws_instance_block_device.go +++ b/pkg/middlewares/aws_instance_block_device.go @@ -2,16 +2,19 @@ package middlewares import ( awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/sirupsen/logrus" + "github.com/cloudskiff/driftctl/pkg/resource" "github.com/cloudskiff/driftctl/pkg/resource/aws" - "github.com/sirupsen/logrus" ) // Remove root_block_device from aws_instance resources and create dedicated aws_ebs_volume resources -type AwsInstanceBlockDeviceResourceMapper struct{} +type AwsInstanceBlockDeviceResourceMapper struct { + resourceFactory resource.ResourceFactory +} -func NewAwsInstanceBlockDeviceResourceMapper() AwsInstanceBlockDeviceResourceMapper { - return AwsInstanceBlockDeviceResourceMapper{} +func NewAwsInstanceBlockDeviceResourceMapper(resourceFactory resource.ResourceFactory) AwsInstanceBlockDeviceResourceMapper { + return AwsInstanceBlockDeviceResourceMapper{resourceFactory: resourceFactory} } func (a AwsInstanceBlockDeviceResourceMapper) Execute(remoteResources, resourcesFromState *[]resource.Resource) error { @@ -32,6 +35,21 @@ func (a AwsInstanceBlockDeviceResourceMapper) Execute(remoteResources, resources "volume": *rootBlock.VolumeId, "instance": instance.TerraformId(), }).Debug("Creating aws_ebs_volume from aws_instance.root_block_device") + data := map[string]interface{}{ + "availability_zone": instance.AvailabilityZone, + "encrypted": rootBlock.Encrypted, + "id": *rootBlock.VolumeId, + "iops": rootBlock.Iops, + "kms_key_id": rootBlock.KmsKeyId, + "size": rootBlock.VolumeSize, + "type": rootBlock.VolumeType, + "multi_attach_enabled": false, + "tags": instance.VolumeTags, + } + ctyVal, err := a.resourceFactory.CreateResource(data, "aws_ebs_volume") + if err != nil { + return err + } ebsVolume := aws.AwsEbsVolume{ AvailabilityZone: instance.AvailabilityZone, Encrypted: rootBlock.Encrypted, @@ -42,6 +60,7 @@ func (a AwsInstanceBlockDeviceResourceMapper) Execute(remoteResources, resources Type: rootBlock.VolumeType, MultiAttachEnabled: awssdk.Bool(false), Tags: instance.VolumeTags, + CtyVal: ctyVal, } newStateResources = append(newStateResources, &ebsVolume) } @@ -53,6 +72,21 @@ func (a AwsInstanceBlockDeviceResourceMapper) Execute(remoteResources, resources "volume": *blockDevice.VolumeId, "instance": instance.TerraformId(), }).Debug("Creating aws_ebs_volume from aws_instance.ebs_block_device") + data := map[string]interface{}{ + "availability_zone": instance.AvailabilityZone, + "encrypted": blockDevice.Encrypted, + "id": *blockDevice.VolumeId, + "iops": blockDevice.Iops, + "kms_key_id": blockDevice.KmsKeyId, + "size": blockDevice.VolumeSize, + "type": blockDevice.VolumeType, + "multi_attach_enabled": false, + "tags": instance.VolumeTags, + } + ctyVal, err := a.resourceFactory.CreateResource(data, "aws_ebs_volume") + if err != nil { + return err + } ebsVolume := aws.AwsEbsVolume{ AvailabilityZone: instance.AvailabilityZone, Encrypted: blockDevice.Encrypted, @@ -63,6 +97,7 @@ func (a AwsInstanceBlockDeviceResourceMapper) Execute(remoteResources, resources Type: blockDevice.VolumeType, MultiAttachEnabled: awssdk.Bool(false), Tags: instance.VolumeTags, + CtyVal: ctyVal, } newStateResources = append(newStateResources, &ebsVolume) } diff --git a/pkg/middlewares/aws_instance_block_device_test.go b/pkg/middlewares/aws_instance_block_device_test.go index 158826b0..17ae55fa 100644 --- a/pkg/middlewares/aws_instance_block_device_test.go +++ b/pkg/middlewares/aws_instance_block_device_test.go @@ -6,8 +6,12 @@ import ( awssdk "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awsutil" + "github.com/stretchr/testify/mock" + "github.com/cloudskiff/driftctl/pkg/resource" "github.com/cloudskiff/driftctl/pkg/resource/aws" + "github.com/cloudskiff/driftctl/pkg/terraform" + "github.com/r3labs/diff/v2" ) @@ -19,6 +23,7 @@ func TestAwsInstanceBlockDeviceResourceMapper_Execute(t *testing.T) { tests := []struct { name string args args + mocks func(factory *terraform.MockResourceFactory) wantErr bool }{ { @@ -118,12 +123,21 @@ func TestAwsInstanceBlockDeviceResourceMapper_Execute(t *testing.T) { }, }, }, + func(factory *terraform.MockResourceFactory) { + factory.On("CreateResource", mock.Anything, "aws_ebs_volume").Times(2).Return(nil, nil) + }, false, }, } for _, c := range tests { t.Run(c.name, func(tt *testing.T) { - a := AwsInstanceBlockDeviceResourceMapper{} + + factory := &terraform.MockResourceFactory{} + if c.mocks != nil { + c.mocks(factory) + } + + a := NewAwsInstanceBlockDeviceResourceMapper(factory) if err := a.Execute(&[]resource.Resource{}, c.args.resourcesFromState); (err != nil) != c.wantErr { t.Errorf("Execute() error = %v, wantErr %v", err, c.wantErr) } diff --git a/pkg/middlewares/aws_route_table_expander.go b/pkg/middlewares/aws_route_table_expander.go index 47037a71..7442778c 100644 --- a/pkg/middlewares/aws_route_table_expander.go +++ b/pkg/middlewares/aws_route_table_expander.go @@ -4,10 +4,11 @@ import ( "fmt" awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/sirupsen/logrus" + "github.com/cloudskiff/driftctl/pkg/alerter" "github.com/cloudskiff/driftctl/pkg/resource" "github.com/cloudskiff/driftctl/pkg/resource/aws" - "github.com/sirupsen/logrus" ) type invalidRouteAlert struct { @@ -29,12 +30,14 @@ func (i *invalidRouteAlert) ShouldIgnoreResource() bool { // Explodes routes found in aws_default_route_table.route and aws_route_table.route to dedicated resources type AwsRouteTableExpander struct { - alerter alerter.AlerterInterface + alerter alerter.AlerterInterface + resourceFactory resource.ResourceFactory } -func NewAwsRouteTableExpander(alerter alerter.AlerterInterface) AwsRouteTableExpander { +func NewAwsRouteTableExpander(alerter alerter.AlerterInterface, resourceFactory resource.ResourceFactory) AwsRouteTableExpander { return AwsRouteTableExpander{ alerter, + resourceFactory, } } @@ -80,6 +83,29 @@ func (m *AwsRouteTableExpander) handleTable(table *aws.AwsRouteTable, results *[ m.alerter.SendAlert(aws.AwsRouteTableResourceType, newInvalidRouteAlert(aws.AwsRouteTableResourceType, table.Id)) continue } + data := map[string]interface{}{ + "destination_cidr_block": route.CidrBlock, + "destination_ipv6_cidr_block": route.Ipv6CidrBlock, + "destination_prefix_list_id": "", + "egress_only_gateway_id": route.EgressOnlyGatewayId, + "gateway_id": route.GatewayId, + "id": routeId, + "instance_id": route.InstanceId, + "instance_owner_id": "", + "local_gateway_id": route.LocalGatewayId, + "nat_gateway_id": route.NatGatewayId, + "network_interface_id": route.NetworkInterfaceId, + "origin": "CreateRoute", + "route_table_id": table.Id, + "state": "active", + "transit_gateway_id": route.TransitGatewayId, + "vpc_endpoint_id": route.VpcEndpointId, + "vpc_peering_connection_id": route.VpcPeeringConnectionId, + } + ctyVal, err := m.resourceFactory.CreateResource(data, "aws_route") + if err != nil { + return err + } newRouteFromTable := &aws.AwsRoute{ DestinationCidrBlock: route.CidrBlock, DestinationIpv6CidrBlock: route.Ipv6CidrBlock, @@ -98,6 +124,7 @@ func (m *AwsRouteTableExpander) handleTable(table *aws.AwsRouteTable, results *[ TransitGatewayId: route.TransitGatewayId, VpcEndpointId: route.VpcEndpointId, VpcPeeringConnectionId: route.VpcPeeringConnectionId, + CtyVal: ctyVal, } normalizedRes, err := newRouteFromTable.NormalizeForState() if err != nil { @@ -125,6 +152,28 @@ func (m *AwsRouteTableExpander) handleDefaultTable(table *aws.AwsDefaultRouteTab m.alerter.SendAlert(aws.AwsDefaultRouteTableResourceType, newInvalidRouteAlert(aws.AwsDefaultRouteTableResourceType, table.Id)) continue } + data := map[string]interface{}{ + "destination_cidr_block": route.CidrBlock, + "destination_ipv6_cidr_block": route.Ipv6CidrBlock, + "destination_prefix_list_id": "", + "egress_only_gateway_id": route.EgressOnlyGatewayId, + "gateway_id": route.GatewayId, + "id": routeId, + "instance_id": route.InstanceId, + "instance_owner_id": "", + "nat_gateway_id": route.NatGatewayId, + "network_interface_id": route.NetworkInterfaceId, + "origin": "CreateRoute", + "route_table_id": table.Id, + "state": "active", + "transit_gateway_id": route.TransitGatewayId, + "vpc_endpoint_id": route.VpcEndpointId, + "vpc_peering_connection_id": route.VpcPeeringConnectionId, + } + ctyVal, err := m.resourceFactory.CreateResource(data, "aws_route") + if err != nil { + return err + } newRouteFromTable := &aws.AwsRoute{ DestinationCidrBlock: route.CidrBlock, DestinationIpv6CidrBlock: route.Ipv6CidrBlock, @@ -142,6 +191,7 @@ func (m *AwsRouteTableExpander) handleDefaultTable(table *aws.AwsDefaultRouteTab TransitGatewayId: route.TransitGatewayId, VpcEndpointId: route.VpcEndpointId, VpcPeeringConnectionId: route.VpcPeeringConnectionId, + CtyVal: ctyVal, } normalizedRes, err := newRouteFromTable.NormalizeForState() if err != nil { diff --git a/pkg/middlewares/aws_route_table_expander_test.go b/pkg/middlewares/aws_route_table_expander_test.go index 4b5541c7..2c58afaf 100644 --- a/pkg/middlewares/aws_route_table_expander_test.go +++ b/pkg/middlewares/aws_route_table_expander_test.go @@ -6,11 +6,14 @@ import ( awssdk "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awsutil" + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/mock" + "github.com/cloudskiff/driftctl/mocks" "github.com/cloudskiff/driftctl/pkg/resource" "github.com/cloudskiff/driftctl/pkg/resource/aws" + "github.com/cloudskiff/driftctl/pkg/terraform" resource2 "github.com/cloudskiff/driftctl/test/resource" - "github.com/r3labs/diff/v2" ) func TestAwsRouteTableExpander_Execute(t *testing.T) { @@ -18,17 +21,17 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) { name string input []resource.Resource expected []resource.Resource + mock func(factory *terraform.MockResourceFactory) }{ { - "test with nil route attributes", - []resource.Resource{ + name: "test with nil route attributes", + input: []resource.Resource{ &aws.AwsRouteTable{ Id: "table_from_state", Route: nil, }, }, - - []resource.Resource{ + expected: []resource.Resource{ &aws.AwsRouteTable{ Id: "table_from_state", Route: nil, @@ -36,8 +39,8 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) { }, }, { - "test with empty route attributes", - []resource.Resource{ + name: "test with empty route attributes", + input: []resource.Resource{ &aws.AwsRouteTable{ Id: "table_from_state", Route: &[]struct { @@ -55,7 +58,7 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) { }{}, }, }, - []resource.Resource{ + expected: []resource.Resource{ &aws.AwsRouteTable{ Id: "table_from_state", Route: nil, @@ -63,8 +66,8 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) { }, }, { - "test route are expanded", - []resource.Resource{ + name: "test route are expanded", + input: []resource.Resource{ &resource2.FakeResource{ Id: "fake_resource", }, @@ -95,7 +98,7 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) { }, }, }, - []resource.Resource{ + expected: []resource.Resource{ &resource2.FakeResource{ Id: "fake_resource", }, @@ -124,10 +127,13 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) { InstanceOwnerId: awssdk.String(""), }, }, + mock: func(factory *terraform.MockResourceFactory) { + factory.On("CreateResource", mock.Anything, "aws_route").Times(2).Return(nil, nil) + }, }, { - "test route are expanded on default route tables", - []resource.Resource{ + name: "test route are expanded on default route tables", + input: []resource.Resource{ &resource2.FakeResource{ Id: "fake_resource", }, @@ -157,7 +163,7 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) { }, }, }, - []resource.Resource{ + expected: []resource.Resource{ &resource2.FakeResource{ Id: "fake_resource", }, @@ -186,12 +192,21 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) { InstanceOwnerId: awssdk.String(""), }, }, + mock: func(factory *terraform.MockResourceFactory) { + factory.On("CreateResource", mock.Anything, "aws_route").Times(2).Return(nil, nil) + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockedAlerter := &mocks.AlerterInterface{} - m := NewAwsRouteTableExpander(mockedAlerter) + + factory := &terraform.MockResourceFactory{} + if tt.mock != nil { + tt.mock(factory) + } + + m := NewAwsRouteTableExpander(mockedAlerter, factory) err := m.Execute(nil, &tt.input) if err != nil { t.Fatal(err) @@ -274,7 +289,9 @@ func TestAwsRouteTableExpander_ExecuteWithInvalidRoutes(t *testing.T) { }, } - m := NewAwsRouteTableExpander(mockedAlerter) + factory := &terraform.MockResourceFactory{} + + m := NewAwsRouteTableExpander(mockedAlerter, factory) err := m.Execute(nil, &input) if err != nil { t.Fatal(err) diff --git a/pkg/middlewares/aws_sns_topic_policy_expander.go b/pkg/middlewares/aws_sns_topic_policy_expander.go index d2d2dd8f..80502490 100644 --- a/pkg/middlewares/aws_sns_topic_policy_expander.go +++ b/pkg/middlewares/aws_sns_topic_policy_expander.go @@ -1,16 +1,21 @@ package middlewares import ( + "github.com/sirupsen/logrus" + "github.com/cloudskiff/driftctl/pkg/resource" "github.com/cloudskiff/driftctl/pkg/resource/aws" - "github.com/sirupsen/logrus" ) // Explodes policy found in aws_sns_topic from state resources to aws_sns_topic_policy resources -type AwsSNSTopicPolicyExpander struct{} +type AwsSNSTopicPolicyExpander struct { + resourceFactory resource.ResourceFactory +} -func NewAwsSNSTopicPolicyExpander() AwsSNSTopicPolicyExpander { - return AwsSNSTopicPolicyExpander{} +func NewAwsSNSTopicPolicyExpander(resourceFactory resource.ResourceFactory) AwsSNSTopicPolicyExpander { + return AwsSNSTopicPolicyExpander{ + resourceFactory, + } } func (m AwsSNSTopicPolicyExpander) Execute(_, resourcesFromState *[]resource.Resource) error { @@ -44,10 +49,21 @@ func (m *AwsSNSTopicPolicyExpander) splitPolicy(topic *aws.AwsSnsTopic, results return nil } + data := map[string]interface{}{ + "arn": topic.Arn, + "id": topic.Id, + "policy": topic.Policy, + } + ctyVal, err := m.resourceFactory.CreateResource(data, "aws_sns_topic_policy") + if err != nil { + return err + } + newPolicy := &aws.AwsSnsTopicPolicy{ Id: topic.Id, Arn: topic.Arn, Policy: topic.Policy, + CtyVal: ctyVal, } normalized, err := newPolicy.NormalizeForState() diff --git a/pkg/middlewares/aws_sns_topic_policy_expander_test.go b/pkg/middlewares/aws_sns_topic_policy_expander_test.go index 596fd1f0..250dea01 100644 --- a/pkg/middlewares/aws_sns_topic_policy_expander_test.go +++ b/pkg/middlewares/aws_sns_topic_policy_expander_test.go @@ -5,8 +5,10 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws" + "github.com/stretchr/testify/mock" awsresource "github.com/cloudskiff/driftctl/pkg/resource/aws" + "github.com/cloudskiff/driftctl/pkg/terraform" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" @@ -19,6 +21,7 @@ func TestAwsSNSTopicPolicyExpander_Execute(t *testing.T) { name string resourcesFromState *[]resource.Resource expected *[]resource.Resource + mock func(factory *terraform.MockResourceFactory) wantErr bool }{ { @@ -42,6 +45,9 @@ func TestAwsSNSTopicPolicyExpander_Execute(t *testing.T) { Id: "ID", }, }, + mock: func(factory *terraform.MockResourceFactory) { + factory.On("CreateResource", mock.Anything, "aws_sns_topic_policy").Once().Return(nil, nil) + }, wantErr: false, }, { @@ -103,7 +109,13 @@ func TestAwsSNSTopicPolicyExpander_Execute(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - m := AwsSNSTopicPolicyExpander{} + + factory := &terraform.MockResourceFactory{} + if tt.mock != nil { + tt.mock(factory) + } + + m := NewAwsSNSTopicPolicyExpander(factory) if err := m.Execute(&[]resource.Resource{}, tt.resourcesFromState); (err != nil) != tt.wantErr { t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/pkg/middlewares/aws_sqs_queue_policy_expander.go b/pkg/middlewares/aws_sqs_queue_policy_expander.go index 5c1398ba..609d04fe 100644 --- a/pkg/middlewares/aws_sqs_queue_policy_expander.go +++ b/pkg/middlewares/aws_sqs_queue_policy_expander.go @@ -2,16 +2,21 @@ package middlewares import ( awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/sirupsen/logrus" + "github.com/cloudskiff/driftctl/pkg/resource" "github.com/cloudskiff/driftctl/pkg/resource/aws" - "github.com/sirupsen/logrus" ) // Explodes policy found in aws_sqs_queue.policy from state resources to dedicated resources -type AwsSqsQueuePolicyExpander struct{} +type AwsSqsQueuePolicyExpander struct { + resourceFactory resource.ResourceFactory +} -func NewAwsSqsQueuePolicyExpander() AwsSqsQueuePolicyExpander { - return AwsSqsQueuePolicyExpander{} +func NewAwsSqsQueuePolicyExpander(resourceFactory resource.ResourceFactory) AwsSqsQueuePolicyExpander { + return AwsSqsQueuePolicyExpander{ + resourceFactory, + } } func (m AwsSqsQueuePolicyExpander) Execute(_, resourcesFromState *[]resource.Resource) error { @@ -45,10 +50,20 @@ func (m AwsSqsQueuePolicyExpander) Execute(_, resourcesFromState *[]resource.Res } func (m *AwsSqsQueuePolicyExpander) handlePolicy(queue *aws.AwsSqsQueue, results *[]resource.Resource) error { + data := map[string]interface{}{ + "queue_url": queue.Id, + "id": queue.Id, + "policy": queue.Policy, + } + ctyVal, err := m.resourceFactory.CreateResource(data, "aws_sqs_queue_policy") + if err != nil { + return err + } newPolicy := &aws.AwsSqsQueuePolicy{ Id: queue.Id, QueueUrl: awssdk.String(queue.Id), Policy: queue.Policy, + CtyVal: ctyVal, } normalizedRes, err := newPolicy.NormalizeForState() if err != nil { diff --git a/pkg/middlewares/aws_sqs_queue_policy_expander_test.go b/pkg/middlewares/aws_sqs_queue_policy_expander_test.go index 06849a63..276b97d2 100644 --- a/pkg/middlewares/aws_sqs_queue_policy_expander_test.go +++ b/pkg/middlewares/aws_sqs_queue_policy_expander_test.go @@ -6,8 +6,12 @@ import ( awssdk "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awsutil" + "github.com/stretchr/testify/mock" + "github.com/cloudskiff/driftctl/pkg/resource" "github.com/cloudskiff/driftctl/pkg/resource/aws" + "github.com/cloudskiff/driftctl/pkg/terraform" + "github.com/r3labs/diff/v2" ) @@ -90,7 +94,11 @@ func TestAwsSqsQueuePolicyExpander_Execute(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - m := NewAwsSqsQueuePolicyExpander() + + factory := &terraform.MockResourceFactory{} + factory.On("CreateResource", mock.Anything, "aws_sqs_queue_policy").Once().Return(nil, nil) + + m := NewAwsSqsQueuePolicyExpander(factory) err := m.Execute(&[]resource.Resource{}, &tt.resourcesFromState) if err != nil { t.Fatal(err) diff --git a/pkg/middlewares/vpc_security_group_rules.go b/pkg/middlewares/vpc_security_group_rules.go index ea190d46..14d4b2a0 100644 --- a/pkg/middlewares/vpc_security_group_rules.go +++ b/pkg/middlewares/vpc_security_group_rules.go @@ -8,10 +8,14 @@ import ( ) // Split security group rule if it needs to given its attributes -type VPCSecurityGroupRuleSanitizer struct{} +type VPCSecurityGroupRuleSanitizer struct { + resourceFactory resource.ResourceFactory +} -func NewVPCSecurityGroupRuleSanitizer() VPCSecurityGroupRuleSanitizer { - return VPCSecurityGroupRuleSanitizer{} +func NewVPCSecurityGroupRuleSanitizer(resourceFactory resource.ResourceFactory) VPCSecurityGroupRuleSanitizer { + return VPCSecurityGroupRuleSanitizer{ + resourceFactory, + } } func (m VPCSecurityGroupRuleSanitizer) Execute(_, resourcesFromState *[]resource.Resource) error { @@ -33,88 +37,71 @@ func (m VPCSecurityGroupRuleSanitizer) Execute(_, resourcesFromState *[]resource if securityGroupRule.CidrBlocks != nil && len(*securityGroupRule.CidrBlocks) > 0 { for _, ipRange := range *securityGroupRule.CidrBlocks { - rule := resourceaws.AwsSecurityGroupRule{ - Type: securityGroupRule.Type, - Description: securityGroupRule.Description, - SecurityGroupId: securityGroupRule.SecurityGroupId, - Protocol: securityGroupRule.Protocol, - FromPort: securityGroupRule.FromPort, - ToPort: securityGroupRule.ToPort, - CidrBlocks: &[]string{ipRange}, - Ipv6CidrBlocks: &[]string{}, - PrefixListIds: &[]string{}, + rule := *securityGroupRule + rule.CidrBlocks = &[]string{ipRange} + rule.Ipv6CidrBlocks = &[]string{} + rule.PrefixListIds = &[]string{} + res, err := m.createRule(&rule) + if err != nil { + return err } - rule.Id = rule.CreateIdHash() logrus.WithFields(logrus.Fields{ "formerRuleId": securityGroupRule.TerraformId(), "newRuleId": rule.TerraformId(), }).Debug("Splitting aws_security_group_rule") - newStateResources = append(newStateResources, &rule) + newStateResources = append(newStateResources, res) } } if securityGroupRule.Ipv6CidrBlocks != nil && len(*securityGroupRule.Ipv6CidrBlocks) > 0 { for _, ipRange := range *securityGroupRule.Ipv6CidrBlocks { - rule := resourceaws.AwsSecurityGroupRule{ - Type: securityGroupRule.Type, - Description: securityGroupRule.Description, - SecurityGroupId: securityGroupRule.SecurityGroupId, - Protocol: securityGroupRule.Protocol, - FromPort: securityGroupRule.FromPort, - ToPort: securityGroupRule.ToPort, - CidrBlocks: &[]string{}, - Ipv6CidrBlocks: &[]string{ipRange}, - PrefixListIds: &[]string{}, + rule := *securityGroupRule + rule.CidrBlocks = &[]string{} + rule.Ipv6CidrBlocks = &[]string{ipRange} + rule.PrefixListIds = &[]string{} + res, err := m.createRule(&rule) + if err != nil { + return err } - rule.Id = rule.CreateIdHash() logrus.WithFields(logrus.Fields{ "formerRuleId": securityGroupRule.TerraformId(), "newRuleId": rule.TerraformId(), }).Debug("Splitting aws_security_group_rule") - newStateResources = append(newStateResources, &rule) + newStateResources = append(newStateResources, res) } } if securityGroupRule.PrefixListIds != nil && len(*securityGroupRule.PrefixListIds) > 0 { for _, listId := range *securityGroupRule.PrefixListIds { - rule := resourceaws.AwsSecurityGroupRule{ - Type: securityGroupRule.Type, - Description: securityGroupRule.Description, - SecurityGroupId: securityGroupRule.SecurityGroupId, - Protocol: securityGroupRule.Protocol, - FromPort: securityGroupRule.FromPort, - ToPort: securityGroupRule.ToPort, - CidrBlocks: &[]string{}, - Ipv6CidrBlocks: &[]string{}, - PrefixListIds: &[]string{listId}, + rule := *securityGroupRule + rule.CidrBlocks = &[]string{} + rule.Ipv6CidrBlocks = &[]string{} + rule.PrefixListIds = &[]string{listId} + res, err := m.createRule(&rule) + if err != nil { + return err } - rule.Id = rule.CreateIdHash() logrus.WithFields(logrus.Fields{ "formerRuleId": securityGroupRule.TerraformId(), "newRuleId": rule.TerraformId(), }).Debug("Splitting aws_security_group_rule") - newStateResources = append(newStateResources, &rule) + newStateResources = append(newStateResources, res) } } if (securityGroupRule.Self != nil && *securityGroupRule.Self) || (securityGroupRule.SourceSecurityGroupId != nil && *securityGroupRule.SourceSecurityGroupId != "") { - rule := resourceaws.AwsSecurityGroupRule{ - Type: securityGroupRule.Type, - Description: securityGroupRule.Description, - SecurityGroupId: securityGroupRule.SecurityGroupId, - Protocol: securityGroupRule.Protocol, - FromPort: securityGroupRule.FromPort, - ToPort: securityGroupRule.ToPort, - CidrBlocks: &[]string{}, - Ipv6CidrBlocks: &[]string{}, - PrefixListIds: &[]string{}, - Self: securityGroupRule.Self, - SourceSecurityGroupId: securityGroupRule.SourceSecurityGroupId, + rule := *securityGroupRule + rule.CidrBlocks = &[]string{} + rule.Ipv6CidrBlocks = &[]string{} + rule.PrefixListIds = &[]string{} + res, err := m.createRule(&rule) + if err != nil { + return err } rule.Id = rule.CreateIdHash() logrus.WithFields(logrus.Fields{ "formerRuleId": securityGroupRule.TerraformId(), "newRuleId": rule.TerraformId(), }).Debug("Splitting aws_security_group_rule") - newStateResources = append(newStateResources, &rule) + newStateResources = append(newStateResources, res) } } @@ -123,6 +110,30 @@ func (m VPCSecurityGroupRuleSanitizer) Execute(_, resourcesFromState *[]resource return nil } +func (m *VPCSecurityGroupRuleSanitizer) createRule(res *resourceaws.AwsSecurityGroupRule) (*resourceaws.AwsSecurityGroupRule, error) { + res.Id = res.CreateIdHash() + data := map[string]interface{}{ + "id": res.Id, + "cidr_blocks": res.CidrBlocks, + "description": res.Description, + "from_port": res.FromPort, + "ipv6_cidr_blocks": res.Ipv6CidrBlocks, + "prefix_list_ids": res.PrefixListIds, + "protocol": res.Protocol, + "security_group_id": res.SecurityGroupId, + "self": res.Self, + "source_security_group_id": res.SourceSecurityGroupId, + "to_port": res.ToPort, + "type": res.Type, + } + ctyVal, err := m.resourceFactory.CreateResource(data, "aws_security_group_rule") + if err != nil { + return nil, err + } + res.CtyVal = ctyVal + return res, err +} + func shouldBeSplit(rule *resourceaws.AwsSecurityGroupRule) bool { var i int if rule.CidrBlocks != nil && len(*rule.CidrBlocks) > 0 { diff --git a/pkg/middlewares/vpc_security_group_rules_test.go b/pkg/middlewares/vpc_security_group_rules_test.go index de622db8..317b641b 100644 --- a/pkg/middlewares/vpc_security_group_rules_test.go +++ b/pkg/middlewares/vpc_security_group_rules_test.go @@ -4,13 +4,19 @@ import ( "testing" awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/stretchr/testify/mock" "github.com/cloudskiff/driftctl/pkg/resource" "github.com/cloudskiff/driftctl/pkg/resource/aws" + "github.com/cloudskiff/driftctl/pkg/terraform" ) func TestVPCSecurityGroupRuleSanitizer(t *testing.T) { - middleware := NewVPCSecurityGroupRuleSanitizer() + + factory := &terraform.MockResourceFactory{} + factory.On("CreateResource", mock.Anything, "aws_security_group_rule").Times(8).Return(nil, nil) + + middleware := NewVPCSecurityGroupRuleSanitizer(factory) var remoteResources []resource.Resource stateResources := []resource.Resource{ &aws.AwsSecurityGroup{ diff --git a/pkg/resource/resource.go b/pkg/resource/resource.go index f907eb8e..816a1d0f 100644 --- a/pkg/resource/resource.go +++ b/pkg/resource/resource.go @@ -13,6 +13,10 @@ type Resource interface { CtyValue() *cty.Value } +type ResourceFactory interface { + CreateResource(data interface{}, ty string) (*cty.Value, error) +} + type SerializableResource struct { Resource } diff --git a/pkg/terraform/mock_ResourceFactory.go b/pkg/terraform/mock_ResourceFactory.go new file mode 100644 index 00000000..89f9f43b --- /dev/null +++ b/pkg/terraform/mock_ResourceFactory.go @@ -0,0 +1,36 @@ +// Code generated by mockery v2.3.0. DO NOT EDIT. + +package terraform + +import ( + mock "github.com/stretchr/testify/mock" + cty "github.com/zclconf/go-cty/cty" +) + +// MockResourceFactory is an autogenerated mock type for the ResourceFactory type +type MockResourceFactory struct { + mock.Mock +} + +// CreateResource provides a mock function with given fields: data, ty +func (_m *MockResourceFactory) CreateResource(data interface{}, ty string) (*cty.Value, error) { + ret := _m.Called(data, ty) + + var r0 *cty.Value + if rf, ok := ret.Get(0).(func(interface{}, string) *cty.Value); ok { + r0 = rf(data, ty) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*cty.Value) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(interface{}, string) error); ok { + r1 = rf(data, ty) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/terraform/providers.go b/pkg/terraform/providers.go index 11b58358..881e6435 100644 --- a/pkg/terraform/providers.go +++ b/pkg/terraform/providers.go @@ -1,6 +1,9 @@ package terraform import ( + "strings" + + "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -37,3 +40,22 @@ func (p *ProviderLibrary) Cleanup() { provider.Cleanup() } } + +func (p *ProviderLibrary) GetProviderForResourceType(resType string) (TerraformProvider, error) { + + var name string + + if strings.HasPrefix(resType, AWS) { + name = AWS + } + + if strings.HasPrefix(resType, GITHUB) { + name = GITHUB + } + + if name != "" { + return p.Provider(name), nil + } + + return nil, errors.New("Unable to resolve provider for resource") +} diff --git a/pkg/terraform/resource_factory.go b/pkg/terraform/resource_factory.go new file mode 100644 index 00000000..2a9ddfbc --- /dev/null +++ b/pkg/terraform/resource_factory.go @@ -0,0 +1,46 @@ +package terraform + +import ( + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/zclconf/go-cty/cty" + "github.com/zclconf/go-cty/cty/gocty" +) + +type TerraformResourceFactory struct { + providerLibrary *ProviderLibrary +} + +func NewTerraformResourceFactory(providerLibrary *ProviderLibrary) *TerraformResourceFactory { + return &TerraformResourceFactory{providerLibrary: providerLibrary} +} + +func (r *TerraformResourceFactory) resolveType(ty string) (cty.Type, error) { + provider, err := r.providerLibrary.GetProviderForResourceType(ty) + if err != nil { + return cty.NilType, err + } + if schemas, exist := provider.Schema()[ty]; exist { + return schemas.Block.ImpliedType(), nil + } + + return cty.NilType, errors.New("Unable to find ") +} + +func (r *TerraformResourceFactory) CreateResource(data interface{}, ty string) (*cty.Value, error) { + ctyType, err := r.resolveType(ty) + if err != nil { + return nil, err + } + + logrus.WithFields(logrus.Fields{ + "type": ty, + }).Debug("Found cty type for resource") + + val, err := gocty.ToCtyValue(data, ctyType) + if err != nil { + return nil, err + } + + return &val, nil +}