Add resource factory

main
Elie 2021-03-29 18:10:50 +02:00
parent 5789a0c24e
commit 2b71c8e650
19 changed files with 439 additions and 110 deletions

View File

@ -171,7 +171,10 @@ func scanRun(opts *ScanOptions) error {
if err != nil {
return err
}
ctl := pkg.NewDriftCTL(scanner, iacSupplier, opts.Filter, alerter)
resFactory := terraform.NewTerraformResourceFactory(providerLibrary)
ctl := pkg.NewDriftCTL(scanner, iacSupplier, opts.Filter, alerter, resFactory)
go func() {
<-c

View File

@ -3,26 +3,28 @@ package pkg
import (
"fmt"
"github.com/jmespath/go-jmespath"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/cloudskiff/driftctl/pkg/alerter"
"github.com/cloudskiff/driftctl/pkg/analyser"
"github.com/cloudskiff/driftctl/pkg/filter"
"github.com/cloudskiff/driftctl/pkg/middlewares"
"github.com/cloudskiff/driftctl/pkg/resource"
"github.com/jmespath/go-jmespath"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
type DriftCTL struct {
remoteSupplier resource.Supplier
iacSupplier resource.Supplier
alerter alerter.AlerterInterface
analyzer analyser.Analyzer
filter *jmespath.JMESPath
remoteSupplier resource.Supplier
iacSupplier resource.Supplier
alerter alerter.AlerterInterface
analyzer analyser.Analyzer
filter *jmespath.JMESPath
resourceFactory resource.ResourceFactory
}
func NewDriftCTL(remoteSupplier resource.Supplier, iacSupplier resource.Supplier, filter *jmespath.JMESPath, alerter *alerter.Alerter) *DriftCTL {
return &DriftCTL{remoteSupplier, iacSupplier, alerter, analyser.NewAnalyzer(alerter), filter}
func NewDriftCTL(remoteSupplier resource.Supplier, iacSupplier resource.Supplier, filter *jmespath.JMESPath, alerter *alerter.Alerter, resFactory resource.ResourceFactory) *DriftCTL {
return &DriftCTL{remoteSupplier, iacSupplier, alerter, analyser.NewAnalyzer(alerter), filter, resFactory}
}
func (d DriftCTL) Run() (*analyser.Analysis, error) {
@ -34,23 +36,23 @@ func (d DriftCTL) Run() (*analyser.Analysis, error) {
middleware := middlewares.NewChain(
middlewares.NewRoute53DefaultZoneRecordSanitizer(),
middlewares.NewS3BucketAcl(),
middlewares.NewAwsInstanceBlockDeviceResourceMapper(),
middlewares.NewAwsInstanceBlockDeviceResourceMapper(d.resourceFactory),
middlewares.NewVPCDefaultSecurityGroupSanitizer(),
middlewares.NewVPCSecurityGroupRuleSanitizer(),
middlewares.NewVPCSecurityGroupRuleSanitizer(d.resourceFactory),
middlewares.NewIamPolicyAttachmentSanitizer(),
middlewares.AwsInstanceEIP{},
middlewares.NewAwsDefaultInternetGatewayRoute(),
middlewares.NewAwsDefaultInternetGateway(),
middlewares.NewAwsDefaultVPC(),
middlewares.NewAwsDefaultSubnet(),
middlewares.NewAwsRouteTableExpander(d.alerter),
middlewares.NewAwsRouteTableExpander(d.alerter, d.resourceFactory),
middlewares.NewAwsDefaultRouteTable(),
middlewares.NewAwsDefaultRoute(),
middlewares.NewAwsNatGatewayEipAssoc(),
middlewares.NewAwsBucketPolicyExpander(),
middlewares.NewAwsSqsQueuePolicyExpander(),
middlewares.NewAwsBucketPolicyExpander(d.resourceFactory),
middlewares.NewAwsSqsQueuePolicyExpander(d.resourceFactory),
middlewares.NewAwsDefaultSqsQueuePolicy(),
middlewares.NewAwsSNSTopicPolicyExpander(),
middlewares.NewAwsSNSTopicPolicyExpander(d.resourceFactory),
)
logrus.Debug("Ready to run middlewares")

View File

@ -12,6 +12,7 @@ import (
"github.com/cloudskiff/driftctl/pkg/analyser"
filter2 "github.com/cloudskiff/driftctl/pkg/filter"
"github.com/cloudskiff/driftctl/pkg/resource"
"github.com/cloudskiff/driftctl/pkg/terraform"
"github.com/cloudskiff/driftctl/test"
testresource "github.com/cloudskiff/driftctl/test/resource"
)
@ -21,6 +22,7 @@ type TestCase struct {
stateResources []resource.Resource
remoteResources []resource.Resource
filter string
mocks func(factory resource.ResourceFactory)
assert func(result *test.ScanResult, err error)
}
@ -52,7 +54,13 @@ func runTest(t *testing.T, cases TestCases) {
filter = f
}
driftctl := pkg.NewDriftCTL(remoteSupplier, stateSupplier, filter, testAlerter)
resourceFactory := &terraform.MockResourceFactory{}
if c.mocks != nil {
c.mocks(resourceFactory)
}
driftctl := pkg.NewDriftCTL(remoteSupplier, stateSupplier, filter, testAlerter, resourceFactory)
analysis, err := driftctl.Run()

View File

@ -1,16 +1,21 @@
package middlewares
import (
"github.com/sirupsen/logrus"
"github.com/cloudskiff/driftctl/pkg/resource"
"github.com/cloudskiff/driftctl/pkg/resource/aws"
"github.com/sirupsen/logrus"
)
// Explodes policy found in aws_s3_bucket.policy from state resources to dedicated resources
type AwsBucketPolicyExpander struct{}
type AwsBucketPolicyExpander struct {
resourceFactory resource.ResourceFactory
}
func NewAwsBucketPolicyExpander() AwsBucketPolicyExpander {
return AwsBucketPolicyExpander{}
func NewAwsBucketPolicyExpander(resourceFactory resource.ResourceFactory) AwsBucketPolicyExpander {
return AwsBucketPolicyExpander{
resourceFactory: resourceFactory,
}
}
func (m AwsBucketPolicyExpander) Execute(_, resourcesFromState *[]resource.Resource) error {
@ -44,10 +49,21 @@ func (m *AwsBucketPolicyExpander) handlePolicy(bucket *aws.AwsS3Bucket, results
return nil
}
data := map[string]interface{}{
"id": bucket.Id,
"bucket": bucket.Bucket,
"policy": bucket.Policy,
}
ctyVal, err := m.resourceFactory.CreateResource(data, "aws_s3_bucket_policy")
if err != nil {
return err
}
newPolicy := &aws.AwsS3BucketPolicy{
Id: bucket.Id,
Bucket: bucket.Bucket,
Policy: bucket.Policy,
CtyVal: ctyVal,
}
normalizedRes, err := newPolicy.NormalizeForState()
if err != nil {

View File

@ -6,8 +6,12 @@ import (
awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/mock"
"github.com/cloudskiff/driftctl/pkg/resource"
"github.com/cloudskiff/driftctl/pkg/resource/aws"
"github.com/cloudskiff/driftctl/pkg/terraform"
"github.com/r3labs/diff/v2"
)
@ -96,7 +100,11 @@ func TestAwsBucketPolicyExpander_Execute(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := NewAwsBucketPolicyExpander()
factory := &terraform.MockResourceFactory{}
factory.On("CreateResource", mock.Anything, "aws_s3_bucket_policy").Once().Return(nil, nil)
m := NewAwsBucketPolicyExpander(factory)
err := m.Execute(&[]resource.Resource{}, &tt.resourcesFromState)
if err != nil {
t.Fatal(err)

View File

@ -2,16 +2,19 @@ package middlewares
import (
awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/sirupsen/logrus"
"github.com/cloudskiff/driftctl/pkg/resource"
"github.com/cloudskiff/driftctl/pkg/resource/aws"
"github.com/sirupsen/logrus"
)
// Remove root_block_device from aws_instance resources and create dedicated aws_ebs_volume resources
type AwsInstanceBlockDeviceResourceMapper struct{}
type AwsInstanceBlockDeviceResourceMapper struct {
resourceFactory resource.ResourceFactory
}
func NewAwsInstanceBlockDeviceResourceMapper() AwsInstanceBlockDeviceResourceMapper {
return AwsInstanceBlockDeviceResourceMapper{}
func NewAwsInstanceBlockDeviceResourceMapper(resourceFactory resource.ResourceFactory) AwsInstanceBlockDeviceResourceMapper {
return AwsInstanceBlockDeviceResourceMapper{resourceFactory: resourceFactory}
}
func (a AwsInstanceBlockDeviceResourceMapper) Execute(remoteResources, resourcesFromState *[]resource.Resource) error {
@ -32,6 +35,21 @@ func (a AwsInstanceBlockDeviceResourceMapper) Execute(remoteResources, resources
"volume": *rootBlock.VolumeId,
"instance": instance.TerraformId(),
}).Debug("Creating aws_ebs_volume from aws_instance.root_block_device")
data := map[string]interface{}{
"availability_zone": instance.AvailabilityZone,
"encrypted": rootBlock.Encrypted,
"id": *rootBlock.VolumeId,
"iops": rootBlock.Iops,
"kms_key_id": rootBlock.KmsKeyId,
"size": rootBlock.VolumeSize,
"type": rootBlock.VolumeType,
"multi_attach_enabled": false,
"tags": instance.VolumeTags,
}
ctyVal, err := a.resourceFactory.CreateResource(data, "aws_ebs_volume")
if err != nil {
return err
}
ebsVolume := aws.AwsEbsVolume{
AvailabilityZone: instance.AvailabilityZone,
Encrypted: rootBlock.Encrypted,
@ -42,6 +60,7 @@ func (a AwsInstanceBlockDeviceResourceMapper) Execute(remoteResources, resources
Type: rootBlock.VolumeType,
MultiAttachEnabled: awssdk.Bool(false),
Tags: instance.VolumeTags,
CtyVal: ctyVal,
}
newStateResources = append(newStateResources, &ebsVolume)
}
@ -53,6 +72,21 @@ func (a AwsInstanceBlockDeviceResourceMapper) Execute(remoteResources, resources
"volume": *blockDevice.VolumeId,
"instance": instance.TerraformId(),
}).Debug("Creating aws_ebs_volume from aws_instance.ebs_block_device")
data := map[string]interface{}{
"availability_zone": instance.AvailabilityZone,
"encrypted": blockDevice.Encrypted,
"id": *blockDevice.VolumeId,
"iops": blockDevice.Iops,
"kms_key_id": blockDevice.KmsKeyId,
"size": blockDevice.VolumeSize,
"type": blockDevice.VolumeType,
"multi_attach_enabled": false,
"tags": instance.VolumeTags,
}
ctyVal, err := a.resourceFactory.CreateResource(data, "aws_ebs_volume")
if err != nil {
return err
}
ebsVolume := aws.AwsEbsVolume{
AvailabilityZone: instance.AvailabilityZone,
Encrypted: blockDevice.Encrypted,
@ -63,6 +97,7 @@ func (a AwsInstanceBlockDeviceResourceMapper) Execute(remoteResources, resources
Type: blockDevice.VolumeType,
MultiAttachEnabled: awssdk.Bool(false),
Tags: instance.VolumeTags,
CtyVal: ctyVal,
}
newStateResources = append(newStateResources, &ebsVolume)
}

View File

@ -6,8 +6,12 @@ import (
awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/mock"
"github.com/cloudskiff/driftctl/pkg/resource"
"github.com/cloudskiff/driftctl/pkg/resource/aws"
"github.com/cloudskiff/driftctl/pkg/terraform"
"github.com/r3labs/diff/v2"
)
@ -19,6 +23,7 @@ func TestAwsInstanceBlockDeviceResourceMapper_Execute(t *testing.T) {
tests := []struct {
name string
args args
mocks func(factory *terraform.MockResourceFactory)
wantErr bool
}{
{
@ -118,12 +123,21 @@ func TestAwsInstanceBlockDeviceResourceMapper_Execute(t *testing.T) {
},
},
},
func(factory *terraform.MockResourceFactory) {
factory.On("CreateResource", mock.Anything, "aws_ebs_volume").Times(2).Return(nil, nil)
},
false,
},
}
for _, c := range tests {
t.Run(c.name, func(tt *testing.T) {
a := AwsInstanceBlockDeviceResourceMapper{}
factory := &terraform.MockResourceFactory{}
if c.mocks != nil {
c.mocks(factory)
}
a := NewAwsInstanceBlockDeviceResourceMapper(factory)
if err := a.Execute(&[]resource.Resource{}, c.args.resourcesFromState); (err != nil) != c.wantErr {
t.Errorf("Execute() error = %v, wantErr %v", err, c.wantErr)
}

View File

@ -4,10 +4,11 @@ import (
"fmt"
awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/sirupsen/logrus"
"github.com/cloudskiff/driftctl/pkg/alerter"
"github.com/cloudskiff/driftctl/pkg/resource"
"github.com/cloudskiff/driftctl/pkg/resource/aws"
"github.com/sirupsen/logrus"
)
type invalidRouteAlert struct {
@ -29,12 +30,14 @@ func (i *invalidRouteAlert) ShouldIgnoreResource() bool {
// Explodes routes found in aws_default_route_table.route and aws_route_table.route to dedicated resources
type AwsRouteTableExpander struct {
alerter alerter.AlerterInterface
alerter alerter.AlerterInterface
resourceFactory resource.ResourceFactory
}
func NewAwsRouteTableExpander(alerter alerter.AlerterInterface) AwsRouteTableExpander {
func NewAwsRouteTableExpander(alerter alerter.AlerterInterface, resourceFactory resource.ResourceFactory) AwsRouteTableExpander {
return AwsRouteTableExpander{
alerter,
resourceFactory,
}
}
@ -80,6 +83,29 @@ func (m *AwsRouteTableExpander) handleTable(table *aws.AwsRouteTable, results *[
m.alerter.SendAlert(aws.AwsRouteTableResourceType, newInvalidRouteAlert(aws.AwsRouteTableResourceType, table.Id))
continue
}
data := map[string]interface{}{
"destination_cidr_block": route.CidrBlock,
"destination_ipv6_cidr_block": route.Ipv6CidrBlock,
"destination_prefix_list_id": "",
"egress_only_gateway_id": route.EgressOnlyGatewayId,
"gateway_id": route.GatewayId,
"id": routeId,
"instance_id": route.InstanceId,
"instance_owner_id": "",
"local_gateway_id": route.LocalGatewayId,
"nat_gateway_id": route.NatGatewayId,
"network_interface_id": route.NetworkInterfaceId,
"origin": "CreateRoute",
"route_table_id": table.Id,
"state": "active",
"transit_gateway_id": route.TransitGatewayId,
"vpc_endpoint_id": route.VpcEndpointId,
"vpc_peering_connection_id": route.VpcPeeringConnectionId,
}
ctyVal, err := m.resourceFactory.CreateResource(data, "aws_route")
if err != nil {
return err
}
newRouteFromTable := &aws.AwsRoute{
DestinationCidrBlock: route.CidrBlock,
DestinationIpv6CidrBlock: route.Ipv6CidrBlock,
@ -98,6 +124,7 @@ func (m *AwsRouteTableExpander) handleTable(table *aws.AwsRouteTable, results *[
TransitGatewayId: route.TransitGatewayId,
VpcEndpointId: route.VpcEndpointId,
VpcPeeringConnectionId: route.VpcPeeringConnectionId,
CtyVal: ctyVal,
}
normalizedRes, err := newRouteFromTable.NormalizeForState()
if err != nil {
@ -125,6 +152,28 @@ func (m *AwsRouteTableExpander) handleDefaultTable(table *aws.AwsDefaultRouteTab
m.alerter.SendAlert(aws.AwsDefaultRouteTableResourceType, newInvalidRouteAlert(aws.AwsDefaultRouteTableResourceType, table.Id))
continue
}
data := map[string]interface{}{
"destination_cidr_block": route.CidrBlock,
"destination_ipv6_cidr_block": route.Ipv6CidrBlock,
"destination_prefix_list_id": "",
"egress_only_gateway_id": route.EgressOnlyGatewayId,
"gateway_id": route.GatewayId,
"id": routeId,
"instance_id": route.InstanceId,
"instance_owner_id": "",
"nat_gateway_id": route.NatGatewayId,
"network_interface_id": route.NetworkInterfaceId,
"origin": "CreateRoute",
"route_table_id": table.Id,
"state": "active",
"transit_gateway_id": route.TransitGatewayId,
"vpc_endpoint_id": route.VpcEndpointId,
"vpc_peering_connection_id": route.VpcPeeringConnectionId,
}
ctyVal, err := m.resourceFactory.CreateResource(data, "aws_route")
if err != nil {
return err
}
newRouteFromTable := &aws.AwsRoute{
DestinationCidrBlock: route.CidrBlock,
DestinationIpv6CidrBlock: route.Ipv6CidrBlock,
@ -142,6 +191,7 @@ func (m *AwsRouteTableExpander) handleDefaultTable(table *aws.AwsDefaultRouteTab
TransitGatewayId: route.TransitGatewayId,
VpcEndpointId: route.VpcEndpointId,
VpcPeeringConnectionId: route.VpcPeeringConnectionId,
CtyVal: ctyVal,
}
normalizedRes, err := newRouteFromTable.NormalizeForState()
if err != nil {

View File

@ -6,11 +6,14 @@ import (
awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/r3labs/diff/v2"
"github.com/stretchr/testify/mock"
"github.com/cloudskiff/driftctl/mocks"
"github.com/cloudskiff/driftctl/pkg/resource"
"github.com/cloudskiff/driftctl/pkg/resource/aws"
"github.com/cloudskiff/driftctl/pkg/terraform"
resource2 "github.com/cloudskiff/driftctl/test/resource"
"github.com/r3labs/diff/v2"
)
func TestAwsRouteTableExpander_Execute(t *testing.T) {
@ -18,17 +21,17 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) {
name string
input []resource.Resource
expected []resource.Resource
mock func(factory *terraform.MockResourceFactory)
}{
{
"test with nil route attributes",
[]resource.Resource{
name: "test with nil route attributes",
input: []resource.Resource{
&aws.AwsRouteTable{
Id: "table_from_state",
Route: nil,
},
},
[]resource.Resource{
expected: []resource.Resource{
&aws.AwsRouteTable{
Id: "table_from_state",
Route: nil,
@ -36,8 +39,8 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) {
},
},
{
"test with empty route attributes",
[]resource.Resource{
name: "test with empty route attributes",
input: []resource.Resource{
&aws.AwsRouteTable{
Id: "table_from_state",
Route: &[]struct {
@ -55,7 +58,7 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) {
}{},
},
},
[]resource.Resource{
expected: []resource.Resource{
&aws.AwsRouteTable{
Id: "table_from_state",
Route: nil,
@ -63,8 +66,8 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) {
},
},
{
"test route are expanded",
[]resource.Resource{
name: "test route are expanded",
input: []resource.Resource{
&resource2.FakeResource{
Id: "fake_resource",
},
@ -95,7 +98,7 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) {
},
},
},
[]resource.Resource{
expected: []resource.Resource{
&resource2.FakeResource{
Id: "fake_resource",
},
@ -124,10 +127,13 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) {
InstanceOwnerId: awssdk.String(""),
},
},
mock: func(factory *terraform.MockResourceFactory) {
factory.On("CreateResource", mock.Anything, "aws_route").Times(2).Return(nil, nil)
},
},
{
"test route are expanded on default route tables",
[]resource.Resource{
name: "test route are expanded on default route tables",
input: []resource.Resource{
&resource2.FakeResource{
Id: "fake_resource",
},
@ -157,7 +163,7 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) {
},
},
},
[]resource.Resource{
expected: []resource.Resource{
&resource2.FakeResource{
Id: "fake_resource",
},
@ -186,12 +192,21 @@ func TestAwsRouteTableExpander_Execute(t *testing.T) {
InstanceOwnerId: awssdk.String(""),
},
},
mock: func(factory *terraform.MockResourceFactory) {
factory.On("CreateResource", mock.Anything, "aws_route").Times(2).Return(nil, nil)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockedAlerter := &mocks.AlerterInterface{}
m := NewAwsRouteTableExpander(mockedAlerter)
factory := &terraform.MockResourceFactory{}
if tt.mock != nil {
tt.mock(factory)
}
m := NewAwsRouteTableExpander(mockedAlerter, factory)
err := m.Execute(nil, &tt.input)
if err != nil {
t.Fatal(err)
@ -274,7 +289,9 @@ func TestAwsRouteTableExpander_ExecuteWithInvalidRoutes(t *testing.T) {
},
}
m := NewAwsRouteTableExpander(mockedAlerter)
factory := &terraform.MockResourceFactory{}
m := NewAwsRouteTableExpander(mockedAlerter, factory)
err := m.Execute(nil, &input)
if err != nil {
t.Fatal(err)

View File

@ -1,16 +1,21 @@
package middlewares
import (
"github.com/sirupsen/logrus"
"github.com/cloudskiff/driftctl/pkg/resource"
"github.com/cloudskiff/driftctl/pkg/resource/aws"
"github.com/sirupsen/logrus"
)
// Explodes policy found in aws_sns_topic from state resources to aws_sns_topic_policy resources
type AwsSNSTopicPolicyExpander struct{}
type AwsSNSTopicPolicyExpander struct {
resourceFactory resource.ResourceFactory
}
func NewAwsSNSTopicPolicyExpander() AwsSNSTopicPolicyExpander {
return AwsSNSTopicPolicyExpander{}
func NewAwsSNSTopicPolicyExpander(resourceFactory resource.ResourceFactory) AwsSNSTopicPolicyExpander {
return AwsSNSTopicPolicyExpander{
resourceFactory,
}
}
func (m AwsSNSTopicPolicyExpander) Execute(_, resourcesFromState *[]resource.Resource) error {
@ -44,10 +49,21 @@ func (m *AwsSNSTopicPolicyExpander) splitPolicy(topic *aws.AwsSnsTopic, results
return nil
}
data := map[string]interface{}{
"arn": topic.Arn,
"id": topic.Id,
"policy": topic.Policy,
}
ctyVal, err := m.resourceFactory.CreateResource(data, "aws_sns_topic_policy")
if err != nil {
return err
}
newPolicy := &aws.AwsSnsTopicPolicy{
Id: topic.Id,
Arn: topic.Arn,
Policy: topic.Policy,
CtyVal: ctyVal,
}
normalized, err := newPolicy.NormalizeForState()

View File

@ -5,8 +5,10 @@ import (
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/stretchr/testify/mock"
awsresource "github.com/cloudskiff/driftctl/pkg/resource/aws"
"github.com/cloudskiff/driftctl/pkg/terraform"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/r3labs/diff/v2"
@ -19,6 +21,7 @@ func TestAwsSNSTopicPolicyExpander_Execute(t *testing.T) {
name string
resourcesFromState *[]resource.Resource
expected *[]resource.Resource
mock func(factory *terraform.MockResourceFactory)
wantErr bool
}{
{
@ -42,6 +45,9 @@ func TestAwsSNSTopicPolicyExpander_Execute(t *testing.T) {
Id: "ID",
},
},
mock: func(factory *terraform.MockResourceFactory) {
factory.On("CreateResource", mock.Anything, "aws_sns_topic_policy").Once().Return(nil, nil)
},
wantErr: false,
},
{
@ -103,7 +109,13 @@ func TestAwsSNSTopicPolicyExpander_Execute(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := AwsSNSTopicPolicyExpander{}
factory := &terraform.MockResourceFactory{}
if tt.mock != nil {
tt.mock(factory)
}
m := NewAwsSNSTopicPolicyExpander(factory)
if err := m.Execute(&[]resource.Resource{}, tt.resourcesFromState); (err != nil) != tt.wantErr {
t.Errorf("Execute() error = %v, wantErr %v", err, tt.wantErr)
}

View File

@ -2,16 +2,21 @@ package middlewares
import (
awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/sirupsen/logrus"
"github.com/cloudskiff/driftctl/pkg/resource"
"github.com/cloudskiff/driftctl/pkg/resource/aws"
"github.com/sirupsen/logrus"
)
// Explodes policy found in aws_sqs_queue.policy from state resources to dedicated resources
type AwsSqsQueuePolicyExpander struct{}
type AwsSqsQueuePolicyExpander struct {
resourceFactory resource.ResourceFactory
}
func NewAwsSqsQueuePolicyExpander() AwsSqsQueuePolicyExpander {
return AwsSqsQueuePolicyExpander{}
func NewAwsSqsQueuePolicyExpander(resourceFactory resource.ResourceFactory) AwsSqsQueuePolicyExpander {
return AwsSqsQueuePolicyExpander{
resourceFactory,
}
}
func (m AwsSqsQueuePolicyExpander) Execute(_, resourcesFromState *[]resource.Resource) error {
@ -45,10 +50,20 @@ func (m AwsSqsQueuePolicyExpander) Execute(_, resourcesFromState *[]resource.Res
}
func (m *AwsSqsQueuePolicyExpander) handlePolicy(queue *aws.AwsSqsQueue, results *[]resource.Resource) error {
data := map[string]interface{}{
"queue_url": queue.Id,
"id": queue.Id,
"policy": queue.Policy,
}
ctyVal, err := m.resourceFactory.CreateResource(data, "aws_sqs_queue_policy")
if err != nil {
return err
}
newPolicy := &aws.AwsSqsQueuePolicy{
Id: queue.Id,
QueueUrl: awssdk.String(queue.Id),
Policy: queue.Policy,
CtyVal: ctyVal,
}
normalizedRes, err := newPolicy.NormalizeForState()
if err != nil {

View File

@ -6,8 +6,12 @@ import (
awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/mock"
"github.com/cloudskiff/driftctl/pkg/resource"
"github.com/cloudskiff/driftctl/pkg/resource/aws"
"github.com/cloudskiff/driftctl/pkg/terraform"
"github.com/r3labs/diff/v2"
)
@ -90,7 +94,11 @@ func TestAwsSqsQueuePolicyExpander_Execute(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := NewAwsSqsQueuePolicyExpander()
factory := &terraform.MockResourceFactory{}
factory.On("CreateResource", mock.Anything, "aws_sqs_queue_policy").Once().Return(nil, nil)
m := NewAwsSqsQueuePolicyExpander(factory)
err := m.Execute(&[]resource.Resource{}, &tt.resourcesFromState)
if err != nil {
t.Fatal(err)

View File

@ -8,10 +8,14 @@ import (
)
// Split security group rule if it needs to given its attributes
type VPCSecurityGroupRuleSanitizer struct{}
type VPCSecurityGroupRuleSanitizer struct {
resourceFactory resource.ResourceFactory
}
func NewVPCSecurityGroupRuleSanitizer() VPCSecurityGroupRuleSanitizer {
return VPCSecurityGroupRuleSanitizer{}
func NewVPCSecurityGroupRuleSanitizer(resourceFactory resource.ResourceFactory) VPCSecurityGroupRuleSanitizer {
return VPCSecurityGroupRuleSanitizer{
resourceFactory,
}
}
func (m VPCSecurityGroupRuleSanitizer) Execute(_, resourcesFromState *[]resource.Resource) error {
@ -33,88 +37,71 @@ func (m VPCSecurityGroupRuleSanitizer) Execute(_, resourcesFromState *[]resource
if securityGroupRule.CidrBlocks != nil && len(*securityGroupRule.CidrBlocks) > 0 {
for _, ipRange := range *securityGroupRule.CidrBlocks {
rule := resourceaws.AwsSecurityGroupRule{
Type: securityGroupRule.Type,
Description: securityGroupRule.Description,
SecurityGroupId: securityGroupRule.SecurityGroupId,
Protocol: securityGroupRule.Protocol,
FromPort: securityGroupRule.FromPort,
ToPort: securityGroupRule.ToPort,
CidrBlocks: &[]string{ipRange},
Ipv6CidrBlocks: &[]string{},
PrefixListIds: &[]string{},
rule := *securityGroupRule
rule.CidrBlocks = &[]string{ipRange}
rule.Ipv6CidrBlocks = &[]string{}
rule.PrefixListIds = &[]string{}
res, err := m.createRule(&rule)
if err != nil {
return err
}
rule.Id = rule.CreateIdHash()
logrus.WithFields(logrus.Fields{
"formerRuleId": securityGroupRule.TerraformId(),
"newRuleId": rule.TerraformId(),
}).Debug("Splitting aws_security_group_rule")
newStateResources = append(newStateResources, &rule)
newStateResources = append(newStateResources, res)
}
}
if securityGroupRule.Ipv6CidrBlocks != nil && len(*securityGroupRule.Ipv6CidrBlocks) > 0 {
for _, ipRange := range *securityGroupRule.Ipv6CidrBlocks {
rule := resourceaws.AwsSecurityGroupRule{
Type: securityGroupRule.Type,
Description: securityGroupRule.Description,
SecurityGroupId: securityGroupRule.SecurityGroupId,
Protocol: securityGroupRule.Protocol,
FromPort: securityGroupRule.FromPort,
ToPort: securityGroupRule.ToPort,
CidrBlocks: &[]string{},
Ipv6CidrBlocks: &[]string{ipRange},
PrefixListIds: &[]string{},
rule := *securityGroupRule
rule.CidrBlocks = &[]string{}
rule.Ipv6CidrBlocks = &[]string{ipRange}
rule.PrefixListIds = &[]string{}
res, err := m.createRule(&rule)
if err != nil {
return err
}
rule.Id = rule.CreateIdHash()
logrus.WithFields(logrus.Fields{
"formerRuleId": securityGroupRule.TerraformId(),
"newRuleId": rule.TerraformId(),
}).Debug("Splitting aws_security_group_rule")
newStateResources = append(newStateResources, &rule)
newStateResources = append(newStateResources, res)
}
}
if securityGroupRule.PrefixListIds != nil && len(*securityGroupRule.PrefixListIds) > 0 {
for _, listId := range *securityGroupRule.PrefixListIds {
rule := resourceaws.AwsSecurityGroupRule{
Type: securityGroupRule.Type,
Description: securityGroupRule.Description,
SecurityGroupId: securityGroupRule.SecurityGroupId,
Protocol: securityGroupRule.Protocol,
FromPort: securityGroupRule.FromPort,
ToPort: securityGroupRule.ToPort,
CidrBlocks: &[]string{},
Ipv6CidrBlocks: &[]string{},
PrefixListIds: &[]string{listId},
rule := *securityGroupRule
rule.CidrBlocks = &[]string{}
rule.Ipv6CidrBlocks = &[]string{}
rule.PrefixListIds = &[]string{listId}
res, err := m.createRule(&rule)
if err != nil {
return err
}
rule.Id = rule.CreateIdHash()
logrus.WithFields(logrus.Fields{
"formerRuleId": securityGroupRule.TerraformId(),
"newRuleId": rule.TerraformId(),
}).Debug("Splitting aws_security_group_rule")
newStateResources = append(newStateResources, &rule)
newStateResources = append(newStateResources, res)
}
}
if (securityGroupRule.Self != nil && *securityGroupRule.Self) ||
(securityGroupRule.SourceSecurityGroupId != nil && *securityGroupRule.SourceSecurityGroupId != "") {
rule := resourceaws.AwsSecurityGroupRule{
Type: securityGroupRule.Type,
Description: securityGroupRule.Description,
SecurityGroupId: securityGroupRule.SecurityGroupId,
Protocol: securityGroupRule.Protocol,
FromPort: securityGroupRule.FromPort,
ToPort: securityGroupRule.ToPort,
CidrBlocks: &[]string{},
Ipv6CidrBlocks: &[]string{},
PrefixListIds: &[]string{},
Self: securityGroupRule.Self,
SourceSecurityGroupId: securityGroupRule.SourceSecurityGroupId,
rule := *securityGroupRule
rule.CidrBlocks = &[]string{}
rule.Ipv6CidrBlocks = &[]string{}
rule.PrefixListIds = &[]string{}
res, err := m.createRule(&rule)
if err != nil {
return err
}
rule.Id = rule.CreateIdHash()
logrus.WithFields(logrus.Fields{
"formerRuleId": securityGroupRule.TerraformId(),
"newRuleId": rule.TerraformId(),
}).Debug("Splitting aws_security_group_rule")
newStateResources = append(newStateResources, &rule)
newStateResources = append(newStateResources, res)
}
}
@ -123,6 +110,30 @@ func (m VPCSecurityGroupRuleSanitizer) Execute(_, resourcesFromState *[]resource
return nil
}
func (m *VPCSecurityGroupRuleSanitizer) createRule(res *resourceaws.AwsSecurityGroupRule) (*resourceaws.AwsSecurityGroupRule, error) {
res.Id = res.CreateIdHash()
data := map[string]interface{}{
"id": res.Id,
"cidr_blocks": res.CidrBlocks,
"description": res.Description,
"from_port": res.FromPort,
"ipv6_cidr_blocks": res.Ipv6CidrBlocks,
"prefix_list_ids": res.PrefixListIds,
"protocol": res.Protocol,
"security_group_id": res.SecurityGroupId,
"self": res.Self,
"source_security_group_id": res.SourceSecurityGroupId,
"to_port": res.ToPort,
"type": res.Type,
}
ctyVal, err := m.resourceFactory.CreateResource(data, "aws_security_group_rule")
if err != nil {
return nil, err
}
res.CtyVal = ctyVal
return res, err
}
func shouldBeSplit(rule *resourceaws.AwsSecurityGroupRule) bool {
var i int
if rule.CidrBlocks != nil && len(*rule.CidrBlocks) > 0 {

View File

@ -4,13 +4,19 @@ import (
"testing"
awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/stretchr/testify/mock"
"github.com/cloudskiff/driftctl/pkg/resource"
"github.com/cloudskiff/driftctl/pkg/resource/aws"
"github.com/cloudskiff/driftctl/pkg/terraform"
)
func TestVPCSecurityGroupRuleSanitizer(t *testing.T) {
middleware := NewVPCSecurityGroupRuleSanitizer()
factory := &terraform.MockResourceFactory{}
factory.On("CreateResource", mock.Anything, "aws_security_group_rule").Times(8).Return(nil, nil)
middleware := NewVPCSecurityGroupRuleSanitizer(factory)
var remoteResources []resource.Resource
stateResources := []resource.Resource{
&aws.AwsSecurityGroup{

View File

@ -13,6 +13,10 @@ type Resource interface {
CtyValue() *cty.Value
}
type ResourceFactory interface {
CreateResource(data interface{}, ty string) (*cty.Value, error)
}
type SerializableResource struct {
Resource
}

View File

@ -0,0 +1,36 @@
// Code generated by mockery v2.3.0. DO NOT EDIT.
package terraform
import (
mock "github.com/stretchr/testify/mock"
cty "github.com/zclconf/go-cty/cty"
)
// MockResourceFactory is an autogenerated mock type for the ResourceFactory type
type MockResourceFactory struct {
mock.Mock
}
// CreateResource provides a mock function with given fields: data, ty
func (_m *MockResourceFactory) CreateResource(data interface{}, ty string) (*cty.Value, error) {
ret := _m.Called(data, ty)
var r0 *cty.Value
if rf, ok := ret.Get(0).(func(interface{}, string) *cty.Value); ok {
r0 = rf(data, ty)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*cty.Value)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(interface{}, string) error); ok {
r1 = rf(data, ty)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

View File

@ -1,6 +1,9 @@
package terraform
import (
"strings"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
@ -37,3 +40,22 @@ func (p *ProviderLibrary) Cleanup() {
provider.Cleanup()
}
}
func (p *ProviderLibrary) GetProviderForResourceType(resType string) (TerraformProvider, error) {
var name string
if strings.HasPrefix(resType, AWS) {
name = AWS
}
if strings.HasPrefix(resType, GITHUB) {
name = GITHUB
}
if name != "" {
return p.Provider(name), nil
}
return nil, errors.New("Unable to resolve provider for resource")
}

View File

@ -0,0 +1,46 @@
package terraform
import (
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/zclconf/go-cty/cty"
"github.com/zclconf/go-cty/cty/gocty"
)
type TerraformResourceFactory struct {
providerLibrary *ProviderLibrary
}
func NewTerraformResourceFactory(providerLibrary *ProviderLibrary) *TerraformResourceFactory {
return &TerraformResourceFactory{providerLibrary: providerLibrary}
}
func (r *TerraformResourceFactory) resolveType(ty string) (cty.Type, error) {
provider, err := r.providerLibrary.GetProviderForResourceType(ty)
if err != nil {
return cty.NilType, err
}
if schemas, exist := provider.Schema()[ty]; exist {
return schemas.Block.ImpliedType(), nil
}
return cty.NilType, errors.New("Unable to find ")
}
func (r *TerraformResourceFactory) CreateResource(data interface{}, ty string) (*cty.Value, error) {
ctyType, err := r.resolveType(ty)
if err != nil {
return nil, err
}
logrus.WithFields(logrus.Fields{
"type": ty,
}).Debug("Found cty type for resource")
val, err := gocty.ToCtyValue(data, ctyType)
if err != nil {
return nil, err
}
return &val, nil
}