driftctl/pkg/middlewares/vpc_security_group_rules.go

157 lines
6.1 KiB
Go

package middlewares
import (
"github.com/sirupsen/logrus"
"github.com/snyk/driftctl/pkg/resource"
resourceaws "github.com/snyk/driftctl/pkg/resource/aws"
)
// Split security group rule if it needs to given its attributes
type VPCSecurityGroupRuleSanitizer struct {
resourceFactory resource.ResourceFactory
}
func NewVPCSecurityGroupRuleSanitizer(resourceFactory resource.ResourceFactory) VPCSecurityGroupRuleSanitizer {
return VPCSecurityGroupRuleSanitizer{
resourceFactory,
}
}
func (m VPCSecurityGroupRuleSanitizer) Execute(remoteResources, resourcesFromState *[]*resource.Resource) error {
newStateResources := make([]*resource.Resource, 0)
for _, stateResource := range *resourcesFromState {
// Ignore all resources other than security group rule
if stateResource.ResourceType() != resourceaws.AwsSecurityGroupRuleResourceType {
newStateResources = append(newStateResources, stateResource)
continue
}
if stateResource.Attrs.GetBool("self") != nil && *stateResource.Attrs.GetBool("self") {
_ = stateResource.Attrs.SafeSet([]string{"source_security_group_id"}, *stateResource.Attrs.GetString("security_group_id"))
}
if !shouldBeSplit(stateResource) {
stateResource.Attrs.SafeDelete([]string{"self"})
newStateResources = append(newStateResources, stateResource)
continue
}
if stateResource.Attrs.GetSlice("cidr_blocks") != nil && len(stateResource.Attrs.GetSlice("cidr_blocks")) > 0 {
for _, ipRange := range stateResource.Attrs.GetSlice("cidr_blocks") {
attrs := stateResource.Attrs.Copy()
_ = attrs.SafeSet([]string{"cidr_blocks"}, []interface{}{ipRange})
_ = attrs.SafeSet([]string{"ipv6_cidr_blocks"}, []interface{}{})
_ = attrs.SafeSet([]string{"prefix_list_ids"}, []interface{}{})
res := m.createRule(attrs)
logrus.WithFields(logrus.Fields{
"formerRuleId": stateResource.ResourceId(),
"newRuleId": res.ResourceId(),
}).Debug("Splitting aws_security_group_rule")
res.Attrs.SafeDelete([]string{"self"})
newStateResources = append(newStateResources, res)
}
}
if stateResource.Attrs.GetSlice("ipv6_cidr_blocks") != nil && len(stateResource.Attrs.GetSlice("ipv6_cidr_blocks")) > 0 {
for _, ipRange := range stateResource.Attrs.GetSlice("ipv6_cidr_blocks") {
attrs := stateResource.Attrs.Copy()
_ = attrs.SafeSet([]string{"cidr_blocks"}, []interface{}{})
_ = attrs.SafeSet([]string{"ipv6_cidr_blocks"}, []interface{}{ipRange})
_ = attrs.SafeSet([]string{"prefix_list_ids"}, []interface{}{})
res := m.createRule(attrs)
logrus.WithFields(logrus.Fields{
"formerRuleId": stateResource.ResourceId(),
"newRuleId": res.ResourceId(),
}).Debug("Splitting aws_security_group_rule")
res.Attrs.SafeDelete([]string{"self"})
newStateResources = append(newStateResources, res)
}
}
if stateResource.Attrs.GetSlice("prefix_list_ids") != nil && len(stateResource.Attrs.GetSlice("prefix_list_ids")) > 0 {
for _, listId := range stateResource.Attrs.GetSlice("prefix_list_ids") {
attrs := stateResource.Attrs.Copy()
_ = attrs.SafeSet([]string{"cidr_blocks"}, []interface{}{})
_ = attrs.SafeSet([]string{"ipv6_cidr_blocks"}, []interface{}{})
_ = attrs.SafeSet([]string{"prefix_list_ids"}, []interface{}{listId})
res := m.createRule(attrs)
logrus.WithFields(logrus.Fields{
"formerRuleId": stateResource.ResourceId(),
"newRuleId": res.ResourceId(),
}).Debug("Splitting aws_security_group_rule")
res.Attrs.SafeDelete([]string{"self"})
newStateResources = append(newStateResources, res)
}
}
if (stateResource.Attrs.GetBool("self") != nil && *stateResource.Attrs.GetBool("self")) ||
(stateResource.Attrs.GetString("source_security_group_id") != nil && *stateResource.Attrs.GetString("source_security_group_id") != "") {
attrs := stateResource.Attrs.Copy()
_ = attrs.SafeSet([]string{"cidr_blocks"}, []interface{}{})
_ = attrs.SafeSet([]string{"ipv6_cidr_blocks"}, []interface{}{})
_ = attrs.SafeSet([]string{"prefix_list_ids"}, []interface{}{})
res := m.createRule(attrs)
logrus.WithFields(logrus.Fields{
"formerRuleId": stateResource.ResourceId(),
"newRuleId": res.ResourceId(),
}).Debug("Splitting aws_security_group_rule")
res.Attrs.SafeDelete([]string{"self"})
newStateResources = append(newStateResources, res)
}
}
*resourcesFromState = newStateResources
for _, res := range *remoteResources {
if res.ResourceType() != resourceaws.AwsSecurityGroupRuleResourceType {
continue
}
res.Attrs.SafeDelete([]string{"self"})
}
return nil
}
func (m *VPCSecurityGroupRuleSanitizer) createRule(res *resource.Attributes) *resource.Resource {
id := resourceaws.CreateSecurityGroupRuleIdHash(res)
data := map[string]interface{}{
"id": id,
"cidr_blocks": (*res)["cidr_blocks"],
"description": (*res)["description"],
"from_port": (*res)["from_port"],
"ipv6_cidr_blocks": (*res)["ipv6_cidr_blocks"],
"prefix_list_ids": (*res)["prefix_list_ids"],
"protocol": (*res)["protocol"],
"security_group_id": (*res)["security_group_id"],
"self": (*res)["self"],
"source_security_group_id": (*res)["source_security_group_id"],
"to_port": (*res)["to_port"],
"type": (*res)["type"],
}
rule := m.resourceFactory.CreateAbstractResource("aws_security_group_rule", id, data)
return rule
}
func shouldBeSplit(r *resource.Resource) bool {
var i int
if r.Attrs.GetSlice("cidr_blocks") != nil && len(r.Attrs.GetSlice("cidr_blocks")) > 0 {
i += len(r.Attrs.GetSlice("cidr_blocks"))
}
if r.Attrs.GetSlice("ipv6_cidr_blocks") != nil && len(r.Attrs.GetSlice("ipv6_cidr_blocks")) > 0 {
i += len(r.Attrs.GetSlice("ipv6_cidr_blocks"))
}
if r.Attrs.GetSlice("prefix_list_ids") != nil && len(r.Attrs.GetSlice("prefix_list_ids")) > 0 {
i += len(r.Attrs.GetSlice("prefix_list_ids"))
}
if r.Attrs.GetBool("self") != nil && *r.Attrs.GetBool("self") ||
(r.Attrs.GetString("source_security_group_id") != nil && *r.Attrs.GetString("source_security_group_id") != "") {
i += 1
}
return i > 1
}