diff --git a/pkg/alerter/alert.go b/enumeration/alerter/alert.go similarity index 100% rename from pkg/alerter/alert.go rename to enumeration/alerter/alert.go diff --git a/enumeration/alerter/alerter.go b/enumeration/alerter/alerter.go new file mode 100644 index 00000000..d134d2b1 --- /dev/null +++ b/enumeration/alerter/alerter.go @@ -0,0 +1,75 @@ +package alerter + +import ( + "fmt" + + "github.com/snyk/driftctl/enumeration/resource" +) + +type AlerterInterface interface { + SendAlert(key string, alert Alert) +} + +type Alerter struct { + alerts Alerts + alertsCh chan Alerts + doneCh chan bool +} + +func NewAlerter() *Alerter { + var alerter = &Alerter{ + alerts: make(Alerts), + alertsCh: make(chan Alerts), + doneCh: make(chan bool), + } + + go alerter.run() + + return alerter +} + +func (a *Alerter) run() { + defer func() { a.doneCh <- true }() + for alert := range a.alertsCh { + for k, v := range alert { + if val, ok := a.alerts[k]; ok { + a.alerts[k] = append(val, v...) + } else { + a.alerts[k] = v + } + } + } +} + +func (a *Alerter) SetAlerts(alerts Alerts) { + a.alerts = alerts +} + +func (a *Alerter) Retrieve() Alerts { + close(a.alertsCh) + <-a.doneCh + return a.alerts +} + +func (a *Alerter) SendAlert(key string, alert Alert) { + a.alertsCh <- Alerts{ + key: []Alert{alert}, + } +} + +func (a *Alerter) IsResourceIgnored(res *resource.Resource) bool { + alert, alertExists := a.alerts[fmt.Sprintf("%s.%s", res.ResourceType(), res.ResourceId())] + wildcardAlert, wildcardAlertExists := a.alerts[res.ResourceType()] + shouldIgnoreAlert := a.shouldBeIgnored(alert) + shouldIgnoreWildcardAlert := a.shouldBeIgnored(wildcardAlert) + return (alertExists && shouldIgnoreAlert) || (wildcardAlertExists && shouldIgnoreWildcardAlert) +} + +func (a *Alerter) shouldBeIgnored(alert []Alert) bool { + for _, a := range alert { + if a.ShouldIgnoreResource() { + return true + } + } + return false +} diff --git a/enumeration/alerter/alerter_test.go b/enumeration/alerter/alerter_test.go new file mode 100644 index 00000000..ff1252d2 --- /dev/null +++ b/enumeration/alerter/alerter_test.go @@ -0,0 +1,161 @@ +package alerter + +import ( + "reflect" + "testing" + + "github.com/snyk/driftctl/enumeration/resource" +) + +func TestAlerter_Alert(t *testing.T) { + cases := []struct { + name string + alerts Alerts + expected Alerts + }{ + { + name: "TestNoAlerts", + alerts: nil, + expected: Alerts{}, + }, + { + name: "TestWithSingleAlert", + alerts: Alerts{ + "fakeres.foobar": []Alert{ + &FakeAlert{"This is an alert", false}, + }, + }, + expected: Alerts{ + "fakeres.foobar": []Alert{ + &FakeAlert{"This is an alert", false}, + }, + }, + }, + { + name: "TestWithMultipleAlerts", + alerts: Alerts{ + "fakeres.foobar": []Alert{ + &FakeAlert{"This is an alert", false}, + &FakeAlert{"This is a second alert", true}, + }, + "fakeres.barfoo": []Alert{ + &FakeAlert{"This is a third alert", true}, + }, + }, + expected: Alerts{ + "fakeres.foobar": []Alert{ + &FakeAlert{"This is an alert", false}, + &FakeAlert{"This is a second alert", true}, + }, + "fakeres.barfoo": []Alert{ + &FakeAlert{"This is a third alert", true}, + }, + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + alerter := NewAlerter() + + for k, v := range c.alerts { + for _, a := range v { + alerter.SendAlert(k, a) + } + } + + if eq := reflect.DeepEqual(alerter.Retrieve(), c.expected); !eq { + t.Errorf("Got %+v, expected %+v", alerter.Retrieve(), c.expected) + } + }) + } +} + +func TestAlerter_IgnoreResources(t *testing.T) { + cases := []struct { + name string + alerts Alerts + resource *resource.Resource + expected bool + }{ + { + name: "TestNoAlerts", + alerts: Alerts{}, + resource: &resource.Resource{ + Type: "fakeres", + Id: "foobar", + }, + expected: false, + }, + { + name: "TestShouldNotBeIgnoredWithAlerts", + alerts: Alerts{ + "fakeres": { + &FakeAlert{"Should not be ignored", false}, + }, + "fakeres.foobar": { + &FakeAlert{"Should not be ignored", false}, + }, + "fakeres.barfoo": { + &FakeAlert{"Should not be ignored", false}, + }, + "other.resource": { + &FakeAlert{"Should not be ignored", false}, + }, + }, + resource: &resource.Resource{ + Type: "fakeres", + Id: "foobar", + }, + expected: false, + }, + { + name: "TestShouldBeIgnoredWithAlertsOnWildcard", + alerts: Alerts{ + "fakeres": { + &FakeAlert{"Should be ignored", true}, + }, + "other.foobaz": { + &FakeAlert{"Should be ignored", true}, + }, + "other.resource": { + &FakeAlert{"Should not be ignored", false}, + }, + }, + resource: &resource.Resource{ + Type: "fakeres", + Id: "foobar", + }, + expected: true, + }, + { + name: "TestShouldBeIgnoredWithAlertsOnResource", + alerts: Alerts{ + "fakeres": { + &FakeAlert{"Should be ignored", true}, + }, + "other.foobaz": { + &FakeAlert{"Should be ignored", true}, + }, + "other.resource": { + &FakeAlert{"Should not be ignored", false}, + }, + }, + resource: &resource.Resource{ + Type: "other", + Id: "foobaz", + }, + expected: true, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + alerter := NewAlerter() + alerter.SetAlerts(c.alerts) + if got := alerter.IsResourceIgnored(c.resource); got != c.expected { + t.Errorf("Got %+v, expected %+v", got, c.expected) + } + }) + } +} diff --git a/enumeration/filter.go b/enumeration/filter.go new file mode 100644 index 00000000..7c8ef215 --- /dev/null +++ b/enumeration/filter.go @@ -0,0 +1,9 @@ +package enumeration + +import "github.com/snyk/driftctl/enumeration/resource" + +type Filter interface { + IsTypeIgnored(ty resource.ResourceType) bool + IsResourceIgnored(res *resource.Resource) bool + IsFieldIgnored(res *resource.Resource, path []string) bool +} diff --git a/enumeration/mock_Filter.go b/enumeration/mock_Filter.go new file mode 100644 index 00000000..172b23c8 --- /dev/null +++ b/enumeration/mock_Filter.go @@ -0,0 +1,55 @@ +// Code generated by mockery v0.0.0-dev. DO NOT EDIT. + +package enumeration + +import ( + resource "github.com/snyk/driftctl/enumeration/resource" + mock "github.com/stretchr/testify/mock" +) + +// MockFilter is an autogenerated mock type for the Filter type +type MockFilter struct { + mock.Mock +} + +// IsFieldIgnored provides a mock function with given fields: res, path +func (_m *MockFilter) IsFieldIgnored(res *resource.Resource, path []string) bool { + ret := _m.Called(res, path) + + var r0 bool + if rf, ok := ret.Get(0).(func(*resource.Resource, []string) bool); ok { + r0 = rf(res, path) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// IsResourceIgnored provides a mock function with given fields: res +func (_m *MockFilter) IsResourceIgnored(res *resource.Resource) bool { + ret := _m.Called(res) + + var r0 bool + if rf, ok := ret.Get(0).(func(*resource.Resource) bool); ok { + r0 = rf(res) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// IsTypeIgnored provides a mock function with given fields: ty +func (_m *MockFilter) IsTypeIgnored(ty resource.ResourceType) bool { + ret := _m.Called(ty) + + var r0 bool + if rf, ok := ret.Get(0).(func(resource.ResourceType) bool); ok { + r0 = rf(ty) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} diff --git a/pkg/parallel/parallel_runner.go b/enumeration/parallel/parallel_runner.go similarity index 100% rename from pkg/parallel/parallel_runner.go rename to enumeration/parallel/parallel_runner.go diff --git a/pkg/parallel/parallel_runner_test.go b/enumeration/parallel/parallel_runner_test.go similarity index 100% rename from pkg/parallel/parallel_runner_test.go rename to enumeration/parallel/parallel_runner_test.go diff --git a/enumeration/progress.go b/enumeration/progress.go new file mode 100644 index 00000000..b7dda18d --- /dev/null +++ b/enumeration/progress.go @@ -0,0 +1,5 @@ +package enumeration + +type ProgressCounter interface { + Inc() +} diff --git a/enumeration/remote/alerts/alerts.go b/enumeration/remote/alerts/alerts.go new file mode 100644 index 00000000..180369aa --- /dev/null +++ b/enumeration/remote/alerts/alerts.go @@ -0,0 +1,97 @@ +package alerts + +import ( + "fmt" + + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/common" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + + "github.com/sirupsen/logrus" +) + +type ScanningPhase int + +const ( + EnumerationPhase ScanningPhase = iota + DetailsFetchingPhase +) + +type RemoteAccessDeniedAlert struct { + message string + provider string + scanningPhase ScanningPhase +} + +func NewRemoteAccessDeniedAlert(provider string, scanErr *remoteerror.ResourceScanningError, scanningPhase ScanningPhase) *RemoteAccessDeniedAlert { + var message string + switch scanningPhase { + case EnumerationPhase: + message = fmt.Sprintf( + "Ignoring %s from drift calculation: Listing %s is forbidden: %s", + scanErr.Resource(), + scanErr.ListedTypeError(), + scanErr.RootCause().Error(), + ) + case DetailsFetchingPhase: + message = fmt.Sprintf( + "Ignoring %s from drift calculation: Reading details of %s is forbidden: %s", + scanErr.Resource(), + scanErr.ListedTypeError(), + scanErr.RootCause().Error(), + ) + default: + message = fmt.Sprintf( + "Ignoring %s from drift calculation: %s", + scanErr.Resource(), + scanErr.RootCause().Error(), + ) + } + return &RemoteAccessDeniedAlert{message, provider, scanningPhase} +} + +func (e *RemoteAccessDeniedAlert) Message() string { + return e.message +} + +func (e *RemoteAccessDeniedAlert) ShouldIgnoreResource() bool { + return true +} + +func (e *RemoteAccessDeniedAlert) GetProviderMessage() string { + var message string + if e.scanningPhase == DetailsFetchingPhase { + message = "It seems that we got access denied exceptions while reading details of resources.\n" + } + if e.scanningPhase == EnumerationPhase { + message = "It seems that we got access denied exceptions while listing resources.\n" + } + + switch e.provider { + case common.RemoteGithubTerraform: + message += "Please be sure that your Github token has the right permissions, check the last up-to-date documentation there: https://docs.driftctl.com/github/policy" + case common.RemoteAWSTerraform: + message += "The latest minimal read-only IAM policy for driftctl is always available here, please update yours: https://docs.driftctl.com/aws/policy" + case common.RemoteGoogleTerraform: + message += "Please ensure that you have configured the required roles, please check our documentation at https://docs.driftctl.com/google/policy" + default: + return "" + } + return message +} + +func sendRemoteAccessDeniedAlert(provider string, alerter alerter.AlerterInterface, listError *remoteerror.ResourceScanningError, p ScanningPhase) { + logrus.WithFields(logrus.Fields{ + "resource": listError.Resource(), + "listed_type": listError.ListedTypeError(), + }).Debugf("Got an access denied error: %+v", listError.Error()) + alerter.SendAlert(listError.Resource(), NewRemoteAccessDeniedAlert(provider, listError, p)) +} + +func SendEnumerationAlert(provider string, alerter alerter.AlerterInterface, listError *remoteerror.ResourceScanningError) { + sendRemoteAccessDeniedAlert(provider, alerter, listError, EnumerationPhase) +} + +func SendDetailsFetchingAlert(provider string, alerter alerter.AlerterInterface, listError *remoteerror.ResourceScanningError) { + sendRemoteAccessDeniedAlert(provider, alerter, listError, DetailsFetchingPhase) +} diff --git a/enumeration/remote/aws/api_gateway_account_enumerator.go b/enumeration/remote/aws/api_gateway_account_enumerator.go new file mode 100644 index 00000000..857dc23b --- /dev/null +++ b/enumeration/remote/aws/api_gateway_account_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayAccountEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayAccountEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayAccountEnumerator { + return &ApiGatewayAccountEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayAccountEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayAccountResourceType +} + +func (e *ApiGatewayAccountEnumerator) Enumerate() ([]*resource.Resource, error) { + account, err := e.repository.GetAccount() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, 1) + + if account != nil { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + "api-gateway-account", + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_api_key_enumerator.go b/enumeration/remote/aws/api_gateway_api_key_enumerator.go new file mode 100644 index 00000000..88bca526 --- /dev/null +++ b/enumeration/remote/aws/api_gateway_api_key_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayApiKeyEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayApiKeyEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayApiKeyEnumerator { + return &ApiGatewayApiKeyEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayApiKeyEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayApiKeyResourceType +} + +func (e *ApiGatewayApiKeyEnumerator) Enumerate() ([]*resource.Resource, error) { + keys, err := e.repository.ListAllApiKeys() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(keys)) + + for _, key := range keys { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *key.Id, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_authorizer_enumerator.go b/enumeration/remote/aws/api_gateway_authorizer_enumerator.go new file mode 100644 index 00000000..d69372cc --- /dev/null +++ b/enumeration/remote/aws/api_gateway_authorizer_enumerator.go @@ -0,0 +1,56 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayAuthorizerEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayAuthorizerEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayAuthorizerEnumerator { + return &ApiGatewayAuthorizerEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayAuthorizerEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayAuthorizerResourceType +} + +func (e *ApiGatewayAuthorizerEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllRestApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, api := range apis { + a := api + authorizers, err := e.repository.ListAllRestApiAuthorizers(*a.Id) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, authorizer := range authorizers { + au := authorizer + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *au.Id, + map[string]interface{}{}, + ), + ) + } + + } + + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_base_path_mapping_enumerator.go b/enumeration/remote/aws/api_gateway_base_path_mapping_enumerator.go new file mode 100644 index 00000000..c9a3247a --- /dev/null +++ b/enumeration/remote/aws/api_gateway_base_path_mapping_enumerator.go @@ -0,0 +1,64 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "strings" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayBasePathMappingEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayBasePathMappingEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayBasePathMappingEnumerator { + return &ApiGatewayBasePathMappingEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayBasePathMappingEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayBasePathMappingResourceType +} + +func (e *ApiGatewayBasePathMappingEnumerator) Enumerate() ([]*resource.Resource, error) { + domainNames, err := e.repository.ListAllDomainNames() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayDomainNameResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, domainName := range domainNames { + d := domainName + mappings, err := e.repository.ListAllDomainNameBasePathMappings(*d.DomainName) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, mapping := range mappings { + m := mapping + + basePath := "" + if m.BasePath != nil && *m.BasePath != "(none)" { + basePath = *m.BasePath + } + + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + strings.Join([]string{*d.DomainName, basePath}, "/"), + map[string]interface{}{}, + ), + ) + } + + } + + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_domain_name_enumerator.go b/enumeration/remote/aws/api_gateway_domain_name_enumerator.go new file mode 100644 index 00000000..d24fc9da --- /dev/null +++ b/enumeration/remote/aws/api_gateway_domain_name_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayDomainNameEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayDomainNameEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayDomainNameEnumerator { + return &ApiGatewayDomainNameEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayDomainNameEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayDomainNameResourceType +} + +func (e *ApiGatewayDomainNameEnumerator) Enumerate() ([]*resource.Resource, error) { + domainNames, err := e.repository.ListAllDomainNames() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(domainNames)) + + for _, domainName := range domainNames { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *domainName.DomainName, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_gateway_response_enumerator.go b/enumeration/remote/aws/api_gateway_gateway_response_enumerator.go new file mode 100644 index 00000000..828ba3c5 --- /dev/null +++ b/enumeration/remote/aws/api_gateway_gateway_response_enumerator.go @@ -0,0 +1,57 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "strings" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayGatewayResponseEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayGatewayResponseEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayGatewayResponseEnumerator { + return &ApiGatewayGatewayResponseEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayGatewayResponseEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayGatewayResponseResourceType +} + +func (e *ApiGatewayGatewayResponseEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllRestApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, api := range apis { + a := api + gtwResponses, err := e.repository.ListAllRestApiGatewayResponses(*a.Id) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, gtwResponse := range gtwResponses { + g := gtwResponse + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + strings.Join([]string{"aggr", *a.Id, *g.ResponseType}, "-"), + map[string]interface{}{}, + ), + ) + } + + } + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_integration_enumerator.go b/enumeration/remote/aws/api_gateway_integration_enumerator.go new file mode 100644 index 00000000..ef2089d3 --- /dev/null +++ b/enumeration/remote/aws/api_gateway_integration_enumerator.go @@ -0,0 +1,59 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "strings" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayIntegrationEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayIntegrationEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayIntegrationEnumerator { + return &ApiGatewayIntegrationEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayIntegrationEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayIntegrationResourceType +} + +func (e *ApiGatewayIntegrationEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllRestApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, api := range apis { + a := api + resources, err := e.repository.ListAllRestApiResources(*a.Id) + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayResourceResourceType) + } + + for _, resource := range resources { + r := resource + for httpMethod := range r.ResourceMethods { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + strings.Join([]string{"agi", *a.Id, *r.Id, httpMethod}, "-"), + map[string]interface{}{}, + ), + ) + } + } + } + + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_integration_response_enumerator.go b/enumeration/remote/aws/api_gateway_integration_response_enumerator.go new file mode 100644 index 00000000..3ede0317 --- /dev/null +++ b/enumeration/remote/aws/api_gateway_integration_response_enumerator.go @@ -0,0 +1,63 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "strings" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayIntegrationResponseEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayIntegrationResponseEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayIntegrationResponseEnumerator { + return &ApiGatewayIntegrationResponseEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayIntegrationResponseEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayIntegrationResponseResourceType +} + +func (e *ApiGatewayIntegrationResponseEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllRestApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, api := range apis { + a := api + resources, err := e.repository.ListAllRestApiResources(*a.Id) + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayResourceResourceType) + } + + for _, resource := range resources { + r := resource + for httpMethod, method := range r.ResourceMethods { + if method.MethodIntegration != nil { + for statusCode := range method.MethodIntegration.IntegrationResponses { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + strings.Join([]string{"agir", *a.Id, *r.Id, httpMethod, statusCode}, "-"), + map[string]interface{}{}, + ), + ) + } + } + } + } + } + + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_method_enumerator.go b/enumeration/remote/aws/api_gateway_method_enumerator.go new file mode 100644 index 00000000..15034e27 --- /dev/null +++ b/enumeration/remote/aws/api_gateway_method_enumerator.go @@ -0,0 +1,59 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "strings" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayMethodEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayMethodEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayMethodEnumerator { + return &ApiGatewayMethodEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayMethodEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayMethodResourceType +} + +func (e *ApiGatewayMethodEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllRestApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, api := range apis { + a := api + resources, err := e.repository.ListAllRestApiResources(*a.Id) + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayResourceResourceType) + } + + for _, resource := range resources { + r := resource + for method := range r.ResourceMethods { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + strings.Join([]string{"agm", *a.Id, *r.Id, method}, "-"), + map[string]interface{}{}, + ), + ) + } + } + } + + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_method_response_enumerator.go b/enumeration/remote/aws/api_gateway_method_response_enumerator.go new file mode 100644 index 00000000..c0cee595 --- /dev/null +++ b/enumeration/remote/aws/api_gateway_method_response_enumerator.go @@ -0,0 +1,61 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "strings" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayMethodResponseEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayMethodResponseEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayMethodResponseEnumerator { + return &ApiGatewayMethodResponseEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayMethodResponseEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayMethodResponseResourceType +} + +func (e *ApiGatewayMethodResponseEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllRestApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, api := range apis { + a := api + resources, err := e.repository.ListAllRestApiResources(*a.Id) + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayResourceResourceType) + } + + for _, resource := range resources { + r := resource + for httpMethod, method := range r.ResourceMethods { + for statusCode := range method.MethodResponses { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + strings.Join([]string{"agmr", *a.Id, *r.Id, httpMethod, statusCode}, "-"), + map[string]interface{}{}, + ), + ) + } + } + } + } + + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_method_settings_enumerator.go b/enumeration/remote/aws/api_gateway_method_settings_enumerator.go new file mode 100644 index 00000000..254359f2 --- /dev/null +++ b/enumeration/remote/aws/api_gateway_method_settings_enumerator.go @@ -0,0 +1,59 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "strings" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayMethodSettingsEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayMethodSettingsEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayMethodSettingsEnumerator { + return &ApiGatewayMethodSettingsEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayMethodSettingsEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayMethodSettingsResourceType +} + +func (e *ApiGatewayMethodSettingsEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllRestApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, api := range apis { + a := api + stages, err := e.repository.ListAllRestApiStages(*a.Id) + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayStageResourceType) + } + + for _, stage := range stages { + s := stage + for methodPath := range s.MethodSettings { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + strings.Join([]string{*a.Id, *s.StageName, methodPath}, "-"), + map[string]interface{}{}, + ), + ) + } + } + } + + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_model_enumerator.go b/enumeration/remote/aws/api_gateway_model_enumerator.go new file mode 100644 index 00000000..e2a3b53e --- /dev/null +++ b/enumeration/remote/aws/api_gateway_model_enumerator.go @@ -0,0 +1,55 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayModelEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayModelEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayModelEnumerator { + return &ApiGatewayModelEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayModelEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayModelResourceType +} + +func (e *ApiGatewayModelEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllRestApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, api := range apis { + a := api + models, err := e.repository.ListAllRestApiModels(*a.Id) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, model := range models { + m := model + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *m.Id, + map[string]interface{}{}, + ), + ) + } + } + + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_request_validator_enumerator.go b/enumeration/remote/aws/api_gateway_request_validator_enumerator.go new file mode 100644 index 00000000..d886bfc9 --- /dev/null +++ b/enumeration/remote/aws/api_gateway_request_validator_enumerator.go @@ -0,0 +1,55 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayRequestValidatorEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayRequestValidatorEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayRequestValidatorEnumerator { + return &ApiGatewayRequestValidatorEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayRequestValidatorEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayRequestValidatorResourceType +} + +func (e *ApiGatewayRequestValidatorEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllRestApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, api := range apis { + a := api + requestValidators, err := e.repository.ListAllRestApiRequestValidators(*a.Id) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, requestValidator := range requestValidators { + r := requestValidator + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *r.Id, + map[string]interface{}{}, + ), + ) + } + + } + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_resource_enumerator.go b/enumeration/remote/aws/api_gateway_resource_enumerator.go new file mode 100644 index 00000000..ea1bd800 --- /dev/null +++ b/enumeration/remote/aws/api_gateway_resource_enumerator.go @@ -0,0 +1,58 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayResourceEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayResourceEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayResourceEnumerator { + return &ApiGatewayResourceEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayResourceEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayResourceResourceType +} + +func (e *ApiGatewayResourceEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllRestApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, api := range apis { + a := api + resources, err := e.repository.ListAllRestApiResources(*a.Id) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, resource := range resources { + r := resource + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *r.Id, + map[string]interface{}{ + "rest_api_id": *a.Id, + "path": *r.Path, + }, + ), + ) + } + } + + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_rest_api_enumerator.go b/enumeration/remote/aws/api_gateway_rest_api_enumerator.go new file mode 100644 index 00000000..44588d25 --- /dev/null +++ b/enumeration/remote/aws/api_gateway_rest_api_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayRestApiEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayRestApiEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayRestApiEnumerator { + return &ApiGatewayRestApiEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayRestApiEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayRestApiResourceType +} + +func (e *ApiGatewayRestApiEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllRestApis() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(apis)) + + for _, api := range apis { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *api.Id, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_rest_api_policy_enumerator.go b/enumeration/remote/aws/api_gateway_rest_api_policy_enumerator.go new file mode 100644 index 00000000..9cb109bb --- /dev/null +++ b/enumeration/remote/aws/api_gateway_rest_api_policy_enumerator.go @@ -0,0 +1,49 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayRestApiPolicyEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayRestApiPolicyEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayRestApiPolicyEnumerator { + return &ApiGatewayRestApiPolicyEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayRestApiPolicyEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayRestApiPolicyResourceType +} + +func (e *ApiGatewayRestApiPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllRestApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, api := range apis { + a := api + if a.Policy == nil || *a.Policy == "" { + continue + } + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *a.Id, + map[string]interface{}{}, + ), + ) + } + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_stage_enumerator.go b/enumeration/remote/aws/api_gateway_stage_enumerator.go new file mode 100644 index 00000000..34860c9c --- /dev/null +++ b/enumeration/remote/aws/api_gateway_stage_enumerator.go @@ -0,0 +1,57 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "strings" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayStageEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayStageEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayStageEnumerator { + return &ApiGatewayStageEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayStageEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayStageResourceType +} + +func (e *ApiGatewayStageEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllRestApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, api := range apis { + a := api + stages, err := e.repository.ListAllRestApiStages(*a.Id) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, stage := range stages { + s := stage + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + strings.Join([]string{"ags", *a.Id, *s.StageName}, "-"), + map[string]interface{}{}, + ), + ) + } + + } + return results, err +} diff --git a/enumeration/remote/aws/api_gateway_vpc_link_enumerator.go b/enumeration/remote/aws/api_gateway_vpc_link_enumerator.go new file mode 100644 index 00000000..a66ace88 --- /dev/null +++ b/enumeration/remote/aws/api_gateway_vpc_link_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayVpcLinkEnumerator struct { + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayVpcLinkEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayVpcLinkEnumerator { + return &ApiGatewayVpcLinkEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayVpcLinkEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayVpcLinkResourceType +} + +func (e *ApiGatewayVpcLinkEnumerator) Enumerate() ([]*resource.Resource, error) { + vpcLinks, err := e.repository.ListAllVpcLinks() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(vpcLinks)) + + for _, vpcLink := range vpcLinks { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *vpcLink.Id, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/apigatewayv2_api_enumerator.go b/enumeration/remote/aws/apigatewayv2_api_enumerator.go new file mode 100644 index 00000000..34333542 --- /dev/null +++ b/enumeration/remote/aws/apigatewayv2_api_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayV2ApiEnumerator struct { + repository repository.ApiGatewayV2Repository + factory resource.ResourceFactory +} + +func NewApiGatewayV2ApiEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2ApiEnumerator { + return &ApiGatewayV2ApiEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayV2ApiEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayV2ApiResourceType +} + +func (e *ApiGatewayV2ApiEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllApis() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(apis)) + + for _, api := range apis { + a := api + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *a.ApiId, + map[string]interface{}{}, + ), + ) + } + return results, err +} diff --git a/enumeration/remote/aws/apigatewayv2_authorizer_enumerator.go b/enumeration/remote/aws/apigatewayv2_authorizer_enumerator.go new file mode 100644 index 00000000..7370abf8 --- /dev/null +++ b/enumeration/remote/aws/apigatewayv2_authorizer_enumerator.go @@ -0,0 +1,56 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayV2AuthorizerEnumerator struct { + repository repository.ApiGatewayV2Repository + factory resource.ResourceFactory +} + +func NewApiGatewayV2AuthorizerEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2AuthorizerEnumerator { + return &ApiGatewayV2AuthorizerEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayV2AuthorizerEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayV2AuthorizerResourceType +} + +func (e *ApiGatewayV2AuthorizerEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, api := range apis { + a := api + authorizers, err := e.repository.ListAllApiAuthorizers(*a.ApiId) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, authorizer := range authorizers { + au := authorizer + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *au.AuthorizerId, + map[string]interface{}{}, + ), + ) + } + + } + + return results, err +} diff --git a/enumeration/remote/aws/apigatewayv2_deployment_enumerator.go b/enumeration/remote/aws/apigatewayv2_deployment_enumerator.go new file mode 100644 index 00000000..43eef1fe --- /dev/null +++ b/enumeration/remote/aws/apigatewayv2_deployment_enumerator.go @@ -0,0 +1,50 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayV2DeploymentEnumerator struct { + repository repository.ApiGatewayV2Repository + factory resource.ResourceFactory +} + +func NewApiGatewayV2DeploymentEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2DeploymentEnumerator { + return &ApiGatewayV2DeploymentEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayV2DeploymentEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayV2DeploymentResourceType +} + +func (e *ApiGatewayV2DeploymentEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) + } + + var results []*resource.Resource + for _, api := range apis { + deployments, err := e.repository.ListAllApiDeployments(api.ApiId) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + for _, deployment := range deployments { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *deployment.DeploymentId, + map[string]interface{}{}, + ), + ) + } + } + return results, err +} diff --git a/enumeration/remote/aws/apigatewayv2_domain_name_enumerator.go b/enumeration/remote/aws/apigatewayv2_domain_name_enumerator.go new file mode 100644 index 00000000..87315f4f --- /dev/null +++ b/enumeration/remote/aws/apigatewayv2_domain_name_enumerator.go @@ -0,0 +1,49 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayV2DomainNameEnumerator struct { + // AWS SDK list domain names endpoint from API Gateway v2 returns the + // same results as the v1 one, thus let's re-use the method from + // the API Gateway v1 + repository repository.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayV2DomainNameEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayV2DomainNameEnumerator { + return &ApiGatewayV2DomainNameEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayV2DomainNameEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayV2DomainNameResourceType +} + +func (e *ApiGatewayV2DomainNameEnumerator) Enumerate() ([]*resource.Resource, error) { + domainNames, err := e.repository.ListAllDomainNames() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(domainNames)) + + for _, domainName := range domainNames { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *domainName.DomainName, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/apigatewayv2_integration_enumerator.go b/enumeration/remote/aws/apigatewayv2_integration_enumerator.go new file mode 100644 index 00000000..57552d87 --- /dev/null +++ b/enumeration/remote/aws/apigatewayv2_integration_enumerator.go @@ -0,0 +1,63 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayV2IntegrationEnumerator struct { + repository repository.ApiGatewayV2Repository + factory resource.ResourceFactory +} + +func NewApiGatewayV2IntegrationEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2IntegrationEnumerator { + return &ApiGatewayV2IntegrationEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayV2IntegrationEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayV2IntegrationResourceType +} + +func (e *ApiGatewayV2IntegrationEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, a := range apis { + api := a + integrations, err := e.repository.ListAllApiIntegrations(*api.ApiId) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, integration := range integrations { + data := map[string]interface{}{ + "api_id": *api.ApiId, + "integration_type": *integration.IntegrationType, + } + + if integration.IntegrationMethod != nil { + // this is needed to discriminate in middleware. But it is nil when the type is mock... + data["integration_method"] = *integration.IntegrationMethod + } + + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *integration.IntegrationId, + data, + ), + ) + } + } + return results, err +} diff --git a/enumeration/remote/aws/apigatewayv2_integration_response_enumerator.go b/enumeration/remote/aws/apigatewayv2_integration_response_enumerator.go new file mode 100644 index 00000000..62f45850 --- /dev/null +++ b/enumeration/remote/aws/apigatewayv2_integration_response_enumerator.go @@ -0,0 +1,63 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayV2IntegrationResponseEnumerator struct { + repository repository.ApiGatewayV2Repository + factory resource.ResourceFactory +} + +func NewApiGatewayV2IntegrationResponseEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2IntegrationResponseEnumerator { + return &ApiGatewayV2IntegrationResponseEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayV2IntegrationResponseEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayV2IntegrationResponseResourceType +} + +func (e *ApiGatewayV2IntegrationResponseEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, a := range apis { + apiID := *a.ApiId + integrations, err := e.repository.ListAllApiIntegrations(apiID) + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2IntegrationResourceType) + } + + for _, integration := range integrations { + integrationId := *integration.IntegrationId + responses, err := e.repository.ListAllApiIntegrationResponses(apiID, integrationId) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, resp := range responses { + responseId := *resp.IntegrationResponseId + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + responseId, + map[string]interface{}{}, + ), + ) + } + + } + } + return results, err +} diff --git a/enumeration/remote/aws/apigatewayv2_mapping_enumerator.go b/enumeration/remote/aws/apigatewayv2_mapping_enumerator.go new file mode 100644 index 00000000..01e812a0 --- /dev/null +++ b/enumeration/remote/aws/apigatewayv2_mapping_enumerator.go @@ -0,0 +1,61 @@ +package aws + +import ( + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayV2MappingEnumerator struct { + repository repository2.ApiGatewayV2Repository + repositoryV1 repository2.ApiGatewayRepository + factory resource.ResourceFactory +} + +func NewApiGatewayV2MappingEnumerator(repo repository2.ApiGatewayV2Repository, repov1 repository2.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayV2MappingEnumerator { + return &ApiGatewayV2MappingEnumerator{ + repository: repo, + repositoryV1: repov1, + factory: factory, + } +} + +func (e *ApiGatewayV2MappingEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayV2MappingResourceType +} + +func (e *ApiGatewayV2MappingEnumerator) Enumerate() ([]*resource.Resource, error) { + domainNames, err := e.repositoryV1.ListAllDomainNames() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayDomainNameResourceType) + } + + var results []*resource.Resource + for _, domainName := range domainNames { + mappings, err := e.repository.ListAllApiMappings(*domainName.DomainName) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + for _, mapping := range mappings { + attrs := make(map[string]interface{}) + + if mapping.ApiId != nil { + attrs["api_id"] = *mapping.ApiId + } + if mapping.Stage != nil { + attrs["stage"] = *mapping.Stage + } + + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *mapping.ApiMappingId, + attrs, + ), + ) + } + } + return results, err +} diff --git a/enumeration/remote/aws/apigatewayv2_model_enumerator.go b/enumeration/remote/aws/apigatewayv2_model_enumerator.go new file mode 100644 index 00000000..0704f290 --- /dev/null +++ b/enumeration/remote/aws/apigatewayv2_model_enumerator.go @@ -0,0 +1,52 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayV2ModelEnumerator struct { + repository repository.ApiGatewayV2Repository + factory resource.ResourceFactory +} + +func NewApiGatewayV2ModelEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2ModelEnumerator { + return &ApiGatewayV2ModelEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayV2ModelEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayV2ModelResourceType +} + +func (e *ApiGatewayV2ModelEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) + } + + var results []*resource.Resource + for _, api := range apis { + models, err := e.repository.ListAllApiModels(*api.ApiId) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + for _, model := range models { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *model.ModelId, + map[string]interface{}{ + "name": *model.Name, + }, + ), + ) + } + } + return results, err +} diff --git a/enumeration/remote/aws/apigatewayv2_route_enumerator.go b/enumeration/remote/aws/apigatewayv2_route_enumerator.go new file mode 100644 index 00000000..b0cfc02f --- /dev/null +++ b/enumeration/remote/aws/apigatewayv2_route_enumerator.go @@ -0,0 +1,53 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayV2RouteEnumerator struct { + repository repository.ApiGatewayV2Repository + factory resource.ResourceFactory +} + +func NewApiGatewayV2RouteEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2RouteEnumerator { + return &ApiGatewayV2RouteEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayV2RouteEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayV2RouteResourceType +} + +func (e *ApiGatewayV2RouteEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) + } + + var results []*resource.Resource + for _, api := range apis { + routes, err := e.repository.ListAllApiRoutes(api.ApiId) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + for _, route := range routes { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *route.RouteId, + map[string]interface{}{ + "api_id": *api.ApiId, + "route_key": *route.RouteKey, + }, + ), + ) + } + } + return results, err +} diff --git a/enumeration/remote/aws/apigatewayv2_route_response_enumerator.go b/enumeration/remote/aws/apigatewayv2_route_response_enumerator.go new file mode 100644 index 00000000..9c6d5a62 --- /dev/null +++ b/enumeration/remote/aws/apigatewayv2_route_response_enumerator.go @@ -0,0 +1,59 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayV2RouteResponseEnumerator struct { + repository repository.ApiGatewayV2Repository + factory resource.ResourceFactory +} + +func NewApiGatewayV2RouteResponseEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2RouteResponseEnumerator { + return &ApiGatewayV2RouteResponseEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayV2RouteResponseEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayV2RouteResponseResourceType +} + +func (e *ApiGatewayV2RouteResponseEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) + } + + var results []*resource.Resource + for _, api := range apis { + a := api + routes, err := e.repository.ListAllApiRoutes(a.ApiId) + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2RouteResourceType) + } + for _, route := range routes { + r := route + responses, err := e.repository.ListAllApiRouteResponses(*a.ApiId, *r.RouteId) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + for _, response := range responses { + res := response + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *res.RouteResponseId, + map[string]interface{}{}, + ), + ) + } + } + } + return results, err +} diff --git a/enumeration/remote/aws/apigatewayv2_stage_enumerator.go b/enumeration/remote/aws/apigatewayv2_stage_enumerator.go new file mode 100644 index 00000000..bdc14473 --- /dev/null +++ b/enumeration/remote/aws/apigatewayv2_stage_enumerator.go @@ -0,0 +1,54 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayV2StageEnumerator struct { + repository repository.ApiGatewayV2Repository + factory resource.ResourceFactory +} + +func NewApiGatewayV2StageEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2StageEnumerator { + return &ApiGatewayV2StageEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayV2StageEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayV2StageResourceType +} + +func (e *ApiGatewayV2StageEnumerator) Enumerate() ([]*resource.Resource, error) { + apis, err := e.repository.ListAllApis() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, api := range apis { + stages, err := e.repository.ListAllApiStages(*api.ApiId) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, stage := range stages { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *stage.StageName, + map[string]interface{}{}, + ), + ) + } + + } + + return results, err +} diff --git a/enumeration/remote/aws/apigatewayv2_vpc_link_enumerator.go b/enumeration/remote/aws/apigatewayv2_vpc_link_enumerator.go new file mode 100644 index 00000000..5edef1aa --- /dev/null +++ b/enumeration/remote/aws/apigatewayv2_vpc_link_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ApiGatewayV2VpcLinkEnumerator struct { + repository repository.ApiGatewayV2Repository + factory resource.ResourceFactory +} + +func NewApiGatewayV2VpcLinkEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2VpcLinkEnumerator { + return &ApiGatewayV2VpcLinkEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ApiGatewayV2VpcLinkEnumerator) SupportedType() resource.ResourceType { + return aws.AwsApiGatewayV2VpcLinkResourceType +} + +func (e *ApiGatewayV2VpcLinkEnumerator) Enumerate() ([]*resource.Resource, error) { + vpcLinks, err := e.repository.ListAllVpcLinks() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(vpcLinks)) + + for _, vpcLink := range vpcLinks { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *vpcLink.VpcLinkId, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/appautoscaling_policy_enumerator.go b/enumeration/remote/aws/appautoscaling_policy_enumerator.go new file mode 100644 index 00000000..4662b8b9 --- /dev/null +++ b/enumeration/remote/aws/appautoscaling_policy_enumerator.go @@ -0,0 +1,53 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type AppAutoscalingPolicyEnumerator struct { + repository repository.AppAutoScalingRepository + factory resource.ResourceFactory +} + +func NewAppAutoscalingPolicyEnumerator(repository repository.AppAutoScalingRepository, factory resource.ResourceFactory) *AppAutoscalingPolicyEnumerator { + return &AppAutoscalingPolicyEnumerator{ + repository, + factory, + } +} + +func (e *AppAutoscalingPolicyEnumerator) SupportedType() resource.ResourceType { + return aws.AwsAppAutoscalingPolicyResourceType +} + +func (e *AppAutoscalingPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { + results := make([]*resource.Resource, 0) + + for _, ns := range e.repository.ServiceNamespaceValues() { + policies, err := e.repository.DescribeScalingPolicies(ns) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, policy := range policies { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *policy.PolicyName, + map[string]interface{}{ + "name": *policy.PolicyName, + "resource_id": *policy.ResourceId, + "scalable_dimension": *policy.ScalableDimension, + "service_namespace": *policy.ServiceNamespace, + }, + ), + ) + } + } + + return results, nil +} diff --git a/enumeration/remote/aws/appautoscaling_scheduled_action_enumerator.go b/enumeration/remote/aws/appautoscaling_scheduled_action_enumerator.go new file mode 100644 index 00000000..3cb182e7 --- /dev/null +++ b/enumeration/remote/aws/appautoscaling_scheduled_action_enumerator.go @@ -0,0 +1,50 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "strings" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type AppAutoscalingScheduledActionEnumerator struct { + repository repository.AppAutoScalingRepository + factory resource.ResourceFactory +} + +func NewAppAutoscalingScheduledActionEnumerator(repository repository.AppAutoScalingRepository, factory resource.ResourceFactory) *AppAutoscalingScheduledActionEnumerator { + return &AppAutoscalingScheduledActionEnumerator{ + repository, + factory, + } +} + +func (e *AppAutoscalingScheduledActionEnumerator) SupportedType() resource.ResourceType { + return aws.AwsAppAutoscalingScheduledActionResourceType +} + +func (e *AppAutoscalingScheduledActionEnumerator) Enumerate() ([]*resource.Resource, error) { + results := make([]*resource.Resource, 0) + + for _, ns := range e.repository.ServiceNamespaceValues() { + actions, err := e.repository.DescribeScheduledActions(ns) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, action := range actions { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + strings.Join([]string{*action.ScheduledActionName, *action.ServiceNamespace, *action.ResourceId}, "-"), + map[string]interface{}{}, + ), + ) + } + } + + return results, nil +} diff --git a/enumeration/remote/aws/appautoscaling_target_enumerator.go b/enumeration/remote/aws/appautoscaling_target_enumerator.go new file mode 100644 index 00000000..fbe039aa --- /dev/null +++ b/enumeration/remote/aws/appautoscaling_target_enumerator.go @@ -0,0 +1,55 @@ +package aws + +import ( + "github.com/aws/aws-sdk-go/service/applicationautoscaling" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type AppAutoscalingTargetEnumerator struct { + repository repository.AppAutoScalingRepository + factory resource.ResourceFactory +} + +func NewAppAutoscalingTargetEnumerator(repository repository.AppAutoScalingRepository, factory resource.ResourceFactory) *AppAutoscalingTargetEnumerator { + return &AppAutoscalingTargetEnumerator{ + repository, + factory, + } +} + +func (e *AppAutoscalingTargetEnumerator) SupportedType() resource.ResourceType { + return aws.AwsAppAutoscalingTargetResourceType +} + +func (e *AppAutoscalingTargetEnumerator) Enumerate() ([]*resource.Resource, error) { + targets := make([]*applicationautoscaling.ScalableTarget, 0) + + for _, ns := range e.repository.ServiceNamespaceValues() { + results, err := e.repository.DescribeScalableTargets(ns) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + targets = append(targets, results...) + } + + results := make([]*resource.Resource, 0, len(targets)) + + for _, target := range targets { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *target.ResourceId, + map[string]interface{}{ + "service_namespace": *target.ServiceNamespace, + "scalable_dimension": *target.ScalableDimension, + }, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/classic_loadbalancer_enumerator.go b/enumeration/remote/aws/classic_loadbalancer_enumerator.go new file mode 100644 index 00000000..7890f159 --- /dev/null +++ b/enumeration/remote/aws/classic_loadbalancer_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ClassicLoadBalancerEnumerator struct { + repository repository.ELBRepository + factory resource.ResourceFactory +} + +func NewClassicLoadBalancerEnumerator(repo repository.ELBRepository, factory resource.ResourceFactory) *ClassicLoadBalancerEnumerator { + return &ClassicLoadBalancerEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ClassicLoadBalancerEnumerator) SupportedType() resource.ResourceType { + return aws.AwsClassicLoadBalancerResourceType +} + +func (e *ClassicLoadBalancerEnumerator) Enumerate() ([]*resource.Resource, error) { + loadBalancers, err := e.repository.ListAllLoadBalancers() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(loadBalancers)) + + for _, lb := range loadBalancers { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *lb.LoadBalancerName, + map[string]interface{}{}, + ), + ) + } + + return results, nil +} diff --git a/pkg/remote/aws/client/mock_AwsClientFactoryInterface.go b/enumeration/remote/aws/client/mock_AwsClientFactoryInterface.go similarity index 100% rename from pkg/remote/aws/client/mock_AwsClientFactoryInterface.go rename to enumeration/remote/aws/client/mock_AwsClientFactoryInterface.go diff --git a/pkg/remote/aws/client/s3_client_factory.go b/enumeration/remote/aws/client/s3_client_factory.go similarity index 100% rename from pkg/remote/aws/client/s3_client_factory.go rename to enumeration/remote/aws/client/s3_client_factory.go diff --git a/enumeration/remote/aws/cloudformation_stack_enumerator.go b/enumeration/remote/aws/cloudformation_stack_enumerator.go new file mode 100644 index 00000000..5be5e519 --- /dev/null +++ b/enumeration/remote/aws/cloudformation_stack_enumerator.go @@ -0,0 +1,60 @@ +package aws + +import ( + "github.com/aws/aws-sdk-go/service/cloudformation" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type CloudformationStackEnumerator struct { + repository repository.CloudformationRepository + factory resource.ResourceFactory +} + +func NewCloudformationStackEnumerator(repo repository.CloudformationRepository, factory resource.ResourceFactory) *CloudformationStackEnumerator { + return &CloudformationStackEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *CloudformationStackEnumerator) SupportedType() resource.ResourceType { + return aws.AwsCloudformationStackResourceType +} + +func (e *CloudformationStackEnumerator) Enumerate() ([]*resource.Resource, error) { + stacks, err := e.repository.ListAllStacks() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(stacks)) + + for _, stack := range stacks { + attrs := map[string]interface{}{} + if stack.Parameters != nil && len(stack.Parameters) > 0 { + attrs["parameters"] = flattenParameters(stack.Parameters) + } + + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *stack.StackId, + attrs, + ), + ) + } + + return results, err +} + +func flattenParameters(parameters []*cloudformation.Parameter) interface{} { + params := make(map[string]interface{}, len(parameters)) + for _, p := range parameters { + params[*p.ParameterKey] = *p.ParameterValue + } + return params +} diff --git a/enumeration/remote/aws/cloudfront_distribution_enumerator.go b/enumeration/remote/aws/cloudfront_distribution_enumerator.go new file mode 100644 index 00000000..58aa7c1b --- /dev/null +++ b/enumeration/remote/aws/cloudfront_distribution_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type CloudfrontDistributionEnumerator struct { + repository repository.CloudfrontRepository + factory resource.ResourceFactory +} + +func NewCloudfrontDistributionEnumerator(repo repository.CloudfrontRepository, factory resource.ResourceFactory) *CloudfrontDistributionEnumerator { + return &CloudfrontDistributionEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *CloudfrontDistributionEnumerator) SupportedType() resource.ResourceType { + return aws.AwsCloudfrontDistributionResourceType +} + +func (e *CloudfrontDistributionEnumerator) Enumerate() ([]*resource.Resource, error) { + distributions, err := e.repository.ListAllDistributions() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(distributions)) + + for _, distribution := range distributions { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *distribution.Id, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/default_vpc_enumerator.go b/enumeration/remote/aws/default_vpc_enumerator.go new file mode 100644 index 00000000..b49b8fb0 --- /dev/null +++ b/enumeration/remote/aws/default_vpc_enumerator.go @@ -0,0 +1,47 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource/aws" + + "github.com/snyk/driftctl/enumeration/resource" +) + +type DefaultVPCEnumerator struct { + repo repository.EC2Repository + factory resource.ResourceFactory +} + +func NewDefaultVPCEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *DefaultVPCEnumerator { + return &DefaultVPCEnumerator{ + repo, + factory, + } +} + +func (e *DefaultVPCEnumerator) SupportedType() resource.ResourceType { + return aws.AwsDefaultVpcResourceType +} + +func (e *DefaultVPCEnumerator) Enumerate() ([]*resource.Resource, error) { + _, defaultVPCs, err := e.repo.ListAllVPCs() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(defaultVPCs)) + + for _, item := range defaultVPCs { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *item.VpcId, + map[string]interface{}{}, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/dynamodb_table_enumerator.go b/enumeration/remote/aws/dynamodb_table_enumerator.go new file mode 100644 index 00000000..37cdaffe --- /dev/null +++ b/enumeration/remote/aws/dynamodb_table_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type DynamoDBTableEnumerator struct { + repository repository.DynamoDBRepository + factory resource.ResourceFactory +} + +func NewDynamoDBTableEnumerator(repository repository.DynamoDBRepository, factory resource.ResourceFactory) *DynamoDBTableEnumerator { + return &DynamoDBTableEnumerator{ + repository, + factory, + } +} + +func (e *DynamoDBTableEnumerator) SupportedType() resource.ResourceType { + return aws.AwsDynamodbTableResourceType +} + +func (e *DynamoDBTableEnumerator) Enumerate() ([]*resource.Resource, error) { + tables, err := e.repository.ListAllTables() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(tables)) + + for _, table := range tables { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *table, + map[string]interface{}{}, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/ebs_encryption_by_default_enumerator.go b/enumeration/remote/aws/ebs_encryption_by_default_enumerator.go new file mode 100644 index 00000000..2363ca08 --- /dev/null +++ b/enumeration/remote/aws/ebs_encryption_by_default_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2EbsEncryptionByDefaultEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2EbsEncryptionByDefaultEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2EbsEncryptionByDefaultEnumerator { + return &EC2EbsEncryptionByDefaultEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2EbsEncryptionByDefaultEnumerator) SupportedType() resource.ResourceType { + return aws.AwsEbsEncryptionByDefaultResourceType +} + +func (e *EC2EbsEncryptionByDefaultEnumerator) Enumerate() ([]*resource.Resource, error) { + enabled, err := e.repository.IsEbsEncryptionEnabledByDefault() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0) + + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + "ebs_encryption_default", + map[string]interface{}{ + "enabled": enabled, + }, + ), + ) + + return results, err +} diff --git a/enumeration/remote/aws/ec2_ami_enumerator.go b/enumeration/remote/aws/ec2_ami_enumerator.go new file mode 100644 index 00000000..ed409fb3 --- /dev/null +++ b/enumeration/remote/aws/ec2_ami_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2AmiEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2AmiEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2AmiEnumerator { + return &EC2AmiEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2AmiEnumerator) SupportedType() resource.ResourceType { + return aws.AwsAmiResourceType +} + +func (e *EC2AmiEnumerator) Enumerate() ([]*resource.Resource, error) { + images, err := e.repository.ListAllImages() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(images)) + + for _, image := range images { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *image.ImageId, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/ec2_default_network_acl_enumerator.go b/enumeration/remote/aws/ec2_default_network_acl_enumerator.go new file mode 100644 index 00000000..1490211e --- /dev/null +++ b/enumeration/remote/aws/ec2_default_network_acl_enumerator.go @@ -0,0 +1,50 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2DefaultNetworkACLEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2DefaultNetworkACLEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2DefaultNetworkACLEnumerator { + return &EC2DefaultNetworkACLEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2DefaultNetworkACLEnumerator) SupportedType() resource.ResourceType { + return aws.AwsDefaultNetworkACLResourceType +} + +func (e *EC2DefaultNetworkACLEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.ListAllNetworkACLs() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + // Do not handle non-default network acl since it is a dedicated resource + if !*res.IsDefault { + continue + } + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *res.NetworkAclId, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/ec2_default_route_table_enumerator.go b/enumeration/remote/aws/ec2_default_route_table_enumerator.go new file mode 100644 index 00000000..75ec8036 --- /dev/null +++ b/enumeration/remote/aws/ec2_default_route_table_enumerator.go @@ -0,0 +1,50 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2DefaultRouteTableEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2DefaultRouteTableEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2DefaultRouteTableEnumerator { + return &EC2DefaultRouteTableEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2DefaultRouteTableEnumerator) SupportedType() resource.ResourceType { + return aws.AwsDefaultRouteTableResourceType +} + +func (e *EC2DefaultRouteTableEnumerator) Enumerate() ([]*resource.Resource, error) { + routeTables, err := e.repository.ListAllRouteTables() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + var results []*resource.Resource + + for _, routeTable := range routeTables { + if isMainRouteTable(routeTable) { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *routeTable.RouteTableId, + map[string]interface{}{ + "vpc_id": *routeTable.VpcId, + }, + ), + ) + } + } + + return results, err +} diff --git a/enumeration/remote/aws/ec2_default_subnet_enumerator.go b/enumeration/remote/aws/ec2_default_subnet_enumerator.go new file mode 100644 index 00000000..dbc5cef9 --- /dev/null +++ b/enumeration/remote/aws/ec2_default_subnet_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2DefaultSubnetEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2DefaultSubnetEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2DefaultSubnetEnumerator { + return &EC2DefaultSubnetEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2DefaultSubnetEnumerator) SupportedType() resource.ResourceType { + return aws.AwsDefaultSubnetResourceType +} + +func (e *EC2DefaultSubnetEnumerator) Enumerate() ([]*resource.Resource, error) { + _, defaultSubnets, err := e.repository.ListAllSubnets() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(defaultSubnets)) + + for _, subnet := range defaultSubnets { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *subnet.SubnetId, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/ec2_ebs_snapshot_enumerator.go b/enumeration/remote/aws/ec2_ebs_snapshot_enumerator.go new file mode 100644 index 00000000..20d69dd0 --- /dev/null +++ b/enumeration/remote/aws/ec2_ebs_snapshot_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2EbsSnapshotEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2EbsSnapshotEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2EbsSnapshotEnumerator { + return &EC2EbsSnapshotEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2EbsSnapshotEnumerator) SupportedType() resource.ResourceType { + return aws.AwsEbsSnapshotResourceType +} + +func (e *EC2EbsSnapshotEnumerator) Enumerate() ([]*resource.Resource, error) { + snapshots, err := e.repository.ListAllSnapshots() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(snapshots)) + + for _, snapshot := range snapshots { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *snapshot.SnapshotId, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/ec2_ebs_volume_enumerator.go b/enumeration/remote/aws/ec2_ebs_volume_enumerator.go new file mode 100644 index 00000000..7528b139 --- /dev/null +++ b/enumeration/remote/aws/ec2_ebs_volume_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2EbsVolumeEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2EbsVolumeEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2EbsVolumeEnumerator { + return &EC2EbsVolumeEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2EbsVolumeEnumerator) SupportedType() resource.ResourceType { + return aws.AwsEbsVolumeResourceType +} + +func (e *EC2EbsVolumeEnumerator) Enumerate() ([]*resource.Resource, error) { + volumes, err := e.repository.ListAllVolumes() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(volumes)) + + for _, volume := range volumes { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *volume.VolumeId, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/ec2_eip_association_enumerator.go b/enumeration/remote/aws/ec2_eip_association_enumerator.go new file mode 100644 index 00000000..7da15c1c --- /dev/null +++ b/enumeration/remote/aws/ec2_eip_association_enumerator.go @@ -0,0 +1,48 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2EipAssociationEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2EipAssociationEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2EipAssociationEnumerator { + return &EC2EipAssociationEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2EipAssociationEnumerator) SupportedType() resource.ResourceType { + return aws.AwsEipAssociationResourceType +} + +func (e *EC2EipAssociationEnumerator) Enumerate() ([]*resource.Resource, error) { + addresses, err := e.repository.ListAllAddressesAssociation() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(addresses)) + + for _, address := range addresses { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *address.AssociationId, + map[string]interface{}{ + "allocation_id": *address.AllocationId, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/ec2_eip_enumerator.go b/enumeration/remote/aws/ec2_eip_enumerator.go new file mode 100644 index 00000000..a4ace574 --- /dev/null +++ b/enumeration/remote/aws/ec2_eip_enumerator.go @@ -0,0 +1,51 @@ +package aws + +import ( + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2EipEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2EipEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2EipEnumerator { + return &EC2EipEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2EipEnumerator) SupportedType() resource.ResourceType { + return aws.AwsEipResourceType +} + +func (e *EC2EipEnumerator) Enumerate() ([]*resource.Resource, error) { + addresses, err := e.repository.ListAllAddresses() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(addresses)) + + for _, address := range addresses { + if address.AllocationId == nil { + logrus.Warn("Elastic IP does not have an allocation ID, ignoring") + continue + } + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *address.AllocationId, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/ec2_instance_enumerator.go b/enumeration/remote/aws/ec2_instance_enumerator.go new file mode 100644 index 00000000..0779759f --- /dev/null +++ b/enumeration/remote/aws/ec2_instance_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2InstanceEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2InstanceEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2InstanceEnumerator { + return &EC2InstanceEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2InstanceEnumerator) SupportedType() resource.ResourceType { + return aws.AwsInstanceResourceType +} + +func (e *EC2InstanceEnumerator) Enumerate() ([]*resource.Resource, error) { + instances, err := e.repository.ListAllInstances() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(instances)) + + for _, instance := range instances { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *instance.InstanceId, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/ec2_internet_gateway_enumerator.go b/enumeration/remote/aws/ec2_internet_gateway_enumerator.go new file mode 100644 index 00000000..82dedf7c --- /dev/null +++ b/enumeration/remote/aws/ec2_internet_gateway_enumerator.go @@ -0,0 +1,50 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2InternetGatewayEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2InternetGatewayEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2InternetGatewayEnumerator { + return &EC2InternetGatewayEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2InternetGatewayEnumerator) SupportedType() resource.ResourceType { + return aws.AwsInternetGatewayResourceType +} + +func (e *EC2InternetGatewayEnumerator) Enumerate() ([]*resource.Resource, error) { + internetGateways, err := e.repository.ListAllInternetGateways() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(internetGateways)) + + for _, internetGateway := range internetGateways { + data := map[string]interface{}{} + if len(internetGateway.Attachments) > 0 && internetGateway.Attachments[0].VpcId != nil { + data["vpc_id"] = *internetGateway.Attachments[0].VpcId + } + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *internetGateway.InternetGatewayId, + data, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/ec2_key_pair_enumerator.go b/enumeration/remote/aws/ec2_key_pair_enumerator.go new file mode 100644 index 00000000..9753a8cf --- /dev/null +++ b/enumeration/remote/aws/ec2_key_pair_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2KeyPairEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2KeyPairEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2KeyPairEnumerator { + return &EC2KeyPairEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2KeyPairEnumerator) SupportedType() resource.ResourceType { + return aws.AwsKeyPairResourceType +} + +func (e *EC2KeyPairEnumerator) Enumerate() ([]*resource.Resource, error) { + keyPairs, err := e.repository.ListAllKeyPairs() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(keyPairs)) + + for _, keyPair := range keyPairs { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *keyPair.KeyName, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/ec2_nat_gateway_enumerator.go b/enumeration/remote/aws/ec2_nat_gateway_enumerator.go new file mode 100644 index 00000000..1cc1a14f --- /dev/null +++ b/enumeration/remote/aws/ec2_nat_gateway_enumerator.go @@ -0,0 +1,54 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2NatGatewayEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2NatGatewayEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2NatGatewayEnumerator { + return &EC2NatGatewayEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2NatGatewayEnumerator) SupportedType() resource.ResourceType { + return aws.AwsNatGatewayResourceType +} + +func (e *EC2NatGatewayEnumerator) Enumerate() ([]*resource.Resource, error) { + natGateways, err := e.repository.ListAllNatGateways() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(natGateways)) + + for _, natGateway := range natGateways { + + attrs := map[string]interface{}{} + if len(natGateway.NatGatewayAddresses) > 0 { + if allocId := natGateway.NatGatewayAddresses[0].AllocationId; allocId != nil { + attrs["allocation_id"] = *allocId + } + } + + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *natGateway.NatGatewayId, + attrs, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/ec2_network_acl_enumerator.go b/enumeration/remote/aws/ec2_network_acl_enumerator.go new file mode 100644 index 00000000..3facb89c --- /dev/null +++ b/enumeration/remote/aws/ec2_network_acl_enumerator.go @@ -0,0 +1,50 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2NetworkACLEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2NetworkACLEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2NetworkACLEnumerator { + return &EC2NetworkACLEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2NetworkACLEnumerator) SupportedType() resource.ResourceType { + return aws.AwsNetworkACLResourceType +} + +func (e *EC2NetworkACLEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.ListAllNetworkACLs() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + // Do not handle default network acl since it is a dedicated resource + if *res.IsDefault { + continue + } + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *res.NetworkAclId, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/ec2_network_acl_rule_enumerator.go b/enumeration/remote/aws/ec2_network_acl_rule_enumerator.go new file mode 100644 index 00000000..25aef07a --- /dev/null +++ b/enumeration/remote/aws/ec2_network_acl_rule_enumerator.go @@ -0,0 +1,70 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2NetworkACLRuleEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2NetworkACLRuleEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2NetworkACLRuleEnumerator { + return &EC2NetworkACLRuleEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2NetworkACLRuleEnumerator) SupportedType() resource.ResourceType { + return aws.AwsNetworkACLRuleResourceType +} + +func (e *EC2NetworkACLRuleEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.ListAllNetworkACLs() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsNetworkACLResourceType) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + for _, entry := range res.Entries { + + attrs := map[string]interface{}{ + "egress": *entry.Egress, + "network_acl_id": *res.NetworkAclId, + "rule_action": *entry.RuleAction, // Used in default middleware + "rule_number": float64(*entry.RuleNumber), // Used in default middleware + "protocol": *entry.Protocol, // Used in default middleware + } + + if entry.CidrBlock != nil { + attrs["cidr_block"] = *entry.CidrBlock + } + + if entry.Ipv6CidrBlock != nil { + attrs["ipv6_cidr_block"] = *entry.Ipv6CidrBlock + } + + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + aws.CreateNetworkACLRuleID( + *res.NetworkAclId, + int(float64(*entry.RuleNumber)), + *entry.Egress, + *entry.Protocol, + ), + attrs, + ), + ) + } + } + + return results, err +} diff --git a/enumeration/remote/aws/ec2_route_enumerator.go b/enumeration/remote/aws/ec2_route_enumerator.go new file mode 100644 index 00000000..e41c193f --- /dev/null +++ b/enumeration/remote/aws/ec2_route_enumerator.go @@ -0,0 +1,66 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2RouteEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2RouteEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2RouteEnumerator { + return &EC2RouteEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2RouteEnumerator) SupportedType() resource.ResourceType { + return aws.AwsRouteResourceType +} + +func (e *EC2RouteEnumerator) Enumerate() ([]*resource.Resource, error) { + routeTables, err := e.repository.ListAllRouteTables() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsRouteTableResourceType) + } + + var results []*resource.Resource + + for _, routeTable := range routeTables { + for _, route := range routeTable.Routes { + routeId := aws.CalculateRouteID(routeTable.RouteTableId, route.DestinationCidrBlock, route.DestinationIpv6CidrBlock, route.DestinationPrefixListId) + data := map[string]interface{}{ + "route_table_id": *routeTable.RouteTableId, + "origin": *route.Origin, + } + if route.DestinationCidrBlock != nil && *route.DestinationCidrBlock != "" { + data["destination_cidr_block"] = *route.DestinationCidrBlock + } + if route.DestinationIpv6CidrBlock != nil && *route.DestinationIpv6CidrBlock != "" { + data["destination_ipv6_cidr_block"] = *route.DestinationIpv6CidrBlock + } + if route.DestinationPrefixListId != nil && *route.DestinationPrefixListId != "" { + data["destination_prefix_list_id"] = *route.DestinationPrefixListId + } + if route.GatewayId != nil && *route.GatewayId != "" { + data["gateway_id"] = *route.GatewayId + } + + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + routeId, + data, + ), + ) + } + } + + return results, err +} diff --git a/enumeration/remote/aws/ec2_route_table_association_enumerator.go b/enumeration/remote/aws/ec2_route_table_association_enumerator.go new file mode 100644 index 00000000..4ccd6335 --- /dev/null +++ b/enumeration/remote/aws/ec2_route_table_association_enumerator.go @@ -0,0 +1,69 @@ +package aws + +import ( + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2RouteTableAssociationEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2RouteTableAssociationEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2RouteTableAssociationEnumerator { + return &EC2RouteTableAssociationEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2RouteTableAssociationEnumerator) SupportedType() resource.ResourceType { + return aws.AwsRouteTableAssociationResourceType +} + +func (e *EC2RouteTableAssociationEnumerator) Enumerate() ([]*resource.Resource, error) { + routeTables, err := e.repository.ListAllRouteTables() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsRouteTableResourceType) + } + + var results []*resource.Resource + + for _, routeTable := range routeTables { + for _, assoc := range routeTable.Associations { + if e.shouldBeIgnored(assoc) { + continue + } + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *assoc.RouteTableAssociationId, + map[string]interface{}{ + "route_table_id": *assoc.RouteTableId, + }, + ), + ) + } + } + + return results, err +} + +func (e *EC2RouteTableAssociationEnumerator) shouldBeIgnored(assoc *ec2.RouteTableAssociation) bool { + // Ignore when nothing is associated + if assoc.GatewayId == nil && assoc.SubnetId == nil { + return true + } + + // Ignore when association is not associated + if assoc.AssociationState != nil && assoc.AssociationState.State != nil && + *assoc.AssociationState.State != "associated" { + return true + } + + return false +} diff --git a/enumeration/remote/aws/ec2_route_table_enumerator.go b/enumeration/remote/aws/ec2_route_table_enumerator.go new file mode 100644 index 00000000..052d278d --- /dev/null +++ b/enumeration/remote/aws/ec2_route_table_enumerator.go @@ -0,0 +1,58 @@ +package aws + +import ( + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2RouteTableEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2RouteTableEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2RouteTableEnumerator { + return &EC2RouteTableEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2RouteTableEnumerator) SupportedType() resource.ResourceType { + return aws.AwsRouteTableResourceType +} + +func (e *EC2RouteTableEnumerator) Enumerate() ([]*resource.Resource, error) { + routeTables, err := e.repository.ListAllRouteTables() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + var results []*resource.Resource + + for _, routeTable := range routeTables { + if !isMainRouteTable(routeTable) { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *routeTable.RouteTableId, + map[string]interface{}{}, + ), + ) + } + } + + return results, err +} + +func isMainRouteTable(routeTable *ec2.RouteTable) bool { + for _, assoc := range routeTable.Associations { + if assoc.Main != nil && *assoc.Main { + return true + } + } + return false +} diff --git a/enumeration/remote/aws/ec2_subnet_enumerator.go b/enumeration/remote/aws/ec2_subnet_enumerator.go new file mode 100644 index 00000000..77592046 --- /dev/null +++ b/enumeration/remote/aws/ec2_subnet_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type EC2SubnetEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewEC2SubnetEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2SubnetEnumerator { + return &EC2SubnetEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *EC2SubnetEnumerator) SupportedType() resource.ResourceType { + return aws.AwsSubnetResourceType +} + +func (e *EC2SubnetEnumerator) Enumerate() ([]*resource.Resource, error) { + subnets, _, err := e.repository.ListAllSubnets() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(subnets)) + + for _, subnet := range subnets { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *subnet.SubnetId, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/ecr_repository_enumerator.go b/enumeration/remote/aws/ecr_repository_enumerator.go new file mode 100644 index 00000000..7ad47496 --- /dev/null +++ b/enumeration/remote/aws/ecr_repository_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ECRRepositoryEnumerator struct { + repository repository.ECRRepository + factory resource.ResourceFactory +} + +func NewECRRepositoryEnumerator(repo repository.ECRRepository, factory resource.ResourceFactory) *ECRRepositoryEnumerator { + return &ECRRepositoryEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ECRRepositoryEnumerator) SupportedType() resource.ResourceType { + return aws.AwsEcrRepositoryResourceType +} + +func (e *ECRRepositoryEnumerator) Enumerate() ([]*resource.Resource, error) { + repos, err := e.repository.ListAllRepositories() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(repos)) + + for _, repo := range repos { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *repo.RepositoryName, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/ecr_repository_policy_enumerator.go b/enumeration/remote/aws/ecr_repository_policy_enumerator.go new file mode 100644 index 00000000..990302ef --- /dev/null +++ b/enumeration/remote/aws/ecr_repository_policy_enumerator.go @@ -0,0 +1,55 @@ +package aws + +import ( + "github.com/aws/aws-sdk-go/service/ecr" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ECRRepositoryPolicyEnumerator struct { + repository repository.ECRRepository + factory resource.ResourceFactory +} + +func NewECRRepositoryPolicyEnumerator(repo repository.ECRRepository, factory resource.ResourceFactory) *ECRRepositoryPolicyEnumerator { + return &ECRRepositoryPolicyEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ECRRepositoryPolicyEnumerator) SupportedType() resource.ResourceType { + return aws.AwsEcrRepositoryPolicyResourceType +} + +func (e *ECRRepositoryPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { + repos, err := e.repository.ListAllRepositories() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsEcrRepositoryResourceType) + } + + results := make([]*resource.Resource, 0, len(repos)) + + for _, repo := range repos { + repoOutput, err := e.repository.GetRepositoryPolicy(repo) + if _, ok := err.(*ecr.RepositoryPolicyNotFoundException); ok { + continue + } + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *repoOutput.RepositoryName, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/elasticache_cluster_enumerator.go b/enumeration/remote/aws/elasticache_cluster_enumerator.go new file mode 100644 index 00000000..2ffd0eae --- /dev/null +++ b/enumeration/remote/aws/elasticache_cluster_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type ElastiCacheClusterEnumerator struct { + repository repository.ElastiCacheRepository + factory resource.ResourceFactory +} + +func NewElastiCacheClusterEnumerator(repo repository.ElastiCacheRepository, factory resource.ResourceFactory) *ElastiCacheClusterEnumerator { + return &ElastiCacheClusterEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *ElastiCacheClusterEnumerator) SupportedType() resource.ResourceType { + return aws.AwsElastiCacheClusterResourceType +} + +func (e *ElastiCacheClusterEnumerator) Enumerate() ([]*resource.Resource, error) { + clusters, err := e.repository.ListAllCacheClusters() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(clusters)) + + for _, cluster := range clusters { + c := cluster + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *c.CacheClusterId, + map[string]interface{}{}, + ), + ) + } + return results, err +} diff --git a/enumeration/remote/aws/iam_access_key_enumerator.go b/enumeration/remote/aws/iam_access_key_enumerator.go new file mode 100644 index 00000000..73306046 --- /dev/null +++ b/enumeration/remote/aws/iam_access_key_enumerator.go @@ -0,0 +1,52 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type IamAccessKeyEnumerator struct { + repository repository.IAMRepository + factory resource.ResourceFactory +} + +func NewIamAccessKeyEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamAccessKeyEnumerator { + return &IamAccessKeyEnumerator{ + repository, + factory, + } +} + +func (e *IamAccessKeyEnumerator) SupportedType() resource.ResourceType { + return resourceaws.AwsIamAccessKeyResourceType +} + +func (e *IamAccessKeyEnumerator) Enumerate() ([]*resource.Resource, error) { + users, err := e.repository.ListAllUsers() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsIamUserResourceType) + } + + keys, err := e.repository.ListAllAccessKeys(users) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0) + for _, key := range keys { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *key.AccessKeyId, + map[string]interface{}{ + "user": *key.UserName, + }, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/iam_group_enumerator.go b/enumeration/remote/aws/iam_group_enumerator.go new file mode 100644 index 00000000..ba0931b8 --- /dev/null +++ b/enumeration/remote/aws/iam_group_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type IamGroupEnumerator struct { + repository repository.IAMRepository + factory resource.ResourceFactory +} + +func NewIamGroupEnumerator(repo repository.IAMRepository, factory resource.ResourceFactory) *IamGroupEnumerator { + return &IamGroupEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *IamGroupEnumerator) SupportedType() resource.ResourceType { + return aws.AwsIamGroupResourceType +} + +func (e *IamGroupEnumerator) Enumerate() ([]*resource.Resource, error) { + groups, err := e.repository.ListAllGroups() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsIamGroupResourceType) + } + + results := make([]*resource.Resource, 0, len(groups)) + + for _, group := range groups { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *group.GroupName, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/iam_group_policy_attachment_enumerator.go b/enumeration/remote/aws/iam_group_policy_attachment_enumerator.go new file mode 100644 index 00000000..b3e130d7 --- /dev/null +++ b/enumeration/remote/aws/iam_group_policy_attachment_enumerator.go @@ -0,0 +1,56 @@ +package aws + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type IamGroupPolicyAttachmentEnumerator struct { + repository repository.IAMRepository + factory resource.ResourceFactory +} + +func NewIamGroupPolicyAttachmentEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamGroupPolicyAttachmentEnumerator { + return &IamGroupPolicyAttachmentEnumerator{ + repository, + factory, + } +} + +func (e *IamGroupPolicyAttachmentEnumerator) SupportedType() resource.ResourceType { + return resourceaws.AwsIamGroupPolicyAttachmentResourceType +} + +func (e *IamGroupPolicyAttachmentEnumerator) Enumerate() ([]*resource.Resource, error) { + groups, err := e.repository.ListAllGroups() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsIamGroupResourceType) + } + + results := make([]*resource.Resource, 0) + + policyAttachments, err := e.repository.ListAllGroupPolicyAttachments(groups) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, attachedPol := range policyAttachments { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + fmt.Sprintf("%s-%s", *attachedPol.PolicyName, attachedPol.GroupName), + map[string]interface{}{ + "group": attachedPol.GroupName, + "policy_arn": *attachedPol.PolicyArn, + }, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/iam_group_policy_enumerator.go b/enumeration/remote/aws/iam_group_policy_enumerator.go new file mode 100644 index 00000000..daff7f61 --- /dev/null +++ b/enumeration/remote/aws/iam_group_policy_enumerator.go @@ -0,0 +1,50 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type IamGroupPolicyEnumerator struct { + repository repository.IAMRepository + factory resource.ResourceFactory +} + +func NewIamGroupPolicyEnumerator(repo repository.IAMRepository, factory resource.ResourceFactory) *IamGroupPolicyEnumerator { + return &IamGroupPolicyEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *IamGroupPolicyEnumerator) SupportedType() resource.ResourceType { + return aws.AwsIamGroupPolicyResourceType +} + +func (e *IamGroupPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { + groups, err := e.repository.ListAllGroups() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsIamGroupResourceType) + } + groupPolicies, err := e.repository.ListAllGroupPolicies(groups) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(groupPolicies)) + + for _, groupPolicy := range groupPolicies { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + groupPolicy, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/iam_policy_enumerator.go b/enumeration/remote/aws/iam_policy_enumerator.go new file mode 100644 index 00000000..0f17dd74 --- /dev/null +++ b/enumeration/remote/aws/iam_policy_enumerator.go @@ -0,0 +1,47 @@ +package aws + +import ( + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type IamPolicyEnumerator struct { + repository repository.IAMRepository + factory resource.ResourceFactory +} + +func NewIamPolicyEnumerator(repo repository.IAMRepository, factory resource.ResourceFactory) *IamPolicyEnumerator { + return &IamPolicyEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *IamPolicyEnumerator) SupportedType() resource.ResourceType { + return aws.AwsIamPolicyResourceType +} + +func (e *IamPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { + policies, err := e.repository.ListAllPolicies() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(policies)) + + for _, policy := range policies { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + awssdk.StringValue(policy.Arn), + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/iam_role_enumerator.go b/enumeration/remote/aws/iam_role_enumerator.go new file mode 100644 index 00000000..0e142e4c --- /dev/null +++ b/enumeration/remote/aws/iam_role_enumerator.go @@ -0,0 +1,65 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" +) + +var iamRoleExclusionList = map[string]struct{}{ + // Enabled by default for aws to enable support, not removable + "AWSServiceRoleForSupport": {}, + // Enabled and not removable for every org account + "AWSServiceRoleForOrganizations": {}, + // Not manageable by IaC and set by default + "AWSServiceRoleForTrustedAdvisor": {}, +} + +type IamRoleEnumerator struct { + repository repository.IAMRepository + factory resource.ResourceFactory +} + +func NewIamRoleEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamRoleEnumerator { + return &IamRoleEnumerator{ + repository, + factory, + } +} + +func (e *IamRoleEnumerator) SupportedType() resource.ResourceType { + return resourceaws.AwsIamRoleResourceType +} + +func awsIamRoleShouldBeIgnored(roleName string) bool { + _, ok := iamRoleExclusionList[roleName] + return ok +} + +func (e *IamRoleEnumerator) Enumerate() ([]*resource.Resource, error) { + roles, err := e.repository.ListAllRoles() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0) + for _, role := range roles { + if role.RoleName != nil && awsIamRoleShouldBeIgnored(*role.RoleName) { + continue + } + + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *role.RoleName, + map[string]interface{}{ + "path": *role.Path, + }, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/iam_role_policy_attachment_enumerator.go b/enumeration/remote/aws/iam_role_policy_attachment_enumerator.go new file mode 100644 index 00000000..c2e0c2cb --- /dev/null +++ b/enumeration/remote/aws/iam_role_policy_attachment_enumerator.go @@ -0,0 +1,69 @@ +package aws + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + + "github.com/aws/aws-sdk-go/service/iam" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type IamRolePolicyAttachmentEnumerator struct { + repository repository.IAMRepository + factory resource.ResourceFactory +} + +func NewIamRolePolicyAttachmentEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamRolePolicyAttachmentEnumerator { + return &IamRolePolicyAttachmentEnumerator{ + repository, + factory, + } +} + +func (e *IamRolePolicyAttachmentEnumerator) SupportedType() resource.ResourceType { + return resourceaws.AwsIamRolePolicyAttachmentResourceType +} + +func (e *IamRolePolicyAttachmentEnumerator) Enumerate() ([]*resource.Resource, error) { + roles, err := e.repository.ListAllRoles() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsIamRoleResourceType) + } + + results := make([]*resource.Resource, 0) + rolesNotIgnored := make([]*iam.Role, 0) + + for _, role := range roles { + if role.RoleName != nil && awsIamRoleShouldBeIgnored(*role.RoleName) { + continue + } + rolesNotIgnored = append(rolesNotIgnored, role) + } + + if len(rolesNotIgnored) == 0 { + return results, nil + } + + policyAttachments, err := e.repository.ListAllRolePolicyAttachments(rolesNotIgnored) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, attachedPol := range policyAttachments { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + fmt.Sprintf("%s-%s", *attachedPol.PolicyName, attachedPol.RoleName), + map[string]interface{}{ + "role": attachedPol.RoleName, + "policy_arn": *attachedPol.PolicyArn, + }, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/iam_role_policy_enumerator.go b/enumeration/remote/aws/iam_role_policy_enumerator.go new file mode 100644 index 00000000..34045585 --- /dev/null +++ b/enumeration/remote/aws/iam_role_policy_enumerator.go @@ -0,0 +1,54 @@ +package aws + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type IamRolePolicyEnumerator struct { + repository repository.IAMRepository + factory resource.ResourceFactory +} + +func NewIamRolePolicyEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamRolePolicyEnumerator { + return &IamRolePolicyEnumerator{ + repository, + factory, + } +} + +func (e *IamRolePolicyEnumerator) SupportedType() resource.ResourceType { + return resourceaws.AwsIamRolePolicyResourceType +} + +func (e *IamRolePolicyEnumerator) Enumerate() ([]*resource.Resource, error) { + roles, err := e.repository.ListAllRoles() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsIamRoleResourceType) + } + + policies, err := e.repository.ListAllRolePolicies(roles) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(policies)) + for _, policy := range policies { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + fmt.Sprintf("%s:%s", policy.RoleName, policy.Policy), + map[string]interface{}{ + "role": policy.RoleName, + }, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/iam_user_enumerator.go b/enumeration/remote/aws/iam_user_enumerator.go new file mode 100644 index 00000000..23118fc4 --- /dev/null +++ b/enumeration/remote/aws/iam_user_enumerator.go @@ -0,0 +1,47 @@ +package aws + +import ( + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type IamUserEnumerator struct { + repository repository.IAMRepository + factory resource.ResourceFactory +} + +func NewIamUserEnumerator(repo repository.IAMRepository, factory resource.ResourceFactory) *IamUserEnumerator { + return &IamUserEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *IamUserEnumerator) SupportedType() resource.ResourceType { + return aws.AwsIamUserResourceType +} + +func (e *IamUserEnumerator) Enumerate() ([]*resource.Resource, error) { + users, err := e.repository.ListAllUsers() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(users)) + + for _, user := range users { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + awssdk.StringValue(user.UserName), + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/iam_user_policy_attachment_enumerator.go b/enumeration/remote/aws/iam_user_policy_attachment_enumerator.go new file mode 100644 index 00000000..f910bb84 --- /dev/null +++ b/enumeration/remote/aws/iam_user_policy_attachment_enumerator.go @@ -0,0 +1,55 @@ +package aws + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type IamUserPolicyAttachmentEnumerator struct { + repository repository.IAMRepository + factory resource.ResourceFactory +} + +func NewIamUserPolicyAttachmentEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamUserPolicyAttachmentEnumerator { + return &IamUserPolicyAttachmentEnumerator{ + repository, + factory, + } +} + +func (e *IamUserPolicyAttachmentEnumerator) SupportedType() resource.ResourceType { + return resourceaws.AwsIamUserPolicyAttachmentResourceType +} + +func (e *IamUserPolicyAttachmentEnumerator) Enumerate() ([]*resource.Resource, error) { + users, err := e.repository.ListAllUsers() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsIamUserResourceType) + } + + results := make([]*resource.Resource, 0) + policyAttachments, err := e.repository.ListAllUserPolicyAttachments(users) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, attachedPol := range policyAttachments { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + fmt.Sprintf("%s-%s", *attachedPol.PolicyName, attachedPol.UserName), + map[string]interface{}{ + "user": attachedPol.UserName, + "policy_arn": *attachedPol.PolicyArn, + }, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/iam_user_policy_enumerator.go b/enumeration/remote/aws/iam_user_policy_enumerator.go new file mode 100644 index 00000000..14c6a729 --- /dev/null +++ b/enumeration/remote/aws/iam_user_policy_enumerator.go @@ -0,0 +1,50 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type IamUserPolicyEnumerator struct { + repository repository.IAMRepository + factory resource.ResourceFactory +} + +func NewIamUserPolicyEnumerator(repo repository.IAMRepository, factory resource.ResourceFactory) *IamUserPolicyEnumerator { + return &IamUserPolicyEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *IamUserPolicyEnumerator) SupportedType() resource.ResourceType { + return aws.AwsIamUserPolicyResourceType +} + +func (e *IamUserPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { + users, err := e.repository.ListAllUsers() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsIamUserResourceType) + } + userPolicies, err := e.repository.ListAllUserPolicies(users) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(userPolicies)) + + for _, userPolicy := range userPolicies { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + userPolicy, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/init.go b/enumeration/remote/aws/init.go new file mode 100644 index 00000000..5d8e3409 --- /dev/null +++ b/enumeration/remote/aws/init.go @@ -0,0 +1,258 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/aws/client" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/enumeration/terraform" +) + +/** + * Initialize remote (configure credentials, launch tf providers and start gRPC clients) + * Required to use Scanner + */ + +func Init(version string, alerter *alerter.Alerter, + providerLibrary *terraform.ProviderLibrary, + remoteLibrary *common2.RemoteLibrary, + progress enumeration.ProgressCounter, + resourceSchemaRepository *resource.SchemaRepository, + factory resource.ResourceFactory, + configDir string) error { + + provider, err := NewAWSTerraformProvider(version, progress, configDir) + if err != nil { + return err + } + err = provider.CheckCredentialsExist() + if err != nil { + return err + } + err = provider.Init() + if err != nil { + return err + } + + repositoryCache := cache.New(100) + + s3Repository := repository2.NewS3Repository(client.NewAWSClientFactory(provider.session), repositoryCache) + ec2repository := repository2.NewEC2Repository(provider.session, repositoryCache) + elbv2Repository := repository2.NewELBV2Repository(provider.session, repositoryCache) + route53repository := repository2.NewRoute53Repository(provider.session, repositoryCache) + lambdaRepository := repository2.NewLambdaRepository(provider.session, repositoryCache) + rdsRepository := repository2.NewRDSRepository(provider.session, repositoryCache) + sqsRepository := repository2.NewSQSRepository(provider.session, repositoryCache) + snsRepository := repository2.NewSNSRepository(provider.session, repositoryCache) + cloudfrontRepository := repository2.NewCloudfrontRepository(provider.session, repositoryCache) + dynamoDBRepository := repository2.NewDynamoDBRepository(provider.session, repositoryCache) + ecrRepository := repository2.NewECRRepository(provider.session, repositoryCache) + kmsRepository := repository2.NewKMSRepository(provider.session, repositoryCache) + iamRepository := repository2.NewIAMRepository(provider.session, repositoryCache) + cloudformationRepository := repository2.NewCloudformationRepository(provider.session, repositoryCache) + apigatewayRepository := repository2.NewApiGatewayRepository(provider.session, repositoryCache) + appAutoScalingRepository := repository2.NewAppAutoScalingRepository(provider.session, repositoryCache) + apigatewayv2Repository := repository2.NewApiGatewayV2Repository(provider.session, repositoryCache) + autoscalingRepository := repository2.NewAutoScalingRepository(provider.session, repositoryCache) + elbRepository := repository2.NewELBRepository(provider.session, repositoryCache) + elasticacheRepository := repository2.NewElastiCacheRepository(provider.session, repositoryCache) + + deserializer := resource.NewDeserializer(factory) + providerLibrary.AddProvider(terraform.AWS, provider) + + remoteLibrary.AddEnumerator(NewS3BucketEnumerator(s3Repository, factory, provider.Config, alerter)) + remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketResourceType, common2.NewGenericDetailsFetcher(aws.AwsS3BucketResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewS3BucketInventoryEnumerator(s3Repository, factory, provider.Config, alerter)) + remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketInventoryResourceType, common2.NewGenericDetailsFetcher(aws.AwsS3BucketInventoryResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewS3BucketNotificationEnumerator(s3Repository, factory, provider.Config, alerter)) + remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketNotificationResourceType, common2.NewGenericDetailsFetcher(aws.AwsS3BucketNotificationResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewS3BucketMetricsEnumerator(s3Repository, factory, provider.Config, alerter)) + remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketMetricResourceType, common2.NewGenericDetailsFetcher(aws.AwsS3BucketMetricResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewS3BucketPolicyEnumerator(s3Repository, factory, provider.Config, alerter)) + remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketPolicyResourceType, common2.NewGenericDetailsFetcher(aws.AwsS3BucketPolicyResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewS3BucketAnalyticEnumerator(s3Repository, factory, provider.Config, alerter)) + remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, common2.NewGenericDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewS3BucketPublicAccessBlockEnumerator(s3Repository, factory, provider.Config, alerter)) + + remoteLibrary.AddEnumerator(NewEC2EbsVolumeEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsEbsVolumeResourceType, common2.NewGenericDetailsFetcher(aws.AwsEbsVolumeResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2EbsSnapshotEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsEbsSnapshotResourceType, common2.NewGenericDetailsFetcher(aws.AwsEbsSnapshotResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2EipEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsEipResourceType, common2.NewGenericDetailsFetcher(aws.AwsEipResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2AmiEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsAmiResourceType, common2.NewGenericDetailsFetcher(aws.AwsAmiResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2KeyPairEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsKeyPairResourceType, common2.NewGenericDetailsFetcher(aws.AwsKeyPairResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2EipAssociationEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsEipAssociationResourceType, common2.NewGenericDetailsFetcher(aws.AwsEipAssociationResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2InstanceEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsInstanceResourceType, common2.NewGenericDetailsFetcher(aws.AwsInstanceResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2InternetGatewayEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsInternetGatewayResourceType, common2.NewGenericDetailsFetcher(aws.AwsInternetGatewayResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewVPCEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsVpcResourceType, common2.NewGenericDetailsFetcher(aws.AwsVpcResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewDefaultVPCEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsDefaultVpcResourceType, common2.NewGenericDetailsFetcher(aws.AwsDefaultVpcResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2RouteTableEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsRouteTableResourceType, common2.NewGenericDetailsFetcher(aws.AwsRouteTableResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2DefaultRouteTableEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsDefaultRouteTableResourceType, common2.NewGenericDetailsFetcher(aws.AwsDefaultRouteTableResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2RouteTableAssociationEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsRouteTableAssociationResourceType, common2.NewGenericDetailsFetcher(aws.AwsRouteTableAssociationResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2SubnetEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsSubnetResourceType, common2.NewGenericDetailsFetcher(aws.AwsSubnetResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2DefaultSubnetEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsDefaultSubnetResourceType, common2.NewGenericDetailsFetcher(aws.AwsDefaultSubnetResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewVPCSecurityGroupEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsSecurityGroupResourceType, common2.NewGenericDetailsFetcher(aws.AwsSecurityGroupResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewVPCDefaultSecurityGroupEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsDefaultSecurityGroupResourceType, common2.NewGenericDetailsFetcher(aws.AwsDefaultSecurityGroupResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2NatGatewayEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsNatGatewayResourceType, common2.NewGenericDetailsFetcher(aws.AwsNatGatewayResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2NetworkACLEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsNetworkACLResourceType, common2.NewGenericDetailsFetcher(aws.AwsNetworkACLResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2NetworkACLRuleEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsNetworkACLRuleResourceType, common2.NewGenericDetailsFetcher(aws.AwsNetworkACLRuleResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2DefaultNetworkACLEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsDefaultNetworkACLResourceType, common2.NewGenericDetailsFetcher(aws.AwsDefaultNetworkACLResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2RouteEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsRouteResourceType, common2.NewGenericDetailsFetcher(aws.AwsRouteResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewVPCSecurityGroupRuleEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsSecurityGroupRuleResourceType, common2.NewGenericDetailsFetcher(aws.AwsSecurityGroupRuleResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewLaunchTemplateEnumerator(ec2repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsLaunchTemplateResourceType, common2.NewGenericDetailsFetcher(aws.AwsLaunchTemplateResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewEC2EbsEncryptionByDefaultEnumerator(ec2repository, factory)) + + remoteLibrary.AddEnumerator(NewKMSKeyEnumerator(kmsRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsKmsKeyResourceType, common2.NewGenericDetailsFetcher(aws.AwsKmsKeyResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewKMSAliasEnumerator(kmsRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsKmsAliasResourceType, common2.NewGenericDetailsFetcher(aws.AwsKmsAliasResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewRoute53HealthCheckEnumerator(route53repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsRoute53HealthCheckResourceType, common2.NewGenericDetailsFetcher(aws.AwsRoute53HealthCheckResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewRoute53ZoneEnumerator(route53repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsRoute53ZoneResourceType, common2.NewGenericDetailsFetcher(aws.AwsRoute53ZoneResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewRoute53RecordEnumerator(route53repository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsRoute53RecordResourceType, common2.NewGenericDetailsFetcher(aws.AwsRoute53RecordResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewCloudfrontDistributionEnumerator(cloudfrontRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsCloudfrontDistributionResourceType, common2.NewGenericDetailsFetcher(aws.AwsCloudfrontDistributionResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewRDSDBInstanceEnumerator(rdsRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsDbInstanceResourceType, common2.NewGenericDetailsFetcher(aws.AwsDbInstanceResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewRDSDBSubnetGroupEnumerator(rdsRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsDbSubnetGroupResourceType, common2.NewGenericDetailsFetcher(aws.AwsDbSubnetGroupResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewSQSQueueEnumerator(sqsRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsSqsQueueResourceType, NewSQSQueueDetailsFetcher(provider, deserializer)) + remoteLibrary.AddEnumerator(NewSQSQueuePolicyEnumerator(sqsRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsSqsQueuePolicyResourceType, common2.NewGenericDetailsFetcher(aws.AwsSqsQueuePolicyResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewSNSTopicEnumerator(snsRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsSnsTopicResourceType, common2.NewGenericDetailsFetcher(aws.AwsSnsTopicResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewSNSTopicPolicyEnumerator(snsRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsSnsTopicPolicyResourceType, common2.NewGenericDetailsFetcher(aws.AwsSnsTopicPolicyResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewSNSTopicSubscriptionEnumerator(snsRepository, factory, alerter)) + remoteLibrary.AddDetailsFetcher(aws.AwsSnsTopicSubscriptionResourceType, common2.NewGenericDetailsFetcher(aws.AwsSnsTopicSubscriptionResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewDynamoDBTableEnumerator(dynamoDBRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsDynamodbTableResourceType, common2.NewGenericDetailsFetcher(aws.AwsDynamodbTableResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewIamPolicyEnumerator(iamRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsIamPolicyResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamPolicyResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewLambdaFunctionEnumerator(lambdaRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsLambdaFunctionResourceType, common2.NewGenericDetailsFetcher(aws.AwsLambdaFunctionResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewLambdaEventSourceMappingEnumerator(lambdaRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsLambdaEventSourceMappingResourceType, common2.NewGenericDetailsFetcher(aws.AwsLambdaEventSourceMappingResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewIamUserEnumerator(iamRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsIamUserResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamUserResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewIamUserPolicyEnumerator(iamRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsIamUserPolicyResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamUserPolicyResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewIamRoleEnumerator(iamRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsIamRoleResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamRoleResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewIamAccessKeyEnumerator(iamRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsIamAccessKeyResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamAccessKeyResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewIamRolePolicyAttachmentEnumerator(iamRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsIamRolePolicyAttachmentResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamRolePolicyAttachmentResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewIamRolePolicyEnumerator(iamRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsIamRolePolicyResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamRolePolicyResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewIamUserPolicyAttachmentEnumerator(iamRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsIamUserPolicyAttachmentResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamUserPolicyAttachmentResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewIamGroupPolicyEnumerator(iamRepository, factory)) + remoteLibrary.AddEnumerator(NewIamGroupEnumerator(iamRepository, factory)) + remoteLibrary.AddEnumerator(NewIamGroupPolicyAttachmentEnumerator(iamRepository, factory)) + + remoteLibrary.AddEnumerator(NewECRRepositoryEnumerator(ecrRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsEcrRepositoryResourceType, common2.NewGenericDetailsFetcher(aws.AwsEcrRepositoryResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewECRRepositoryPolicyEnumerator(ecrRepository, factory)) + + remoteLibrary.AddEnumerator(NewRDSClusterEnumerator(rdsRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsRDSClusterResourceType, common2.NewGenericDetailsFetcher(aws.AwsRDSClusterResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewCloudformationStackEnumerator(cloudformationRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsCloudformationStackResourceType, common2.NewGenericDetailsFetcher(aws.AwsCloudformationStackResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewApiGatewayRestApiEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayAccountEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayApiKeyEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayAuthorizerEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayStageEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayResourceEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayDomainNameEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayVpcLinkEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayRequestValidatorEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayRestApiPolicyEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayBasePathMappingEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayMethodEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayModelEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayMethodResponseEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayGatewayResponseEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayMethodSettingsEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayIntegrationEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayIntegrationResponseEnumerator(apigatewayRepository, factory)) + + remoteLibrary.AddEnumerator(NewApiGatewayV2ApiEnumerator(apigatewayv2Repository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayV2RouteEnumerator(apigatewayv2Repository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayV2DeploymentEnumerator(apigatewayv2Repository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayV2VpcLinkEnumerator(apigatewayv2Repository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayV2AuthorizerEnumerator(apigatewayv2Repository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayV2IntegrationEnumerator(apigatewayv2Repository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayV2ModelEnumerator(apigatewayv2Repository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayV2StageEnumerator(apigatewayv2Repository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayV2RouteResponseEnumerator(apigatewayv2Repository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayV2MappingEnumerator(apigatewayv2Repository, apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayV2DomainNameEnumerator(apigatewayRepository, factory)) + remoteLibrary.AddEnumerator(NewApiGatewayV2IntegrationResponseEnumerator(apigatewayv2Repository, factory)) + + remoteLibrary.AddEnumerator(NewAppAutoscalingTargetEnumerator(appAutoScalingRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsAppAutoscalingTargetResourceType, common2.NewGenericDetailsFetcher(aws.AwsAppAutoscalingTargetResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewAppAutoscalingPolicyEnumerator(appAutoScalingRepository, factory)) + remoteLibrary.AddDetailsFetcher(aws.AwsAppAutoscalingPolicyResourceType, common2.NewGenericDetailsFetcher(aws.AwsAppAutoscalingPolicyResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewAppAutoscalingScheduledActionEnumerator(appAutoScalingRepository, factory)) + + remoteLibrary.AddEnumerator(NewLaunchConfigurationEnumerator(autoscalingRepository, factory)) + + remoteLibrary.AddEnumerator(NewLoadBalancerEnumerator(elbv2Repository, factory)) + remoteLibrary.AddEnumerator(NewLoadBalancerListenerEnumerator(elbv2Repository, factory)) + + remoteLibrary.AddEnumerator(NewClassicLoadBalancerEnumerator(elbRepository, factory)) + + remoteLibrary.AddEnumerator(NewElastiCacheClusterEnumerator(elasticacheRepository, factory)) + + err = resourceSchemaRepository.Init(terraform.AWS, provider.Version(), provider.Schema()) + if err != nil { + return err + } + aws.InitResourcesMetadata(resourceSchemaRepository) + + return nil +} diff --git a/enumeration/remote/aws/kms_alias_enumerator.go b/enumeration/remote/aws/kms_alias_enumerator.go new file mode 100644 index 00000000..44f841d2 --- /dev/null +++ b/enumeration/remote/aws/kms_alias_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type KMSAliasEnumerator struct { + repository repository.KMSRepository + factory resource.ResourceFactory +} + +func NewKMSAliasEnumerator(repo repository.KMSRepository, factory resource.ResourceFactory) *KMSAliasEnumerator { + return &KMSAliasEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *KMSAliasEnumerator) SupportedType() resource.ResourceType { + return aws.AwsKmsAliasResourceType +} + +func (e *KMSAliasEnumerator) Enumerate() ([]*resource.Resource, error) { + aliases, err := e.repository.ListAllAliases() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(aliases)) + + for _, alias := range aliases { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *alias.AliasName, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/kms_key_enumerator.go b/enumeration/remote/aws/kms_key_enumerator.go new file mode 100644 index 00000000..f2078c96 --- /dev/null +++ b/enumeration/remote/aws/kms_key_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type KMSKeyEnumerator struct { + repository repository.KMSRepository + factory resource.ResourceFactory +} + +func NewKMSKeyEnumerator(repo repository.KMSRepository, factory resource.ResourceFactory) *KMSKeyEnumerator { + return &KMSKeyEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *KMSKeyEnumerator) SupportedType() resource.ResourceType { + return aws.AwsKmsKeyResourceType +} + +func (e *KMSKeyEnumerator) Enumerate() ([]*resource.Resource, error) { + keys, err := e.repository.ListAllKeys() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(keys)) + + for _, key := range keys { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *key.KeyId, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/lambda_event_source_mapping_enumerator.go b/enumeration/remote/aws/lambda_event_source_mapping_enumerator.go new file mode 100644 index 00000000..0fb03c17 --- /dev/null +++ b/enumeration/remote/aws/lambda_event_source_mapping_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type LambdaEventSourceMappingEnumerator struct { + repository repository.LambdaRepository + factory resource.ResourceFactory +} + +func NewLambdaEventSourceMappingEnumerator(repo repository.LambdaRepository, factory resource.ResourceFactory) *LambdaEventSourceMappingEnumerator { + return &LambdaEventSourceMappingEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *LambdaEventSourceMappingEnumerator) SupportedType() resource.ResourceType { + return resourceaws.AwsLambdaEventSourceMappingResourceType +} + +func (e *LambdaEventSourceMappingEnumerator) Enumerate() ([]*resource.Resource, error) { + eventSourceMappings, err := e.repository.ListAllLambdaEventSourceMappings() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(eventSourceMappings)) + + for _, eventSourceMapping := range eventSourceMappings { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *eventSourceMapping.UUID, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/lambda_function_enumerator.go b/enumeration/remote/aws/lambda_function_enumerator.go new file mode 100644 index 00000000..ec3c2320 --- /dev/null +++ b/enumeration/remote/aws/lambda_function_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type LambdaFunctionEnumerator struct { + repository repository.LambdaRepository + factory resource.ResourceFactory +} + +func NewLambdaFunctionEnumerator(repo repository.LambdaRepository, factory resource.ResourceFactory) *LambdaFunctionEnumerator { + return &LambdaFunctionEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *LambdaFunctionEnumerator) SupportedType() resource.ResourceType { + return resourceaws.AwsLambdaFunctionResourceType +} + +func (e *LambdaFunctionEnumerator) Enumerate() ([]*resource.Resource, error) { + functions, err := e.repository.ListAllLambdaFunctions() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(functions)) + + for _, function := range functions { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *function.FunctionName, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/launch_configuration_enumerator.go b/enumeration/remote/aws/launch_configuration_enumerator.go new file mode 100644 index 00000000..c5265522 --- /dev/null +++ b/enumeration/remote/aws/launch_configuration_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type LaunchConfigurationEnumerator struct { + repository repository.AutoScalingRepository + factory resource.ResourceFactory +} + +func NewLaunchConfigurationEnumerator(repo repository.AutoScalingRepository, factory resource.ResourceFactory) *LaunchConfigurationEnumerator { + return &LaunchConfigurationEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *LaunchConfigurationEnumerator) SupportedType() resource.ResourceType { + return aws.AwsLaunchConfigurationResourceType +} + +func (e *LaunchConfigurationEnumerator) Enumerate() ([]*resource.Resource, error) { + configs, err := e.repository.DescribeLaunchConfigurations() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(configs)) + + for _, config := range configs { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *config.LaunchConfigurationName, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/launch_template_enumerator.go b/enumeration/remote/aws/launch_template_enumerator.go new file mode 100644 index 00000000..b19b98ac --- /dev/null +++ b/enumeration/remote/aws/launch_template_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type LaunchTemplateEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewLaunchTemplateEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *LaunchTemplateEnumerator { + return &LaunchTemplateEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *LaunchTemplateEnumerator) SupportedType() resource.ResourceType { + return aws.AwsLaunchTemplateResourceType +} + +func (e *LaunchTemplateEnumerator) Enumerate() ([]*resource.Resource, error) { + templates, err := e.repository.DescribeLaunchTemplates() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(templates)) + + for _, tmpl := range templates { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *tmpl.LaunchTemplateId, + map[string]interface{}{}, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/load_balancer_enumerator.go b/enumeration/remote/aws/load_balancer_enumerator.go new file mode 100644 index 00000000..77746a50 --- /dev/null +++ b/enumeration/remote/aws/load_balancer_enumerator.go @@ -0,0 +1,48 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type LoadBalancerEnumerator struct { + repository repository.ELBV2Repository + factory resource.ResourceFactory +} + +func NewLoadBalancerEnumerator(repo repository.ELBV2Repository, factory resource.ResourceFactory) *LoadBalancerEnumerator { + return &LoadBalancerEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *LoadBalancerEnumerator) SupportedType() resource.ResourceType { + return aws.AwsLoadBalancerResourceType +} + +func (e *LoadBalancerEnumerator) Enumerate() ([]*resource.Resource, error) { + loadBalancers, err := e.repository.ListAllLoadBalancers() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(loadBalancers)) + + for _, lb := range loadBalancers { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *lb.LoadBalancerArn, + map[string]interface{}{ + "name": *lb.LoadBalancerName, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/load_balancer_listener_enumerator.go b/enumeration/remote/aws/load_balancer_listener_enumerator.go new file mode 100644 index 00000000..cc0d0ec5 --- /dev/null +++ b/enumeration/remote/aws/load_balancer_listener_enumerator.go @@ -0,0 +1,53 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type LoadBalancerListenerEnumerator struct { + repository repository.ELBV2Repository + factory resource.ResourceFactory +} + +func NewLoadBalancerListenerEnumerator(repo repository.ELBV2Repository, factory resource.ResourceFactory) *LoadBalancerListenerEnumerator { + return &LoadBalancerListenerEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *LoadBalancerListenerEnumerator) SupportedType() resource.ResourceType { + return aws.AwsLoadBalancerListenerResourceType +} + +func (e *LoadBalancerListenerEnumerator) Enumerate() ([]*resource.Resource, error) { + loadBalancers, err := e.repository.ListAllLoadBalancers() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsLoadBalancerResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, lb := range loadBalancers { + listeners, err := e.repository.ListAllLoadBalancerListeners(*lb.LoadBalancerArn) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, listener := range listeners { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *listener.ListenerArn, + map[string]interface{}{}, + ), + ) + } + } + + return results, nil +} diff --git a/enumeration/remote/aws/provider.go b/enumeration/remote/aws/provider.go new file mode 100644 index 00000000..1bd10f19 --- /dev/null +++ b/enumeration/remote/aws/provider.go @@ -0,0 +1,119 @@ +package aws + +import ( + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/terraform" + tf "github.com/snyk/driftctl/enumeration/terraform" +) + +type awsConfig struct { + AccessKey string + SecretKey string + CredsFilename string + Profile string + Token string + Region string `cty:"region"` + MaxRetries int + + AssumeRoleARN string + AssumeRoleExternalID string + AssumeRoleSessionName string + AssumeRolePolicy string + + AllowedAccountIds []string + ForbiddenAccountIds []string + + Endpoints map[string]string + IgnoreTagsConfig map[string]string + Insecure bool + + SkipCredsValidation bool `cty:"skip_credentials_validation"` + SkipGetEC2Platforms bool + SkipRegionValidation bool + SkipRequestingAccountId bool `cty:"skip_requesting_account_id"` + SkipMetadataApiCheck bool + S3ForcePathStyle bool +} + +type AWSTerraformProvider struct { + *terraform.TerraformProvider + session *session.Session + name string + version string +} + +func NewAWSTerraformProvider(version string, progress enumeration.ProgressCounter, configDir string) (*AWSTerraformProvider, error) { + if version == "" { + version = "3.19.0" + } + p := &AWSTerraformProvider{ + version: version, + name: "aws", + } + installer, err := tf.NewProviderInstaller(tf.ProviderConfig{ + Key: p.name, + Version: version, + ConfigDir: configDir, + }) + if err != nil { + return nil, err + } + p.session = session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + tfProvider, err := terraform.NewTerraformProvider(installer, terraform.TerraformProviderConfig{ + Name: p.name, + DefaultAlias: *p.session.Config.Region, + GetProviderConfig: func(alias string) interface{} { + return awsConfig{ + Region: alias, + // Those two parameters are used to make sure that the credentials are not validated when calling + // Configure(). Credentials validation is now handled directly in driftctl + SkipCredsValidation: true, + SkipRequestingAccountId: true, + + MaxRetries: 10, // TODO make this configurable + } + }, + }, progress) + if err != nil { + return nil, err + } + p.TerraformProvider = tfProvider + return p, err +} + +func (a *AWSTerraformProvider) Name() string { + return a.name +} + +func (p *AWSTerraformProvider) Version() string { + return p.version +} + +func (p *AWSTerraformProvider) CheckCredentialsExist() error { + _, err := p.session.Config.Credentials.Get() + if err == credentials.ErrNoValidProvidersFoundInChain { + return errors.New("Could not find a way to authenticate on AWS!\n" + + "Please refer to AWS documentation: https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html\n\n" + + "To use a different cloud provider, use --to=\"gcp+tf\" for GCP or --to=\"azure+tf\" for Azure.") + } + if err != nil { + return err + } + // This call is to make sure that the credentials are valid + // A more complex logic exist in terraform provider, but it's probably not worth to implement it + // https://github.com/hashicorp/terraform-provider-aws/blob/e3959651092864925045a6044961a73137095798/aws/auth_helpers.go#L111 + _, err = sts.New(p.session).GetCallerIdentity(&sts.GetCallerIdentityInput{}) + if err != nil { + logrus.Debug(err) + return errors.New("Could not authenticate successfully on AWS with the provided credentials.\n" + + "Please refer to the AWS documentation: https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html\n") + } + return nil +} diff --git a/enumeration/remote/aws/rds_cluster_enumerator.go b/enumeration/remote/aws/rds_cluster_enumerator.go new file mode 100644 index 00000000..990b2e8c --- /dev/null +++ b/enumeration/remote/aws/rds_cluster_enumerator.go @@ -0,0 +1,55 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type RDSClusterEnumerator struct { + repository repository.RDSRepository + factory resource.ResourceFactory +} + +func NewRDSClusterEnumerator(repository repository.RDSRepository, factory resource.ResourceFactory) *RDSClusterEnumerator { + return &RDSClusterEnumerator{ + repository, + factory, + } +} + +func (e *RDSClusterEnumerator) SupportedType() resource.ResourceType { + return aws.AwsRDSClusterResourceType +} + +func (e *RDSClusterEnumerator) Enumerate() ([]*resource.Resource, error) { + clusters, err := e.repository.ListAllDBClusters() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(clusters)) + + for _, cluster := range clusters { + var databaseName string + + if v := cluster.DatabaseName; v != nil { + databaseName = *cluster.DatabaseName + } + + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *cluster.DBClusterIdentifier, + map[string]interface{}{ + "cluster_identifier": *cluster.DBClusterIdentifier, + "database_name": databaseName, + }, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/rds_db_instance_enumerator.go b/enumeration/remote/aws/rds_db_instance_enumerator.go new file mode 100644 index 00000000..11df1120 --- /dev/null +++ b/enumeration/remote/aws/rds_db_instance_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type RDSDBInstanceEnumerator struct { + repository repository.RDSRepository + factory resource.ResourceFactory +} + +func NewRDSDBInstanceEnumerator(repo repository.RDSRepository, factory resource.ResourceFactory) *RDSDBInstanceEnumerator { + return &RDSDBInstanceEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *RDSDBInstanceEnumerator) SupportedType() resource.ResourceType { + return aws.AwsDbInstanceResourceType +} + +func (e *RDSDBInstanceEnumerator) Enumerate() ([]*resource.Resource, error) { + instances, err := e.repository.ListAllDBInstances() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(instances)) + + for _, instance := range instances { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *instance.DBInstanceIdentifier, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/rds_db_subnet_group_enumerator.go b/enumeration/remote/aws/rds_db_subnet_group_enumerator.go new file mode 100644 index 00000000..51973fd6 --- /dev/null +++ b/enumeration/remote/aws/rds_db_subnet_group_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type RDSDBSubnetGroupEnumerator struct { + repository repository.RDSRepository + factory resource.ResourceFactory +} + +func NewRDSDBSubnetGroupEnumerator(repo repository.RDSRepository, factory resource.ResourceFactory) *RDSDBSubnetGroupEnumerator { + return &RDSDBSubnetGroupEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *RDSDBSubnetGroupEnumerator) SupportedType() resource.ResourceType { + return aws.AwsDbSubnetGroupResourceType +} + +func (e *RDSDBSubnetGroupEnumerator) Enumerate() ([]*resource.Resource, error) { + subnetGroups, err := e.repository.ListAllDBSubnetGroups() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(subnetGroups)) + + for _, subnetGroup := range subnetGroups { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *subnetGroup.DBSubnetGroupName, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/repository/api_gateway_repository.go b/enumeration/remote/aws/repository/api_gateway_repository.go new file mode 100644 index 00000000..bb69d7dd --- /dev/null +++ b/enumeration/remote/aws/repository/api_gateway_repository.go @@ -0,0 +1,285 @@ +package repository + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/cache" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/apigateway" + "github.com/aws/aws-sdk-go/service/apigateway/apigatewayiface" +) + +type ApiGatewayRepository interface { + ListAllRestApis() ([]*apigateway.RestApi, error) + GetAccount() (*apigateway.Account, error) + ListAllApiKeys() ([]*apigateway.ApiKey, error) + ListAllRestApiAuthorizers(string) ([]*apigateway.Authorizer, error) + ListAllRestApiStages(string) ([]*apigateway.Stage, error) + ListAllRestApiResources(string) ([]*apigateway.Resource, error) + ListAllDomainNames() ([]*apigateway.DomainName, error) + ListAllVpcLinks() ([]*apigateway.UpdateVpcLinkOutput, error) + ListAllRestApiRequestValidators(string) ([]*apigateway.UpdateRequestValidatorOutput, error) + ListAllDomainNameBasePathMappings(string) ([]*apigateway.BasePathMapping, error) + ListAllRestApiModels(string) ([]*apigateway.Model, error) + ListAllRestApiGatewayResponses(string) ([]*apigateway.UpdateGatewayResponseOutput, error) +} + +type apigatewayRepository struct { + client apigatewayiface.APIGatewayAPI + cache cache.Cache +} + +func NewApiGatewayRepository(session *session.Session, c cache.Cache) *apigatewayRepository { + return &apigatewayRepository{ + apigateway.New(session), + c, + } +} + +func (r *apigatewayRepository) ListAllRestApis() ([]*apigateway.RestApi, error) { + cacheKey := "apigatewayListAllRestApis" + v := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + if v != nil { + return v.([]*apigateway.RestApi), nil + } + + var restApis []*apigateway.RestApi + input := apigateway.GetRestApisInput{} + err := r.client.GetRestApisPages(&input, + func(resp *apigateway.GetRestApisOutput, lastPage bool) bool { + restApis = append(restApis, resp.Items...) + return !lastPage + }, + ) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, restApis) + return restApis, nil +} + +func (r *apigatewayRepository) GetAccount() (*apigateway.Account, error) { + if v := r.cache.Get("apigatewayGetAccount"); v != nil { + return v.(*apigateway.Account), nil + } + + account, err := r.client.GetAccount(&apigateway.GetAccountInput{}) + if err != nil { + return nil, err + } + + r.cache.Put("apigatewayGetAccount", account) + return account, nil +} + +func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) { + if v := r.cache.Get("apigatewayListAllApiKeys"); v != nil { + return v.([]*apigateway.ApiKey), nil + } + + var apiKeys []*apigateway.ApiKey + input := apigateway.GetApiKeysInput{} + err := r.client.GetApiKeysPages(&input, + func(resp *apigateway.GetApiKeysOutput, lastPage bool) bool { + apiKeys = append(apiKeys, resp.Items...) + return !lastPage + }, + ) + if err != nil { + return nil, err + } + + r.cache.Put("apigatewayListAllApiKeys", apiKeys) + return apiKeys, nil +} + +func (r *apigatewayRepository) ListAllRestApiAuthorizers(apiId string) ([]*apigateway.Authorizer, error) { + cacheKey := fmt.Sprintf("apigatewayListAllRestApiAuthorizers_api_%s", apiId) + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*apigateway.Authorizer), nil + } + + input := &apigateway.GetAuthorizersInput{ + RestApiId: &apiId, + } + resources, err := r.client.GetAuthorizers(input) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resources.Items) + return resources.Items, nil +} + +func (r *apigatewayRepository) ListAllRestApiStages(apiId string) ([]*apigateway.Stage, error) { + cacheKey := fmt.Sprintf("apigatewayListAllRestApiStages_api_%s", apiId) + v := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + if v != nil { + return v.([]*apigateway.Stage), nil + } + + input := &apigateway.GetStagesInput{ + RestApiId: &apiId, + } + resources, err := r.client.GetStages(input) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resources.Item) + return resources.Item, nil +} + +func (r *apigatewayRepository) ListAllRestApiResources(apiId string) ([]*apigateway.Resource, error) { + cacheKey := fmt.Sprintf("apigatewayListAllRestApiResources_api_%s", apiId) + v := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + if v != nil { + return v.([]*apigateway.Resource), nil + } + + var resources []*apigateway.Resource + input := &apigateway.GetResourcesInput{ + RestApiId: &apiId, + Embed: []*string{aws.String("methods")}, + } + err := r.client.GetResourcesPages(input, func(res *apigateway.GetResourcesOutput, lastPage bool) bool { + resources = append(resources, res.Items...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resources) + return resources, nil +} + +func (r *apigatewayRepository) ListAllDomainNames() ([]*apigateway.DomainName, error) { + cacheKey := "apigatewayListAllDomainNames" + v := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + if v != nil { + return v.([]*apigateway.DomainName), nil + } + + var domainNames []*apigateway.DomainName + input := apigateway.GetDomainNamesInput{} + err := r.client.GetDomainNamesPages(&input, + func(resp *apigateway.GetDomainNamesOutput, lastPage bool) bool { + domainNames = append(domainNames, resp.Items...) + return !lastPage + }, + ) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, domainNames) + return domainNames, nil +} + +func (r *apigatewayRepository) ListAllVpcLinks() ([]*apigateway.UpdateVpcLinkOutput, error) { + if v := r.cache.Get("apigatewayListAllVpcLinks"); v != nil { + return v.([]*apigateway.UpdateVpcLinkOutput), nil + } + + var vpcLinks []*apigateway.UpdateVpcLinkOutput + input := apigateway.GetVpcLinksInput{} + err := r.client.GetVpcLinksPages(&input, + func(resp *apigateway.GetVpcLinksOutput, lastPage bool) bool { + vpcLinks = append(vpcLinks, resp.Items...) + return !lastPage + }, + ) + if err != nil { + return nil, err + } + + r.cache.Put("apigatewayListAllVpcLinks", vpcLinks) + return vpcLinks, nil +} + +func (r *apigatewayRepository) ListAllRestApiRequestValidators(apiId string) ([]*apigateway.UpdateRequestValidatorOutput, error) { + cacheKey := fmt.Sprintf("apigatewayListAllRestApiRequestValidators_api_%s", apiId) + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*apigateway.UpdateRequestValidatorOutput), nil + } + + input := &apigateway.GetRequestValidatorsInput{ + RestApiId: &apiId, + } + resources, err := r.client.GetRequestValidators(input) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resources.Items) + return resources.Items, nil +} + +func (r *apigatewayRepository) ListAllDomainNameBasePathMappings(domainName string) ([]*apigateway.BasePathMapping, error) { + cacheKey := fmt.Sprintf("apigatewayListAllDomainNameBasePathMappings_domainName_%s", domainName) + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*apigateway.BasePathMapping), nil + } + + var mappings []*apigateway.BasePathMapping + input := &apigateway.GetBasePathMappingsInput{ + DomainName: &domainName, + } + err := r.client.GetBasePathMappingsPages(input, func(res *apigateway.GetBasePathMappingsOutput, lastPage bool) bool { + mappings = append(mappings, res.Items...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, mappings) + return mappings, nil +} + +func (r *apigatewayRepository) ListAllRestApiModels(apiId string) ([]*apigateway.Model, error) { + cacheKey := fmt.Sprintf("apigatewayListAllRestApiModels_api_%s", apiId) + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*apigateway.Model), nil + } + + var resources []*apigateway.Model + input := &apigateway.GetModelsInput{ + RestApiId: &apiId, + } + err := r.client.GetModelsPages(input, func(res *apigateway.GetModelsOutput, lastPage bool) bool { + resources = append(resources, res.Items...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resources) + return resources, nil +} + +func (r *apigatewayRepository) ListAllRestApiGatewayResponses(apiId string) ([]*apigateway.UpdateGatewayResponseOutput, error) { + cacheKey := fmt.Sprintf("apigatewayListAllRestApiGatewayResponses_api_%s", apiId) + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*apigateway.UpdateGatewayResponseOutput), nil + } + + input := &apigateway.GetGatewayResponsesInput{ + RestApiId: &apiId, + } + resources, err := r.client.GetGatewayResponses(input) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resources.Items) + return resources.Items, nil +} diff --git a/enumeration/remote/aws/repository/api_gateway_repository_test.go b/enumeration/remote/aws/repository/api_gateway_repository_test.go new file mode 100644 index 00000000..8f452d07 --- /dev/null +++ b/enumeration/remote/aws/repository/api_gateway_repository_test.go @@ -0,0 +1,890 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/apigateway" + "github.com/pkg/errors" + awstest "github.com/snyk/driftctl/test/aws" + + "github.com/stretchr/testify/mock" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" +) + +func Test_apigatewayRepository_ListAllRestApis(t *testing.T) { + apis := []*apigateway.RestApi{ + {Id: aws.String("restapi1")}, + {Id: aws.String("restapi2")}, + {Id: aws.String("restapi3")}, + {Id: aws.String("restapi4")}, + {Id: aws.String("restapi5")}, + {Id: aws.String("restapi6")}, + } + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) + want []*apigateway.RestApi + wantErr error + }{ + { + name: "list multiple rest apis", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetRestApisPages", + &apigateway.GetRestApisInput{}, + mock.MatchedBy(func(callback func(res *apigateway.GetRestApisOutput, lastPage bool) bool) bool { + callback(&apigateway.GetRestApisOutput{ + Items: apis[:3], + }, false) + callback(&apigateway.GetRestApisOutput{ + Items: apis[3:], + }, true) + return true + })).Return(nil).Once() + + store.On("GetAndLock", "apigatewayListAllRestApis").Return(nil).Times(1) + store.On("Unlock", "apigatewayListAllRestApis").Times(1) + store.On("Put", "apigatewayListAllRestApis", apis).Return(false).Times(1) + }, + want: apis, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + store.On("GetAndLock", "apigatewayListAllRestApis").Return(apis).Times(1) + store.On("Unlock", "apigatewayListAllRestApis").Times(1) + }, + want: apis, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGateway{} + tt.mocks(client, store) + r := &apigatewayRepository{ + client: client, + cache: store, + } + got, err := r.ListAllRestApis() + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayRepository_GetAccount(t *testing.T) { + account := &apigateway.Account{ + CloudwatchRoleArn: aws.String("arn:aws:iam::017011014111:role/api_gateway_cloudwatch_global"), + } + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) + want *apigateway.Account + wantErr error + }{ + { + name: "get a single account", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetAccount", &apigateway.GetAccountInput{}).Return(account, nil).Once() + + store.On("Get", "apigatewayGetAccount").Return(nil).Times(1) + store.On("Put", "apigatewayGetAccount", account).Return(false).Times(1) + }, + want: account, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + store.On("Get", "apigatewayGetAccount").Return(account).Times(1) + }, + want: account, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGateway{} + tt.mocks(client, store) + r := &apigatewayRepository{ + client: client, + cache: store, + } + got, err := r.GetAccount() + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayRepository_ListAllApiKeys(t *testing.T) { + keys := []*apigateway.ApiKey{ + {Id: aws.String("apikey1")}, + {Id: aws.String("apikey2")}, + {Id: aws.String("apikey3")}, + {Id: aws.String("apikey4")}, + {Id: aws.String("apikey5")}, + {Id: aws.String("apikey6")}, + } + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) + want []*apigateway.ApiKey + wantErr error + }{ + { + name: "list multiple api keys", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetApiKeysPages", + &apigateway.GetApiKeysInput{}, + mock.MatchedBy(func(callback func(res *apigateway.GetApiKeysOutput, lastPage bool) bool) bool { + callback(&apigateway.GetApiKeysOutput{ + Items: keys[:3], + }, false) + callback(&apigateway.GetApiKeysOutput{ + Items: keys[3:], + }, true) + return true + })).Return(nil).Once() + + store.On("Get", "apigatewayListAllApiKeys").Return(nil).Times(1) + store.On("Put", "apigatewayListAllApiKeys", keys).Return(false).Times(1) + }, + want: keys, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + store.On("Get", "apigatewayListAllApiKeys").Return(keys).Times(1) + }, + want: keys, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGateway{} + tt.mocks(client, store) + r := &apigatewayRepository{ + client: client, + cache: store, + } + got, err := r.ListAllApiKeys() + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayRepository_ListAllRestApiAuthorizers(t *testing.T) { + api := &apigateway.RestApi{ + Id: aws.String("restapi1"), + } + + apiAuthorizers := []*apigateway.Authorizer{ + {Id: aws.String("resource1")}, + {Id: aws.String("resource2")}, + {Id: aws.String("resource3")}, + {Id: aws.String("resource4")}, + } + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) + want []*apigateway.Authorizer + wantErr error + }{ + { + name: "list multiple rest api authorizers", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetAuthorizers", + &apigateway.GetAuthorizersInput{ + RestApiId: aws.String("restapi1"), + }).Return(&apigateway.GetAuthorizersOutput{Items: apiAuthorizers}, nil).Once() + + store.On("Get", "apigatewayListAllRestApiAuthorizers_api_restapi1").Return(nil).Times(1) + store.On("Put", "apigatewayListAllRestApiAuthorizers_api_restapi1", apiAuthorizers).Return(false).Times(1) + }, + want: apiAuthorizers, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + store.On("Get", "apigatewayListAllRestApiAuthorizers_api_restapi1").Return(apiAuthorizers).Times(1) + }, + want: apiAuthorizers, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGateway{} + tt.mocks(client, store) + r := &apigatewayRepository{ + client: client, + cache: store, + } + got, err := r.ListAllRestApiAuthorizers(*api.Id) + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayRepository_ListAllRestApiStages(t *testing.T) { + api := &apigateway.RestApi{ + Id: aws.String("restapi1"), + } + + apiStages := []*apigateway.Stage{ + {StageName: aws.String("stage1")}, + {StageName: aws.String("stage2")}, + {StageName: aws.String("stage3")}, + {StageName: aws.String("stage4")}, + } + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) + want []*apigateway.Stage + wantErr error + }{ + { + name: "list multiple rest api stages", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetStages", + &apigateway.GetStagesInput{ + RestApiId: aws.String("restapi1"), + }).Return(&apigateway.GetStagesOutput{Item: apiStages}, nil).Once() + + store.On("GetAndLock", "apigatewayListAllRestApiStages_api_restapi1").Return(nil).Times(1) + store.On("Unlock", "apigatewayListAllRestApiStages_api_restapi1").Times(1) + store.On("Put", "apigatewayListAllRestApiStages_api_restapi1", apiStages).Return(false).Times(1) + }, + want: apiStages, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + store.On("GetAndLock", "apigatewayListAllRestApiStages_api_restapi1").Return(apiStages).Times(1) + store.On("Unlock", "apigatewayListAllRestApiStages_api_restapi1").Times(1) + }, + want: apiStages, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGateway{} + tt.mocks(client, store) + r := &apigatewayRepository{ + client: client, + cache: store, + } + got, err := r.ListAllRestApiStages(*api.Id) + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayRepository_ListAllRestApiResources(t *testing.T) { + api := &apigateway.RestApi{ + Id: aws.String("restapi1"), + } + + apiResources := []*apigateway.Resource{ + {Id: aws.String("resource1")}, + {Id: aws.String("resource2")}, + {Id: aws.String("resource3")}, + {Id: aws.String("resource4")}, + } + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) + want []*apigateway.Resource + wantErr error + }{ + { + name: "list multiple rest api resources", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetResourcesPages", + &apigateway.GetResourcesInput{ + RestApiId: aws.String("restapi1"), + Embed: []*string{aws.String("methods")}, + }, + mock.MatchedBy(func(callback func(res *apigateway.GetResourcesOutput, lastPage bool) bool) bool { + callback(&apigateway.GetResourcesOutput{ + Items: apiResources, + }, true) + return true + })).Return(nil).Once() + + store.On("GetAndLock", "apigatewayListAllRestApiResources_api_restapi1").Return(nil).Times(1) + store.On("Unlock", "apigatewayListAllRestApiResources_api_restapi1").Times(1) + store.On("Put", "apigatewayListAllRestApiResources_api_restapi1", apiResources).Return(false).Times(1) + }, + want: apiResources, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + store.On("GetAndLock", "apigatewayListAllRestApiResources_api_restapi1").Return(apiResources).Times(1) + store.On("Unlock", "apigatewayListAllRestApiResources_api_restapi1").Times(1) + }, + want: apiResources, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGateway{} + tt.mocks(client, store) + r := &apigatewayRepository{ + client: client, + cache: store, + } + got, err := r.ListAllRestApiResources(*api.Id) + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayRepository_ListAllDomainNames(t *testing.T) { + domainNames := []*apigateway.DomainName{ + {DomainName: aws.String("domainName1")}, + {DomainName: aws.String("domainName2")}, + {DomainName: aws.String("domainName3")}, + {DomainName: aws.String("domainName4")}, + {DomainName: aws.String("domainName5")}, + {DomainName: aws.String("domainName6")}, + } + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) + want []*apigateway.DomainName + wantErr error + }{ + { + name: "list multiple domain names", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetDomainNamesPages", + &apigateway.GetDomainNamesInput{}, + mock.MatchedBy(func(callback func(res *apigateway.GetDomainNamesOutput, lastPage bool) bool) bool { + callback(&apigateway.GetDomainNamesOutput{ + Items: domainNames[:3], + }, false) + callback(&apigateway.GetDomainNamesOutput{ + Items: domainNames[3:], + }, true) + return true + })).Return(nil).Once() + + store.On("GetAndLock", "apigatewayListAllDomainNames").Return(nil).Times(1) + store.On("Unlock", "apigatewayListAllDomainNames").Times(1) + store.On("Put", "apigatewayListAllDomainNames", domainNames).Return(false).Times(1) + }, + want: domainNames, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + store.On("GetAndLock", "apigatewayListAllDomainNames").Return(domainNames).Times(1) + store.On("Unlock", "apigatewayListAllDomainNames").Times(1) + }, + want: domainNames, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGateway{} + tt.mocks(client, store) + r := &apigatewayRepository{ + client: client, + cache: store, + } + got, err := r.ListAllDomainNames() + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayRepository_ListAllVpcLinks(t *testing.T) { + vpcLinks := []*apigateway.UpdateVpcLinkOutput{ + {Id: aws.String("vpcLink1")}, + {Id: aws.String("vpcLink2")}, + {Id: aws.String("vpcLink3")}, + {Id: aws.String("vpcLink4")}, + {Id: aws.String("vpcLink5")}, + {Id: aws.String("vpcLink6")}, + } + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) + want []*apigateway.UpdateVpcLinkOutput + wantErr error + }{ + { + name: "list multiple vpc links", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetVpcLinksPages", + &apigateway.GetVpcLinksInput{}, + mock.MatchedBy(func(callback func(res *apigateway.GetVpcLinksOutput, lastPage bool) bool) bool { + callback(&apigateway.GetVpcLinksOutput{ + Items: vpcLinks[:3], + }, false) + callback(&apigateway.GetVpcLinksOutput{ + Items: vpcLinks[3:], + }, true) + return true + })).Return(nil).Once() + + store.On("Get", "apigatewayListAllVpcLinks").Return(nil).Times(1) + store.On("Put", "apigatewayListAllVpcLinks", vpcLinks).Return(false).Times(1) + }, + want: vpcLinks, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + store.On("Get", "apigatewayListAllVpcLinks").Return(vpcLinks).Times(1) + }, + want: vpcLinks, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGateway{} + tt.mocks(client, store) + r := &apigatewayRepository{ + client: client, + cache: store, + } + got, err := r.ListAllVpcLinks() + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayRepository_ListAllRestApiRequestValidators(t *testing.T) { + api := &apigateway.RestApi{ + Id: aws.String("restapi1"), + } + + requestValidators := []*apigateway.UpdateRequestValidatorOutput{ + {Id: aws.String("reqVal1")}, + {Id: aws.String("reqVal2")}, + {Id: aws.String("reqVal3")}, + {Id: aws.String("reqVal4")}, + } + + remoteError := errors.New("remote error") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) + want []*apigateway.UpdateRequestValidatorOutput + wantErr error + }{ + { + name: "list multiple rest api request validators", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetRequestValidators", + &apigateway.GetRequestValidatorsInput{ + RestApiId: aws.String("restapi1"), + }).Return(&apigateway.GetRequestValidatorsOutput{Items: requestValidators}, nil).Once() + + store.On("Get", "apigatewayListAllRestApiRequestValidators_api_restapi1").Return(nil).Times(1) + store.On("Put", "apigatewayListAllRestApiRequestValidators_api_restapi1", requestValidators).Return(false).Times(1) + }, + want: requestValidators, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + store.On("Get", "apigatewayListAllRestApiRequestValidators_api_restapi1").Return(requestValidators).Times(1) + }, + want: requestValidators, + }, + { + name: "should return remote error", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetRequestValidators", + &apigateway.GetRequestValidatorsInput{ + RestApiId: aws.String("restapi1"), + }).Return(nil, remoteError).Once() + + store.On("Get", "apigatewayListAllRestApiRequestValidators_api_restapi1").Return(nil).Times(1) + }, + wantErr: remoteError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGateway{} + tt.mocks(client, store) + r := &apigatewayRepository{ + client: client, + cache: store, + } + got, err := r.ListAllRestApiRequestValidators(*api.Id) + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayRepository_ListAllDomainNameBasePathMappings(t *testing.T) { + domainName := &apigateway.DomainName{ + DomainName: aws.String("domainName1"), + } + + mappings := []*apigateway.BasePathMapping{ + {BasePath: aws.String("path1")}, + {BasePath: aws.String("path2")}, + {BasePath: aws.String("path3")}, + {BasePath: aws.String("path4")}, + } + + remoteError := errors.New("remote error") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) + want []*apigateway.BasePathMapping + wantErr error + }{ + { + name: "list multiple domain name base path mappings", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetBasePathMappingsPages", + &apigateway.GetBasePathMappingsInput{ + DomainName: aws.String("domainName1"), + }, + mock.MatchedBy(func(callback func(res *apigateway.GetBasePathMappingsOutput, lastPage bool) bool) bool { + callback(&apigateway.GetBasePathMappingsOutput{ + Items: mappings, + }, true) + return true + })).Return(nil).Once() + + store.On("Get", "apigatewayListAllDomainNameBasePathMappings_domainName_domainName1").Return(nil).Times(1) + store.On("Put", "apigatewayListAllDomainNameBasePathMappings_domainName_domainName1", mappings).Return(false).Times(1) + }, + want: mappings, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + store.On("Get", "apigatewayListAllDomainNameBasePathMappings_domainName_domainName1").Return(mappings).Times(1) + }, + want: mappings, + }, + { + name: "should return remote error", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetBasePathMappingsPages", + &apigateway.GetBasePathMappingsInput{ + DomainName: aws.String("domainName1"), + }, mock.AnythingOfType("func(*apigateway.GetBasePathMappingsOutput, bool) bool")).Return(remoteError).Once() + + store.On("Get", "apigatewayListAllDomainNameBasePathMappings_domainName_domainName1").Return(nil).Times(1) + }, + wantErr: remoteError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGateway{} + tt.mocks(client, store) + r := &apigatewayRepository{ + client: client, + cache: store, + } + got, err := r.ListAllDomainNameBasePathMappings(*domainName.DomainName) + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayRepository_ListAllRestApiModels(t *testing.T) { + api := &apigateway.RestApi{ + Id: aws.String("restapi1"), + } + + apiModels := []*apigateway.Model{ + {Id: aws.String("model1")}, + {Id: aws.String("model2")}, + {Id: aws.String("model3")}, + {Id: aws.String("model4")}, + } + + remoteError := errors.New("remote error") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) + want []*apigateway.Model + wantErr error + }{ + { + name: "list multiple rest api models", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetModelsPages", + &apigateway.GetModelsInput{ + RestApiId: aws.String("restapi1"), + }, + mock.MatchedBy(func(callback func(res *apigateway.GetModelsOutput, lastPage bool) bool) bool { + callback(&apigateway.GetModelsOutput{ + Items: apiModels, + }, true) + return true + })).Return(nil).Once() + + store.On("Get", "apigatewayListAllRestApiModels_api_restapi1").Return(nil).Times(1) + store.On("Put", "apigatewayListAllRestApiModels_api_restapi1", apiModels).Return(false).Times(1) + }, + want: apiModels, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + store.On("Get", "apigatewayListAllRestApiModels_api_restapi1").Return(apiModels).Times(1) + }, + want: apiModels, + }, + { + name: "should return remote error", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetModelsPages", + &apigateway.GetModelsInput{ + RestApiId: aws.String("restapi1"), + }, mock.AnythingOfType("func(*apigateway.GetModelsOutput, bool) bool")).Return(remoteError).Once() + + store.On("Get", "apigatewayListAllRestApiModels_api_restapi1").Return(nil).Times(1) + }, + wantErr: remoteError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGateway{} + tt.mocks(client, store) + r := &apigatewayRepository{ + client: client, + cache: store, + } + got, err := r.ListAllRestApiModels(*api.Id) + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayRepository_ListAllRestApiGatewayResponses(t *testing.T) { + api := &apigateway.RestApi{ + Id: aws.String("restapi1"), + } + + gtwResponses := []*apigateway.UpdateGatewayResponseOutput{ + {ResponseType: aws.String("ACCESS_DENIED")}, + {ResponseType: aws.String("DEFAULT_4XX")}, + {ResponseType: aws.String("DEFAULT_5XX")}, + {ResponseType: aws.String("UNAUTHORIZED")}, + } + + remoteError := errors.New("remote error") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) + want []*apigateway.UpdateGatewayResponseOutput + wantErr error + }{ + { + name: "list multiple rest api gateway responses", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetGatewayResponses", + &apigateway.GetGatewayResponsesInput{ + RestApiId: aws.String("restapi1"), + }).Return(&apigateway.GetGatewayResponsesOutput{Items: gtwResponses}, nil).Once() + + store.On("Get", "apigatewayListAllRestApiGatewayResponses_api_restapi1").Return(nil).Times(1) + store.On("Put", "apigatewayListAllRestApiGatewayResponses_api_restapi1", gtwResponses).Return(false).Times(1) + }, + want: gtwResponses, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + store.On("Get", "apigatewayListAllRestApiGatewayResponses_api_restapi1").Return(gtwResponses).Times(1) + }, + want: gtwResponses, + }, + { + name: "should return remote error", + mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { + client.On("GetGatewayResponses", + &apigateway.GetGatewayResponsesInput{ + RestApiId: aws.String("restapi1"), + }).Return(nil, remoteError).Once() + + store.On("Get", "apigatewayListAllRestApiGatewayResponses_api_restapi1").Return(nil).Times(1) + }, + wantErr: remoteError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGateway{} + tt.mocks(client, store) + r := &apigatewayRepository{ + client: client, + cache: store, + } + got, err := r.ListAllRestApiGatewayResponses(*api.Id) + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} diff --git a/enumeration/remote/aws/repository/apigatewayv2_repository.go b/enumeration/remote/aws/repository/apigatewayv2_repository.go new file mode 100644 index 00000000..cb28704d --- /dev/null +++ b/enumeration/remote/aws/repository/apigatewayv2_repository.go @@ -0,0 +1,228 @@ +package repository + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/cache" + + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/apigatewayv2" + "github.com/aws/aws-sdk-go/service/apigatewayv2/apigatewayv2iface" +) + +type ApiGatewayV2Repository interface { + ListAllApis() ([]*apigatewayv2.Api, error) + ListAllApiRoutes(apiId *string) ([]*apigatewayv2.Route, error) + ListAllApiDeployments(apiId *string) ([]*apigatewayv2.Deployment, error) + ListAllVpcLinks() ([]*apigatewayv2.VpcLink, error) + ListAllApiAuthorizers(string) ([]*apigatewayv2.Authorizer, error) + ListAllApiIntegrations(string) ([]*apigatewayv2.Integration, error) + ListAllApiModels(string) ([]*apigatewayv2.Model, error) + ListAllApiStages(string) ([]*apigatewayv2.Stage, error) + ListAllApiRouteResponses(string, string) ([]*apigatewayv2.RouteResponse, error) + ListAllApiMappings(string) ([]*apigatewayv2.ApiMapping, error) + ListAllApiIntegrationResponses(string, string) ([]*apigatewayv2.IntegrationResponse, error) +} +type apigatewayv2Repository struct { + client apigatewayv2iface.ApiGatewayV2API + cache cache.Cache +} + +func NewApiGatewayV2Repository(session *session.Session, c cache.Cache) *apigatewayv2Repository { + return &apigatewayv2Repository{ + apigatewayv2.New(session), + c, + } +} + +func (r *apigatewayv2Repository) ListAllApis() ([]*apigatewayv2.Api, error) { + cacheKey := "apigatewayv2ListAllApis" + v := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + if v != nil { + return v.([]*apigatewayv2.Api), nil + } + + input := apigatewayv2.GetApisInput{} + resources, err := r.client.GetApis(&input) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resources.Items) + return resources.Items, nil +} + +func (r *apigatewayv2Repository) ListAllApiRoutes(apiID *string) ([]*apigatewayv2.Route, error) { + cacheKey := fmt.Sprintf("apigatewayv2ListAllApiRoutes_api_%s", *apiID) + v := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + if v != nil { + return v.([]*apigatewayv2.Route), nil + } + + resources, err := r.client.GetRoutes(&apigatewayv2.GetRoutesInput{ApiId: apiID}) + if err != nil { + return nil, err + } + r.cache.Put(cacheKey, resources.Items) + return resources.Items, nil +} + +func (r *apigatewayv2Repository) ListAllApiDeployments(apiID *string) ([]*apigatewayv2.Deployment, error) { + cacheKey := fmt.Sprintf("apigatewayv2ListAllApiDeployments_api_%s", *apiID) + v := r.cache.Get(cacheKey) + + if v != nil { + return v.([]*apigatewayv2.Deployment), nil + } + + resources, err := r.client.GetDeployments(&apigatewayv2.GetDeploymentsInput{ApiId: apiID}) + if err != nil { + return nil, err + } + r.cache.Put(cacheKey, resources.Items) + return resources.Items, nil +} + +func (r *apigatewayv2Repository) ListAllVpcLinks() ([]*apigatewayv2.VpcLink, error) { + if v := r.cache.Get("apigatewayv2ListAllVpcLinks"); v != nil { + return v.([]*apigatewayv2.VpcLink), nil + } + + input := apigatewayv2.GetVpcLinksInput{} + resources, err := r.client.GetVpcLinks(&input) + if err != nil { + return nil, err + } + + r.cache.Put("apigatewayv2ListAllVpcLinks", resources.Items) + return resources.Items, nil +} + +func (r *apigatewayv2Repository) ListAllApiAuthorizers(apiId string) ([]*apigatewayv2.Authorizer, error) { + cacheKey := fmt.Sprintf("apigatewayv2ListAllApiAuthorizers_api_%s", apiId) + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*apigatewayv2.Authorizer), nil + } + + input := apigatewayv2.GetAuthorizersInput{ + ApiId: &apiId, + } + resources, err := r.client.GetAuthorizers(&input) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resources.Items) + return resources.Items, nil +} + +func (r *apigatewayv2Repository) ListAllApiIntegrations(apiId string) ([]*apigatewayv2.Integration, error) { + cacheKey := fmt.Sprintf("apigatewayv2ListAllApiIntegrations_api_%s", apiId) + + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*apigatewayv2.Integration), nil + } + + input := apigatewayv2.GetIntegrationsInput{ + ApiId: &apiId, + } + resources, err := r.client.GetIntegrations(&input) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resources.Items) + return resources.Items, nil +} + +func (r *apigatewayv2Repository) ListAllApiModels(apiId string) ([]*apigatewayv2.Model, error) { + cacheKey := fmt.Sprintf("apigatewayv2ListAllApiModels_api_%s", apiId) + + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*apigatewayv2.Model), nil + } + + input := apigatewayv2.GetModelsInput{ + ApiId: &apiId, + } + resources, err := r.client.GetModels(&input) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resources.Items) + return resources.Items, nil +} + +func (r *apigatewayv2Repository) ListAllApiStages(apiId string) ([]*apigatewayv2.Stage, error) { + cacheKey := fmt.Sprintf("apigatewayv2ListAllApiStages_api_%s", apiId) + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*apigatewayv2.Stage), nil + } + + input := apigatewayv2.GetStagesInput{ + ApiId: &apiId, + } + resources, err := r.client.GetStages(&input) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resources.Items) + return resources.Items, nil +} + +func (r *apigatewayv2Repository) ListAllApiIntegrationResponses(apiId, integrationId string) ([]*apigatewayv2.IntegrationResponse, error) { + cacheKey := fmt.Sprintf("apigatewayv2ListAllApiIntegrationResponses_api_%s_integration_%s", apiId, integrationId) + v := r.cache.Get(cacheKey) + if v != nil { + return v.([]*apigatewayv2.IntegrationResponse), nil + } + input := apigatewayv2.GetIntegrationResponsesInput{ + ApiId: &apiId, + IntegrationId: &integrationId, + } + resources, err := r.client.GetIntegrationResponses(&input) + if err != nil { + return nil, err + } + r.cache.Put(cacheKey, resources.Items) + return resources.Items, nil +} + +func (r *apigatewayv2Repository) ListAllApiRouteResponses(apiId, routeId string) ([]*apigatewayv2.RouteResponse, error) { + cacheKey := fmt.Sprintf("apigatewayv2ListAllApiRouteResponses_api_%s_route_%s", apiId, routeId) + v := r.cache.Get(cacheKey) + if v != nil { + return v.([]*apigatewayv2.RouteResponse), nil + } + input := apigatewayv2.GetRouteResponsesInput{ + ApiId: &apiId, + RouteId: &routeId, + } + resources, err := r.client.GetRouteResponses(&input) + if err != nil { + return nil, err + } + r.cache.Put(cacheKey, resources.Items) + return resources.Items, nil +} + +func (r *apigatewayv2Repository) ListAllApiMappings(domainName string) ([]*apigatewayv2.ApiMapping, error) { + cacheKey := fmt.Sprintf("apigatewayv2ListAllApiMappings_api_%s", domainName) + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*apigatewayv2.ApiMapping), nil + } + + input := apigatewayv2.GetApiMappingsInput{ + DomainName: &domainName, + } + resources, err := r.client.GetApiMappings(&input) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resources.Items) + return resources.Items, nil +} diff --git a/enumeration/remote/aws/repository/apigatewayv2_repository_test.go b/enumeration/remote/aws/repository/apigatewayv2_repository_test.go new file mode 100644 index 00000000..ea76f53e --- /dev/null +++ b/enumeration/remote/aws/repository/apigatewayv2_repository_test.go @@ -0,0 +1,637 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/apigatewayv2" + "github.com/pkg/errors" + awstest "github.com/snyk/driftctl/test/aws" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" +) + +func Test_apigatewayv2Repository_ListAllApis(t *testing.T) { + apis := []*apigatewayv2.Api{ + {ApiId: aws.String("api1")}, + {ApiId: aws.String("api2")}, + {ApiId: aws.String("api3")}, + {ApiId: aws.String("api4")}, + {ApiId: aws.String("api5")}, + {ApiId: aws.String("api6")}, + } + + remoteError := errors.New("remote error") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) + want []*apigatewayv2.Api + wantErr error + }{ + { + name: "list multiple apis", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetApis", + &apigatewayv2.GetApisInput{}).Return(&apigatewayv2.GetApisOutput{Items: apis}, nil).Once() + + store.On("GetAndLock", "apigatewayv2ListAllApis").Return(nil).Times(1) + store.On("Unlock", "apigatewayv2ListAllApis").Times(1) + store.On("Put", "apigatewayv2ListAllApis", apis).Return(false).Times(1) + }, + want: apis, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + store.On("GetAndLock", "apigatewayv2ListAllApis").Return(apis).Times(1) + store.On("Unlock", "apigatewayv2ListAllApis").Times(1) + }, + want: apis, + }, + { + name: "should return remote error", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetApis", + &apigatewayv2.GetApisInput{}).Return(nil, remoteError).Once() + + store.On("GetAndLock", "apigatewayv2ListAllApis").Return(nil).Times(1) + store.On("Unlock", "apigatewayv2ListAllApis").Times(1) + }, + wantErr: remoteError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGatewayV2{} + tt.mocks(client, store) + r := &apigatewayv2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllApis() + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayv2Repository_ListAllApiRoutes(t *testing.T) { + routes := []*apigatewayv2.Route{ + {RouteId: aws.String("route1")}, + {RouteId: aws.String("route2")}, + {RouteId: aws.String("route3")}, + } + + remoteError := errors.New("remote error") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) + want []*apigatewayv2.Route + wantErr error + }{ + { + name: "list multiple routes", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetRoutes", + &apigatewayv2.GetRoutesInput{ApiId: aws.String("an-id")}). + Return(&apigatewayv2.GetRoutesOutput{Items: routes}, nil).Once() + + store.On("GetAndLock", "apigatewayv2ListAllApiRoutes_api_an-id").Return(nil).Times(1) + store.On("Unlock", "apigatewayv2ListAllApiRoutes_api_an-id").Times(1) + store.On("Put", "apigatewayv2ListAllApiRoutes_api_an-id", routes).Return(false).Times(1) + }, + want: routes, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + store.On("GetAndLock", "apigatewayv2ListAllApiRoutes_api_an-id").Return(routes).Times(1) + store.On("Unlock", "apigatewayv2ListAllApiRoutes_api_an-id").Times(1) + }, + want: routes, + }, + { + name: "should return remote error", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetRoutes", + &apigatewayv2.GetRoutesInput{ApiId: aws.String("an-id")}).Return(nil, remoteError).Once() + + store.On("GetAndLock", "apigatewayv2ListAllApiRoutes_api_an-id").Return(nil).Times(1) + store.On("Unlock", "apigatewayv2ListAllApiRoutes_api_an-id").Times(1) + }, + wantErr: remoteError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGatewayV2{} + tt.mocks(client, store) + r := &apigatewayv2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllApiRoutes(aws.String("an-id")) + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayv2Repository_ListAllApiDeployments(t *testing.T) { + deployments := []*apigatewayv2.Deployment{ + {DeploymentId: aws.String("id1")}, + {DeploymentId: aws.String("id2")}, + {DeploymentId: aws.String("id3")}, + } + + remoteError := errors.New("remote error") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) + want []*apigatewayv2.Deployment + wantErr error + }{ + { + name: "list multiple deployments", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetDeployments", + &apigatewayv2.GetDeploymentsInput{ApiId: aws.String("an-id")}). + Return(&apigatewayv2.GetDeploymentsOutput{Items: deployments}, nil).Once() + + store.On("Get", "apigatewayv2ListAllApiDeployments_api_an-id").Return(nil).Times(1) + store.On("Put", "apigatewayv2ListAllApiDeployments_api_an-id", deployments).Return(false).Times(1) + }, + want: deployments, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + store.On("Get", "apigatewayv2ListAllApiDeployments_api_an-id").Return(deployments).Times(1) + }, + want: deployments, + }, + { + name: "should return remote error", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetDeployments", + &apigatewayv2.GetDeploymentsInput{ApiId: aws.String("an-id")}).Return(nil, remoteError).Once() + + store.On("Get", "apigatewayv2ListAllApiDeployments_api_an-id").Return(nil).Times(1) + }, + wantErr: remoteError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGatewayV2{} + tt.mocks(client, store) + r := &apigatewayv2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllApiDeployments(aws.String("an-id")) + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayv2Repository_ListAllVpcLinks(t *testing.T) { + vpcLinks := []*apigatewayv2.VpcLink{ + {VpcLinkId: aws.String("vpcLink1")}, + {VpcLinkId: aws.String("vpcLink2")}, + {VpcLinkId: aws.String("vpcLink3")}, + {VpcLinkId: aws.String("vpcLink4")}, + {VpcLinkId: aws.String("vpcLink5")}, + {VpcLinkId: aws.String("vpcLink6")}, + } + + remoteError := errors.New("remote error") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) + want []*apigatewayv2.VpcLink + wantErr error + }{ + { + name: "list multiple vpc links", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetVpcLinks", + &apigatewayv2.GetVpcLinksInput{}).Return(&apigatewayv2.GetVpcLinksOutput{Items: vpcLinks}, nil).Once() + + store.On("Get", "apigatewayv2ListAllVpcLinks").Return(nil).Times(1) + store.On("Put", "apigatewayv2ListAllVpcLinks", vpcLinks).Return(false).Times(1) + }, + want: vpcLinks, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + store.On("Get", "apigatewayv2ListAllVpcLinks").Return(vpcLinks).Times(1) + }, + want: vpcLinks, + }, + { + name: "should return remote error", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetVpcLinks", + &apigatewayv2.GetVpcLinksInput{}).Return(nil, remoteError).Once() + + store.On("Get", "apigatewayv2ListAllVpcLinks").Return(nil).Times(1) + }, + wantErr: remoteError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGatewayV2{} + tt.mocks(client, store) + r := &apigatewayv2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllVpcLinks() + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayv2Repository_ListAllApiAuthorizers(t *testing.T) { + api := &apigatewayv2.Api{ + ApiId: aws.String("api1"), + } + + apiAuthorizers := []*apigatewayv2.Authorizer{ + {AuthorizerId: aws.String("authorizer1")}, + {AuthorizerId: aws.String("authorizer2")}, + {AuthorizerId: aws.String("authorizer3")}, + {AuthorizerId: aws.String("authorizer4")}, + } + + remoteError := errors.New("remote error") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) + want []*apigatewayv2.Authorizer + wantErr error + }{ + { + name: "list multiple api authorizers", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetAuthorizers", + &apigatewayv2.GetAuthorizersInput{ + ApiId: aws.String("api1"), + }).Return(&apigatewayv2.GetAuthorizersOutput{Items: apiAuthorizers}, nil).Once() + + store.On("Get", "apigatewayv2ListAllApiAuthorizers_api_api1").Return(nil).Times(1) + store.On("Put", "apigatewayv2ListAllApiAuthorizers_api_api1", apiAuthorizers).Return(false).Times(1) + }, + want: apiAuthorizers, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + store.On("Get", "apigatewayv2ListAllApiAuthorizers_api_api1").Return(apiAuthorizers).Times(1) + }, + want: apiAuthorizers, + }, + { + name: "should return remote error", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetAuthorizers", + &apigatewayv2.GetAuthorizersInput{ + ApiId: aws.String("api1"), + }).Return(nil, remoteError).Once() + + store.On("Get", "apigatewayv2ListAllApiAuthorizers_api_api1").Return(nil).Times(1) + }, + wantErr: remoteError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGatewayV2{} + tt.mocks(client, store) + r := &apigatewayv2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllApiAuthorizers(*api.ApiId) + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayv2Repository_ListAllApiIntegrations(t *testing.T) { + api := &apigatewayv2.Api{ + ApiId: aws.String("api1"), + } + + apiIntegrations := []*apigatewayv2.Integration{ + {IntegrationId: aws.String("integration1")}, + {IntegrationId: aws.String("integration2")}, + {IntegrationId: aws.String("integration3")}, + {IntegrationId: aws.String("integration4")}, + } + + remoteError := errors.New("remote error") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) + want []*apigatewayv2.Integration + wantErr error + }{ + { + name: "list multiple api integrations", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetIntegrations", + &apigatewayv2.GetIntegrationsInput{ + ApiId: aws.String("api1"), + }).Return(&apigatewayv2.GetIntegrationsOutput{Items: apiIntegrations}, nil).Once() + + store.On("Get", "apigatewayv2ListAllApiIntegrations_api_api1").Return(nil).Times(1) + store.On("Put", "apigatewayv2ListAllApiIntegrations_api_api1", apiIntegrations).Return(false).Times(1) + }, + want: apiIntegrations, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + store.On("Get", "apigatewayv2ListAllApiIntegrations_api_api1").Return(apiIntegrations).Times(1) + }, + want: apiIntegrations, + }, + { + name: "should return remote error", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetIntegrations", + &apigatewayv2.GetIntegrationsInput{ + ApiId: aws.String("api1"), + }).Return(nil, remoteError).Once() + + store.On("Get", "apigatewayv2ListAllApiIntegrations_api_api1").Return(nil).Times(1) + }, + wantErr: remoteError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGatewayV2{} + tt.mocks(client, store) + r := &apigatewayv2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllApiIntegrations(*api.ApiId) + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayv2Repository_ListAllApiRouteResponses(t *testing.T) { + api := &apigatewayv2.Api{ + ApiId: aws.String("api1"), + } + + route := &apigatewayv2.Route{ + RouteId: aws.String("route1"), + } + + responses := []*apigatewayv2.RouteResponse{ + {RouteResponseId: aws.String("response1")}, + {RouteResponseId: aws.String("response2")}, + {RouteResponseId: aws.String("response3")}, + {RouteResponseId: aws.String("response4")}, + } + + remoteError := errors.New("remote error") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) + want []*apigatewayv2.RouteResponse + wantErr error + }{ + { + name: "list multiple api route responses", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetRouteResponses", + &apigatewayv2.GetRouteResponsesInput{ + ApiId: aws.String("api1"), + RouteId: aws.String("route1"), + }).Return(&apigatewayv2.GetRouteResponsesOutput{Items: responses}, nil).Once() + + store.On("Get", "apigatewayv2ListAllApiRouteResponses_api_api1_route_route1").Return(nil).Times(1) + store.On("Put", "apigatewayv2ListAllApiRouteResponses_api_api1_route_route1", responses).Return(false).Times(1) + }, + want: responses, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + store.On("Get", "apigatewayv2ListAllApiRouteResponses_api_api1_route_route1").Return(responses).Times(1) + }, + want: responses, + }, + { + name: "should return remote error", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetRouteResponses", + &apigatewayv2.GetRouteResponsesInput{ + ApiId: aws.String("api1"), + RouteId: aws.String("route1"), + }).Return(nil, remoteError).Once() + + store.On("Get", "apigatewayv2ListAllApiRouteResponses_api_api1_route_route1").Return(nil).Times(1) + }, + wantErr: remoteError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGatewayV2{} + tt.mocks(client, store) + r := &apigatewayv2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllApiRouteResponses(*api.ApiId, *route.RouteId) + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} + +func Test_apigatewayv2Repository_ListAllApiIntegrationResponses(t *testing.T) { + api := &apigatewayv2.Api{ + ApiId: aws.String("api1"), + } + + integration := &apigatewayv2.Integration{ + IntegrationId: aws.String("integration1"), + } + + responses := []*apigatewayv2.IntegrationResponse{ + {IntegrationResponseId: aws.String("response1")}, + {IntegrationResponseId: aws.String("response2")}, + {IntegrationResponseId: aws.String("response3")}, + {IntegrationResponseId: aws.String("response4")}, + } + + remoteError := errors.New("remote error") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) + want []*apigatewayv2.IntegrationResponse + wantErr error + }{ + { + name: "list multiple api integration responses", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetIntegrationResponses", + &apigatewayv2.GetIntegrationResponsesInput{ + ApiId: aws.String("api1"), + IntegrationId: aws.String("integration1"), + }).Return(&apigatewayv2.GetIntegrationResponsesOutput{Items: responses}, nil).Once() + + store.On("Get", "apigatewayv2ListAllApiIntegrationResponses_api_api1_integration_integration1").Return(nil).Times(1) + store.On("Put", "apigatewayv2ListAllApiIntegrationResponses_api_api1_integration_integration1", responses).Return(false).Times(1) + }, + want: responses, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + store.On("Get", "apigatewayv2ListAllApiIntegrationResponses_api_api1_integration_integration1").Return(responses).Times(1) + }, + want: responses, + }, + { + name: "should return remote error", + mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { + client.On("GetIntegrationResponses", + &apigatewayv2.GetIntegrationResponsesInput{ + ApiId: aws.String("api1"), + IntegrationId: aws.String("integration1"), + }).Return(nil, remoteError).Once() + + store.On("Get", "apigatewayv2ListAllApiIntegrationResponses_api_api1_integration_integration1").Return(nil).Times(1) + }, + wantErr: remoteError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApiGatewayV2{} + tt.mocks(client, store) + r := &apigatewayv2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllApiIntegrationResponses(*api.ApiId, *integration.IntegrationId) + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} diff --git a/enumeration/remote/aws/repository/appautoscaling_repository.go b/enumeration/remote/aws/repository/appautoscaling_repository.go new file mode 100644 index 00000000..60155999 --- /dev/null +++ b/enumeration/remote/aws/repository/appautoscaling_repository.go @@ -0,0 +1,87 @@ +package repository + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/cache" + + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/applicationautoscaling" + "github.com/aws/aws-sdk-go/service/applicationautoscaling/applicationautoscalingiface" +) + +type AppAutoScalingRepository interface { + ServiceNamespaceValues() []string + DescribeScalableTargets(string) ([]*applicationautoscaling.ScalableTarget, error) + DescribeScalingPolicies(string) ([]*applicationautoscaling.ScalingPolicy, error) + DescribeScheduledActions(string) ([]*applicationautoscaling.ScheduledAction, error) +} + +type appAutoScalingRepository struct { + client applicationautoscalingiface.ApplicationAutoScalingAPI + cache cache.Cache +} + +func NewAppAutoScalingRepository(session *session.Session, c cache.Cache) *appAutoScalingRepository { + return &appAutoScalingRepository{ + applicationautoscaling.New(session), + c, + } +} + +func (r *appAutoScalingRepository) ServiceNamespaceValues() []string { + return applicationautoscaling.ServiceNamespace_Values() +} + +func (r *appAutoScalingRepository) DescribeScalableTargets(namespace string) ([]*applicationautoscaling.ScalableTarget, error) { + cacheKey := fmt.Sprintf("appAutoScalingDescribeScalableTargets_%s", namespace) + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*applicationautoscaling.ScalableTarget), nil + } + + input := &applicationautoscaling.DescribeScalableTargetsInput{ + ServiceNamespace: &namespace, + } + result, err := r.client.DescribeScalableTargets(input) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, result.ScalableTargets) + return result.ScalableTargets, nil +} + +func (r *appAutoScalingRepository) DescribeScalingPolicies(namespace string) ([]*applicationautoscaling.ScalingPolicy, error) { + cacheKey := fmt.Sprintf("appAutoScalingDescribeScalingPolicies_%s", namespace) + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*applicationautoscaling.ScalingPolicy), nil + } + + input := &applicationautoscaling.DescribeScalingPoliciesInput{ + ServiceNamespace: &namespace, + } + result, err := r.client.DescribeScalingPolicies(input) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, result.ScalingPolicies) + return result.ScalingPolicies, nil +} + +func (r *appAutoScalingRepository) DescribeScheduledActions(namespace string) ([]*applicationautoscaling.ScheduledAction, error) { + cacheKey := fmt.Sprintf("appAutoScalingDescribeScheduledActions_%s", namespace) + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*applicationautoscaling.ScheduledAction), nil + } + + input := &applicationautoscaling.DescribeScheduledActionsInput{ + ServiceNamespace: &namespace, + } + result, err := r.client.DescribeScheduledActions(input) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, result.ScheduledActions) + return result.ScheduledActions, nil +} diff --git a/enumeration/remote/aws/repository/appautoscaling_repository_test.go b/enumeration/remote/aws/repository/appautoscaling_repository_test.go new file mode 100644 index 00000000..bc41cb04 --- /dev/null +++ b/enumeration/remote/aws/repository/appautoscaling_repository_test.go @@ -0,0 +1,342 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/applicationautoscaling" + "github.com/pkg/errors" + awstest "github.com/snyk/driftctl/test/aws" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" +) + +func Test_appautoscalingRepository_DescribeScalableTargets(t *testing.T) { + type args struct { + namespace string + } + + tests := []struct { + name string + args args + mocks func(*awstest.MockFakeApplicationAutoScaling, *cache.MockCache) + want []*applicationautoscaling.ScalableTarget + wantErr error + }{ + { + name: "should return remote error", + args: args{ + namespace: "test", + }, + mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { + client.On("DescribeScalableTargets", + &applicationautoscaling.DescribeScalableTargetsInput{ + ServiceNamespace: aws.String("test"), + }).Return(nil, errors.New("remote error")).Once() + + c.On("Get", "appAutoScalingDescribeScalableTargets_test").Return(nil).Once() + }, + want: nil, + wantErr: errors.New("remote error"), + }, + { + name: "should return scalable targets", + args: args{ + namespace: "test", + }, + mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { + results := []*applicationautoscaling.ScalableTarget{ + { + RoleARN: aws.String("test_target"), + }, + } + + client.On("DescribeScalableTargets", + &applicationautoscaling.DescribeScalableTargetsInput{ + ServiceNamespace: aws.String("test"), + }).Return(&applicationautoscaling.DescribeScalableTargetsOutput{ + ScalableTargets: results, + }, nil).Once() + + c.On("Get", "appAutoScalingDescribeScalableTargets_test").Return(nil).Once() + c.On("Put", "appAutoScalingDescribeScalableTargets_test", results).Return(true).Once() + }, + want: []*applicationautoscaling.ScalableTarget{ + { + RoleARN: aws.String("test_target"), + }, + }, + }, + { + name: "should hit cache return scalable targets", + args: args{ + namespace: "test", + }, + mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { + results := []*applicationautoscaling.ScalableTarget{ + { + RoleARN: aws.String("test_target"), + }, + } + + c.On("Get", "appAutoScalingDescribeScalableTargets_test").Return(results).Once() + }, + want: []*applicationautoscaling.ScalableTarget{ + { + RoleARN: aws.String("test_target"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApplicationAutoScaling{} + tt.mocks(client, store) + + r := &appAutoScalingRepository{ + client: client, + cache: store, + } + got, err := r.DescribeScalableTargets(tt.args.namespace) + if err != nil { + assert.EqualError(t, tt.wantErr, err.Error()) + } else { + assert.Equal(t, tt.wantErr, err) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + + client.AssertExpectations(t) + store.AssertExpectations(t) + }) + } +} + +func Test_appautoscalingRepository_DescribeScalingPolicies(t *testing.T) { + type args struct { + namespace string + } + + tests := []struct { + name string + args args + mocks func(*awstest.MockFakeApplicationAutoScaling, *cache.MockCache) + want []*applicationautoscaling.ScalingPolicy + wantErr error + }{ + { + name: "should return remote error", + args: args{ + namespace: "test", + }, + mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { + client.On("DescribeScalingPolicies", + &applicationautoscaling.DescribeScalingPoliciesInput{ + ServiceNamespace: aws.String("test"), + }).Return(nil, errors.New("remote error")).Once() + + c.On("Get", "appAutoScalingDescribeScalingPolicies_test").Return(nil).Once() + }, + want: nil, + wantErr: errors.New("remote error"), + }, + { + name: "should return scaling policies", + args: args{ + namespace: "test", + }, + mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { + results := []*applicationautoscaling.ScalingPolicy{ + { + PolicyARN: aws.String("test_policy"), + }, + } + + client.On("DescribeScalingPolicies", + &applicationautoscaling.DescribeScalingPoliciesInput{ + ServiceNamespace: aws.String("test"), + }).Return(&applicationautoscaling.DescribeScalingPoliciesOutput{ + ScalingPolicies: results, + }, nil).Once() + + c.On("Get", "appAutoScalingDescribeScalingPolicies_test").Return(nil).Once() + c.On("Put", "appAutoScalingDescribeScalingPolicies_test", results).Return(true).Once() + }, + want: []*applicationautoscaling.ScalingPolicy{ + { + PolicyARN: aws.String("test_policy"), + }, + }, + }, + { + name: "should hit cache return scaling policies", + args: args{ + namespace: "test", + }, + mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { + results := []*applicationautoscaling.ScalingPolicy{ + { + PolicyARN: aws.String("test_policy"), + }, + } + + c.On("Get", "appAutoScalingDescribeScalingPolicies_test").Return(results).Once() + }, + want: []*applicationautoscaling.ScalingPolicy{ + { + PolicyARN: aws.String("test_policy"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApplicationAutoScaling{} + tt.mocks(client, store) + + r := &appAutoScalingRepository{ + client: client, + cache: store, + } + got, err := r.DescribeScalingPolicies(tt.args.namespace) + if err != nil { + assert.EqualError(t, tt.wantErr, err.Error()) + } else { + assert.Equal(t, tt.wantErr, err) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + + client.AssertExpectations(t) + store.AssertExpectations(t) + }) + } +} + +func Test_appautoscalingRepository_DescribeScheduledActions(t *testing.T) { + type args struct { + namespace string + } + + tests := []struct { + name string + args args + mocks func(*awstest.MockFakeApplicationAutoScaling, *cache.MockCache) + want []*applicationautoscaling.ScheduledAction + wantErr error + }{ + { + name: "should return remote error", + args: args{ + namespace: "test", + }, + mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { + client.On("DescribeScheduledActions", + &applicationautoscaling.DescribeScheduledActionsInput{ + ServiceNamespace: aws.String("test"), + }).Return(nil, errors.New("remote error")).Once() + + c.On("Get", "appAutoScalingDescribeScheduledActions_test").Return(nil).Once() + }, + want: nil, + wantErr: errors.New("remote error"), + }, + { + name: "should return scheduled actions", + args: args{ + namespace: "test", + }, + mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { + results := []*applicationautoscaling.ScheduledAction{ + { + ResourceId: aws.String("test"), + }, + } + + client.On("DescribeScheduledActions", + &applicationautoscaling.DescribeScheduledActionsInput{ + ServiceNamespace: aws.String("test"), + }).Return(&applicationautoscaling.DescribeScheduledActionsOutput{ + ScheduledActions: results, + }, nil).Once() + + c.On("Get", "appAutoScalingDescribeScheduledActions_test").Return(nil).Once() + c.On("Put", "appAutoScalingDescribeScheduledActions_test", results).Return(true).Once() + }, + want: []*applicationautoscaling.ScheduledAction{ + { + ResourceId: aws.String("test"), + }, + }, + }, + { + name: "should hit cache return scheduled actions", + args: args{ + namespace: "test", + }, + mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { + results := []*applicationautoscaling.ScheduledAction{ + { + ResourceId: aws.String("test"), + }, + } + + c.On("Get", "appAutoScalingDescribeScheduledActions_test").Return(results).Once() + }, + want: []*applicationautoscaling.ScheduledAction{ + { + ResourceId: aws.String("test"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeApplicationAutoScaling{} + tt.mocks(client, store) + + r := &appAutoScalingRepository{ + client: client, + cache: store, + } + got, err := r.DescribeScheduledActions(tt.args.namespace) + if err != nil { + assert.EqualError(t, tt.wantErr, err.Error()) + } else { + assert.Equal(t, tt.wantErr, err) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + + client.AssertExpectations(t) + store.AssertExpectations(t) + }) + } +} diff --git a/enumeration/remote/aws/repository/autoscaling_repository.go b/enumeration/remote/aws/repository/autoscaling_repository.go new file mode 100644 index 00000000..b00da44c --- /dev/null +++ b/enumeration/remote/aws/repository/autoscaling_repository.go @@ -0,0 +1,44 @@ +package repository + +import ( + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/autoscaling" + "github.com/aws/aws-sdk-go/service/autoscaling/autoscalingiface" + "github.com/snyk/driftctl/enumeration/remote/cache" +) + +type AutoScalingRepository interface { + DescribeLaunchConfigurations() ([]*autoscaling.LaunchConfiguration, error) +} + +type autoScalingRepository struct { + client autoscalingiface.AutoScalingAPI + cache cache.Cache +} + +func NewAutoScalingRepository(session *session.Session, c cache.Cache) *autoScalingRepository { + return &autoScalingRepository{ + autoscaling.New(session), + c, + } +} + +func (r *autoScalingRepository) DescribeLaunchConfigurations() ([]*autoscaling.LaunchConfiguration, error) { + cacheKey := "DescribeLaunchConfigurations" + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*autoscaling.LaunchConfiguration), nil + } + + var results []*autoscaling.LaunchConfiguration + input := &autoscaling.DescribeLaunchConfigurationsInput{} + err := r.client.DescribeLaunchConfigurationsPages(input, func(resp *autoscaling.DescribeLaunchConfigurationsOutput, lastPage bool) bool { + results = append(results, resp.LaunchConfigurations...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, results) + return results, nil +} diff --git a/enumeration/remote/aws/repository/autoscaling_repository_test.go b/enumeration/remote/aws/repository/autoscaling_repository_test.go new file mode 100644 index 00000000..fad155b7 --- /dev/null +++ b/enumeration/remote/aws/repository/autoscaling_repository_test.go @@ -0,0 +1,104 @@ +package repository + +import ( + "errors" + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/service/autoscaling" + awstest "github.com/snyk/driftctl/test/aws" + "github.com/stretchr/testify/mock" + + "github.com/aws/aws-sdk-go/aws" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" +) + +func Test_AutoscalingRepository_DescribeLaunchConfigurations(t *testing.T) { + dummryError := errors.New("dummy error") + + expectedLaunchConfigurations := []*autoscaling.LaunchConfiguration{ + {ImageId: aws.String("1")}, + {ImageId: aws.String("2")}, + {ImageId: aws.String("3")}, + {ImageId: aws.String("4")}, + } + + tests := []struct { + name string + mocks func(*awstest.MockFakeAutoscaling, *cache.MockCache) + want []*autoscaling.LaunchConfiguration + wantErr error + }{ + { + name: "List all launch configurations", + mocks: func(client *awstest.MockFakeAutoscaling, store *cache.MockCache) { + store.On("Get", "DescribeLaunchConfigurations").Return(nil).Once() + + client.On("DescribeLaunchConfigurationsPages", + &autoscaling.DescribeLaunchConfigurationsInput{}, + mock.MatchedBy(func(callback func(res *autoscaling.DescribeLaunchConfigurationsOutput, lastPage bool) bool) bool { + callback(&autoscaling.DescribeLaunchConfigurationsOutput{ + LaunchConfigurations: expectedLaunchConfigurations[:2], + }, false) + callback(&autoscaling.DescribeLaunchConfigurationsOutput{ + LaunchConfigurations: expectedLaunchConfigurations[2:], + }, true) + return true + })).Return(nil).Once() + + store.On("Put", "DescribeLaunchConfigurations", expectedLaunchConfigurations).Return(false).Once() + }, + want: expectedLaunchConfigurations, + }, + { + name: "Hit cache and list all launch configurations", + mocks: func(client *awstest.MockFakeAutoscaling, store *cache.MockCache) { + store.On("Get", "DescribeLaunchConfigurations").Return(expectedLaunchConfigurations).Once() + }, + want: expectedLaunchConfigurations, + }, + { + name: "Error listing all launch configurations", + mocks: func(client *awstest.MockFakeAutoscaling, store *cache.MockCache) { + store.On("Get", "DescribeLaunchConfigurations").Return(nil).Once() + + client.On("DescribeLaunchConfigurationsPages", &autoscaling.DescribeLaunchConfigurationsInput{}, mock.MatchedBy(func(callback func(res *autoscaling.DescribeLaunchConfigurationsOutput, lastPage bool) bool) bool { + callback(&autoscaling.DescribeLaunchConfigurationsOutput{ + LaunchConfigurations: []*autoscaling.LaunchConfiguration{}, + }, true) + return true + })).Return(dummryError).Once() + }, + want: nil, + wantErr: dummryError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeAutoscaling{} + tt.mocks(client, store) + r := &autoScalingRepository{ + client: client, + cache: store, + } + got, err := r.DescribeLaunchConfigurations() + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} diff --git a/enumeration/remote/aws/repository/cloudformation_repository.go b/enumeration/remote/aws/repository/cloudformation_repository.go new file mode 100644 index 00000000..e0aa1dd0 --- /dev/null +++ b/enumeration/remote/aws/repository/cloudformation_repository.go @@ -0,0 +1,47 @@ +package repository + +import ( + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/cloudformation" + "github.com/aws/aws-sdk-go/service/cloudformation/cloudformationiface" + "github.com/snyk/driftctl/enumeration/remote/cache" +) + +type CloudformationRepository interface { + ListAllStacks() ([]*cloudformation.Stack, error) +} + +type cloudformationRepository struct { + client cloudformationiface.CloudFormationAPI + cache cache.Cache +} + +func NewCloudformationRepository(session *session.Session, c cache.Cache) *cloudformationRepository { + return &cloudformationRepository{ + cloudformation.New(session), + c, + } +} + +func (r *cloudformationRepository) ListAllStacks() ([]*cloudformation.Stack, error) { + if v := r.cache.Get("cloudformationListAllStacks"); v != nil { + return v.([]*cloudformation.Stack), nil + } + + var stacks []*cloudformation.Stack + input := cloudformation.DescribeStacksInput{} + err := r.client.DescribeStacksPages(&input, + func(resp *cloudformation.DescribeStacksOutput, lastPage bool) bool { + if resp.Stacks != nil { + stacks = append(stacks, resp.Stacks...) + } + return !lastPage + }, + ) + if err != nil { + return nil, err + } + + r.cache.Put("cloudformationListAllStacks", stacks) + return stacks, nil +} diff --git a/enumeration/remote/aws/repository/cloudformation_repository_test.go b/enumeration/remote/aws/repository/cloudformation_repository_test.go new file mode 100644 index 00000000..3f2dd747 --- /dev/null +++ b/enumeration/remote/aws/repository/cloudformation_repository_test.go @@ -0,0 +1,86 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/cloudformation" + awstest "github.com/snyk/driftctl/test/aws" + + "github.com/stretchr/testify/mock" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" +) + +func Test_cloudformationRepository_ListAllStacks(t *testing.T) { + stacks := []*cloudformation.Stack{ + {StackId: aws.String("stack1")}, + {StackId: aws.String("stack2")}, + {StackId: aws.String("stack3")}, + {StackId: aws.String("stack4")}, + {StackId: aws.String("stack5")}, + {StackId: aws.String("stack6")}, + } + + tests := []struct { + name string + mocks func(client *awstest.MockFakeCloudformation, store *cache.MockCache) + want []*cloudformation.Stack + wantErr error + }{ + { + name: "list multiple stacks", + mocks: func(client *awstest.MockFakeCloudformation, store *cache.MockCache) { + client.On("DescribeStacksPages", + &cloudformation.DescribeStacksInput{}, + mock.MatchedBy(func(callback func(res *cloudformation.DescribeStacksOutput, lastPage bool) bool) bool { + callback(&cloudformation.DescribeStacksOutput{ + Stacks: stacks[:3], + }, false) + callback(&cloudformation.DescribeStacksOutput{ + Stacks: stacks[3:], + }, true) + return true + })).Return(nil).Once() + + store.On("Get", "cloudformationListAllStacks").Return(nil).Times(1) + store.On("Put", "cloudformationListAllStacks", stacks).Return(false).Times(1) + }, + want: stacks, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeCloudformation, store *cache.MockCache) { + store.On("Get", "cloudformationListAllStacks").Return(stacks).Times(1) + }, + want: stacks, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeCloudformation{} + tt.mocks(client, store) + r := &cloudformationRepository{ + client: client, + cache: store, + } + got, err := r.ListAllStacks() + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} diff --git a/enumeration/remote/aws/repository/cloudfront_repository.go b/enumeration/remote/aws/repository/cloudfront_repository.go new file mode 100644 index 00000000..bb7d0ec7 --- /dev/null +++ b/enumeration/remote/aws/repository/cloudfront_repository.go @@ -0,0 +1,47 @@ +package repository + +import ( + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/cloudfront" + "github.com/aws/aws-sdk-go/service/cloudfront/cloudfrontiface" + "github.com/snyk/driftctl/enumeration/remote/cache" +) + +type CloudfrontRepository interface { + ListAllDistributions() ([]*cloudfront.DistributionSummary, error) +} + +type cloudfrontRepository struct { + client cloudfrontiface.CloudFrontAPI + cache cache.Cache +} + +func NewCloudfrontRepository(session *session.Session, c cache.Cache) *cloudfrontRepository { + return &cloudfrontRepository{ + cloudfront.New(session), + c, + } +} + +func (r *cloudfrontRepository) ListAllDistributions() ([]*cloudfront.DistributionSummary, error) { + if v := r.cache.Get("cloudfrontListAllDistributions"); v != nil { + return v.([]*cloudfront.DistributionSummary), nil + } + + var distributions []*cloudfront.DistributionSummary + input := cloudfront.ListDistributionsInput{} + err := r.client.ListDistributionsPages(&input, + func(resp *cloudfront.ListDistributionsOutput, lastPage bool) bool { + if resp.DistributionList != nil { + distributions = append(distributions, resp.DistributionList.Items...) + } + return !lastPage + }, + ) + if err != nil { + return nil, err + } + + r.cache.Put("cloudfrontListAllDistributions", distributions) + return distributions, nil +} diff --git a/enumeration/remote/aws/repository/cloudfront_repository_test.go b/enumeration/remote/aws/repository/cloudfront_repository_test.go new file mode 100644 index 00000000..4c39d437 --- /dev/null +++ b/enumeration/remote/aws/repository/cloudfront_repository_test.go @@ -0,0 +1,92 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/cloudfront" + awstest "github.com/snyk/driftctl/test/aws" + + "github.com/stretchr/testify/mock" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" +) + +func Test_cloudfrontRepository_ListAllDistributions(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeCloudFront) + want []*cloudfront.DistributionSummary + wantErr error + }{ + { + name: "list multiple distributions", + mocks: func(client *awstest.MockFakeCloudFront) { + client.On("ListDistributionsPages", + &cloudfront.ListDistributionsInput{}, + mock.MatchedBy(func(callback func(res *cloudfront.ListDistributionsOutput, lastPage bool) bool) bool { + callback(&cloudfront.ListDistributionsOutput{ + DistributionList: &cloudfront.DistributionList{ + Items: []*cloudfront.DistributionSummary{ + {Id: aws.String("distribution1")}, + {Id: aws.String("distribution2")}, + {Id: aws.String("distribution3")}, + }, + }, + }, false) + callback(&cloudfront.ListDistributionsOutput{ + DistributionList: &cloudfront.DistributionList{ + Items: []*cloudfront.DistributionSummary{ + {Id: aws.String("distribution4")}, + {Id: aws.String("distribution5")}, + {Id: aws.String("distribution6")}, + }, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*cloudfront.DistributionSummary{ + {Id: aws.String("distribution1")}, + {Id: aws.String("distribution2")}, + {Id: aws.String("distribution3")}, + {Id: aws.String("distribution4")}, + {Id: aws.String("distribution5")}, + {Id: aws.String("distribution6")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := awstest.MockFakeCloudFront{} + tt.mocks(&client) + r := &cloudfrontRepository{ + client: &client, + cache: store, + } + got, err := r.ListAllDistributions() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllDistributions() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*cloudfront.DistributionSummary{}, store.Get("cloudfrontListAllDistributions")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} diff --git a/enumeration/remote/aws/repository/dynamodb_repository.go b/enumeration/remote/aws/repository/dynamodb_repository.go new file mode 100644 index 00000000..42efe29a --- /dev/null +++ b/enumeration/remote/aws/repository/dynamodb_repository.go @@ -0,0 +1,43 @@ +package repository + +import ( + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" + "github.com/snyk/driftctl/enumeration/remote/cache" +) + +type DynamoDBRepository interface { + ListAllTables() ([]*string, error) +} + +type dynamoDBRepository struct { + client dynamodbiface.DynamoDBAPI + cache cache.Cache +} + +func NewDynamoDBRepository(session *session.Session, c cache.Cache) *dynamoDBRepository { + return &dynamoDBRepository{ + dynamodb.New(session), + c, + } +} + +func (r *dynamoDBRepository) ListAllTables() ([]*string, error) { + if v := r.cache.Get("dynamodbListAllTables"); v != nil { + return v.([]*string), nil + } + + var tables []*string + input := &dynamodb.ListTablesInput{} + err := r.client.ListTablesPages(input, func(res *dynamodb.ListTablesOutput, lastPage bool) bool { + tables = append(tables, res.TableNames...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put("dynamodbListAllTables", tables) + return tables, nil +} diff --git a/enumeration/remote/aws/repository/dynamodb_repository_test.go b/enumeration/remote/aws/repository/dynamodb_repository_test.go new file mode 100644 index 00000000..3b89046d --- /dev/null +++ b/enumeration/remote/aws/repository/dynamodb_repository_test.go @@ -0,0 +1,89 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/dynamodb" + awstest "github.com/snyk/driftctl/test/aws" + + "github.com/stretchr/testify/mock" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" +) + +func Test_dynamoDBRepository_ListAllTopics(t *testing.T) { + + tests := []struct { + name string + mocks func(client *awstest.MockFakeDynamoDB) + want []*string + wantErr error + }{ + { + name: "List with 2 pages", + mocks: func(client *awstest.MockFakeDynamoDB) { + client.On("ListTablesPages", + &dynamodb.ListTablesInput{}, + mock.MatchedBy(func(callback func(res *dynamodb.ListTablesOutput, lastPage bool) bool) bool { + callback(&dynamodb.ListTablesOutput{ + TableNames: []*string{ + aws.String("1"), + aws.String("2"), + aws.String("3"), + }, + }, false) + callback(&dynamodb.ListTablesOutput{ + TableNames: []*string{ + aws.String("4"), + aws.String("5"), + aws.String("6"), + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*string{ + aws.String("1"), + aws.String("2"), + aws.String("3"), + aws.String("4"), + aws.String("5"), + aws.String("6"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := awstest.MockFakeDynamoDB{} + tt.mocks(&client) + r := &dynamoDBRepository{ + client: &client, + cache: store, + } + got, err := r.ListAllTables() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllTables() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*string{}, store.Get("dynamodbListAllTables")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} diff --git a/enumeration/remote/aws/repository/ec2_repository.go b/enumeration/remote/aws/repository/ec2_repository.go new file mode 100644 index 00000000..725c11e7 --- /dev/null +++ b/enumeration/remote/aws/repository/ec2_repository.go @@ -0,0 +1,408 @@ +package repository + +import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/snyk/driftctl/enumeration/remote/cache" +) + +type EC2Repository interface { + ListAllImages() ([]*ec2.Image, error) + ListAllSnapshots() ([]*ec2.Snapshot, error) + ListAllVolumes() ([]*ec2.Volume, error) + ListAllAddresses() ([]*ec2.Address, error) + ListAllAddressesAssociation() ([]*ec2.Address, error) + ListAllInstances() ([]*ec2.Instance, error) + ListAllKeyPairs() ([]*ec2.KeyPairInfo, error) + ListAllInternetGateways() ([]*ec2.InternetGateway, error) + ListAllSubnets() ([]*ec2.Subnet, []*ec2.Subnet, error) + ListAllNatGateways() ([]*ec2.NatGateway, error) + ListAllRouteTables() ([]*ec2.RouteTable, error) + ListAllVPCs() ([]*ec2.Vpc, []*ec2.Vpc, error) + ListAllSecurityGroups() ([]*ec2.SecurityGroup, []*ec2.SecurityGroup, error) + ListAllNetworkACLs() ([]*ec2.NetworkAcl, error) + DescribeLaunchTemplates() ([]*ec2.LaunchTemplate, error) + IsEbsEncryptionEnabledByDefault() (bool, error) +} + +type ec2Repository struct { + client ec2iface.EC2API + cache cache.Cache +} + +func NewEC2Repository(session *session.Session, c cache.Cache) *ec2Repository { + return &ec2Repository{ + ec2.New(session), + c, + } +} + +func (r *ec2Repository) ListAllImages() ([]*ec2.Image, error) { + if v := r.cache.Get("ec2ListAllImages"); v != nil { + return v.([]*ec2.Image), nil + } + + input := &ec2.DescribeImagesInput{ + Owners: []*string{ + aws.String("self"), + }, + } + images, err := r.client.DescribeImages(input) + if err != nil { + return nil, err + } + r.cache.Put("ec2ListAllImages", images.Images) + return images.Images, err +} + +func (r *ec2Repository) ListAllSnapshots() ([]*ec2.Snapshot, error) { + if v := r.cache.Get("ec2ListAllSnapshots"); v != nil { + return v.([]*ec2.Snapshot), nil + } + + var snapshots []*ec2.Snapshot + input := &ec2.DescribeSnapshotsInput{ + OwnerIds: []*string{ + aws.String("self"), + }, + } + err := r.client.DescribeSnapshotsPages(input, func(res *ec2.DescribeSnapshotsOutput, lastPage bool) bool { + snapshots = append(snapshots, res.Snapshots...) + return !lastPage + }) + if err != nil { + return nil, err + } + r.cache.Put("ec2ListAllSnapshots", snapshots) + return snapshots, err +} + +func (r *ec2Repository) ListAllVolumes() ([]*ec2.Volume, error) { + if v := r.cache.Get("ec2ListAllVolumes"); v != nil { + return v.([]*ec2.Volume), nil + } + + var volumes []*ec2.Volume + input := &ec2.DescribeVolumesInput{} + err := r.client.DescribeVolumesPages(input, func(res *ec2.DescribeVolumesOutput, lastPage bool) bool { + volumes = append(volumes, res.Volumes...) + return !lastPage + }) + if err != nil { + return nil, err + } + r.cache.Put("ec2ListAllVolumes", volumes) + return volumes, nil +} + +func (r *ec2Repository) ListAllAddresses() ([]*ec2.Address, error) { + cacheKey := "ec2ListAllAddresses" + v := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + if v != nil { + return v.([]*ec2.Address), nil + } + + input := &ec2.DescribeAddressesInput{} + response, err := r.client.DescribeAddresses(input) + if err != nil { + return nil, err + } + r.cache.Put(cacheKey, response.Addresses) + return response.Addresses, nil +} + +func (r *ec2Repository) ListAllAddressesAssociation() ([]*ec2.Address, error) { + if v := r.cache.Get("ec2ListAllAddressesAssociation"); v != nil { + return v.([]*ec2.Address), nil + } + + addresses, err := r.ListAllAddresses() + if err != nil { + return nil, err + } + results := make([]*ec2.Address, 0, len(addresses)) + + for _, address := range addresses { + if address.AssociationId != nil { + results = append(results, address) + } + } + r.cache.Put("ec2ListAllAddressesAssociation", results) + return results, nil +} + +func (r *ec2Repository) ListAllInstances() ([]*ec2.Instance, error) { + if v := r.cache.Get("ec2ListAllInstances"); v != nil { + return v.([]*ec2.Instance), nil + } + var instances []*ec2.Instance + input := &ec2.DescribeInstancesInput{ + Filters: []*ec2.Filter{ + { + // Ignore terminated state from enumeration since terminated means that instance + // has been removed + Name: aws.String("instance-state-name"), + Values: aws.StringSlice([]string{ + "pending", + "running", + "stopping", + "shutting-down", + "stopped", + }), + }, + }, + } + err := r.client.DescribeInstancesPages(input, func(res *ec2.DescribeInstancesOutput, lastPage bool) bool { + for _, reservation := range res.Reservations { + instances = append(instances, reservation.Instances...) + } + return !lastPage + }) + if err != nil { + return nil, err + } + r.cache.Put("ec2ListAllInstances", instances) + return instances, nil +} + +func (r *ec2Repository) ListAllKeyPairs() ([]*ec2.KeyPairInfo, error) { + if v := r.cache.Get("ec2ListAllKeyPairs"); v != nil { + return v.([]*ec2.KeyPairInfo), nil + } + + input := &ec2.DescribeKeyPairsInput{} + pairs, err := r.client.DescribeKeyPairs(input) + if err != nil { + return nil, err + } + r.cache.Put("ec2ListAllKeyPairs", pairs.KeyPairs) + return pairs.KeyPairs, err +} + +func (r *ec2Repository) ListAllInternetGateways() ([]*ec2.InternetGateway, error) { + if v := r.cache.Get("ec2ListAllInternetGateways"); v != nil { + return v.([]*ec2.InternetGateway), nil + } + + var internetGateways []*ec2.InternetGateway + input := ec2.DescribeInternetGatewaysInput{} + err := r.client.DescribeInternetGatewaysPages(&input, + func(resp *ec2.DescribeInternetGatewaysOutput, lastPage bool) bool { + internetGateways = append(internetGateways, resp.InternetGateways...) + return !lastPage + }, + ) + if err != nil { + return nil, err + } + r.cache.Put("ec2ListAllInternetGateways", internetGateways) + return internetGateways, nil +} + +func (r *ec2Repository) ListAllSubnets() ([]*ec2.Subnet, []*ec2.Subnet, error) { + cacheKey := "ec2ListAllSubnets" + cacheSubnets := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + + defaultCacheKey := "ec2ListAllDefaultSubnets" + cacheDefaultSubnets := r.cache.GetAndLock(defaultCacheKey) + defer r.cache.Unlock(defaultCacheKey) + if cacheSubnets != nil && cacheDefaultSubnets != nil { + return cacheSubnets.([]*ec2.Subnet), cacheDefaultSubnets.([]*ec2.Subnet), nil + } + + input := ec2.DescribeSubnetsInput{} + var subnets []*ec2.Subnet + var defaultSubnets []*ec2.Subnet + err := r.client.DescribeSubnetsPages(&input, + func(resp *ec2.DescribeSubnetsOutput, lastPage bool) bool { + for _, subnet := range resp.Subnets { + if subnet.DefaultForAz != nil && *subnet.DefaultForAz { + defaultSubnets = append(defaultSubnets, subnet) + continue + } + subnets = append(subnets, subnet) + } + return !lastPage + }) + if err != nil { + return nil, nil, err + } + r.cache.Put(cacheKey, subnets) + r.cache.Put(defaultCacheKey, defaultSubnets) + return subnets, defaultSubnets, nil +} + +func (r *ec2Repository) ListAllNatGateways() ([]*ec2.NatGateway, error) { + if v := r.cache.Get("ec2ListAllNatGateways"); v != nil { + return v.([]*ec2.NatGateway), nil + } + + var result []*ec2.NatGateway + input := ec2.DescribeNatGatewaysInput{} + err := r.client.DescribeNatGatewaysPages(&input, + func(resp *ec2.DescribeNatGatewaysOutput, lastPage bool) bool { + result = append(result, resp.NatGateways...) + return !lastPage + }, + ) + + if err != nil { + return nil, err + } + + r.cache.Put("ec2ListAllNatGateways", result) + return result, nil +} + +func (r *ec2Repository) ListAllRouteTables() ([]*ec2.RouteTable, error) { + cacheKey := "ec2ListAllRouteTables" + v := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + if v != nil { + return v.([]*ec2.RouteTable), nil + } + + var routeTables []*ec2.RouteTable + input := ec2.DescribeRouteTablesInput{} + err := r.client.DescribeRouteTablesPages(&input, + func(resp *ec2.DescribeRouteTablesOutput, lastPage bool) bool { + routeTables = append(routeTables, resp.RouteTables...) + return !lastPage + }, + ) + + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, routeTables) + return routeTables, nil +} + +func (r *ec2Repository) ListAllVPCs() ([]*ec2.Vpc, []*ec2.Vpc, error) { + cacheKey := "ec2ListAllVPCs" + cacheVPCs := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + defaultCacheKey := "ec2ListAllDefaultVPCs" + cacheDefaultVPCs := r.cache.GetAndLock(defaultCacheKey) + defer r.cache.Unlock(defaultCacheKey) + if cacheVPCs != nil && cacheDefaultVPCs != nil { + return cacheVPCs.([]*ec2.Vpc), cacheDefaultVPCs.([]*ec2.Vpc), nil + } + + input := ec2.DescribeVpcsInput{} + var VPCs []*ec2.Vpc + var defaultVPCs []*ec2.Vpc + err := r.client.DescribeVpcsPages(&input, + func(resp *ec2.DescribeVpcsOutput, lastPage bool) bool { + for _, vpc := range resp.Vpcs { + if vpc.IsDefault != nil && *vpc.IsDefault { + defaultVPCs = append(defaultVPCs, vpc) + continue + } + VPCs = append(VPCs, vpc) + } + return !lastPage + }, + ) + if err != nil { + return nil, nil, err + } + + r.cache.Put(cacheKey, VPCs) + r.cache.Put(defaultCacheKey, defaultVPCs) + return VPCs, defaultVPCs, nil +} + +func (r *ec2Repository) ListAllSecurityGroups() ([]*ec2.SecurityGroup, []*ec2.SecurityGroup, error) { + cacheKey := "ec2ListAllSecurityGroups" + cacheSecurityGroups := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + defaultCacheKey := "ec2ListAllDefaultSecurityGroups" + cacheDefaultSecurityGroups := r.cache.GetAndLock(defaultCacheKey) + defer r.cache.Unlock(defaultCacheKey) + if cacheSecurityGroups != nil && cacheDefaultSecurityGroups != nil { + return cacheSecurityGroups.([]*ec2.SecurityGroup), cacheDefaultSecurityGroups.([]*ec2.SecurityGroup), nil + } + + var securityGroups []*ec2.SecurityGroup + var defaultSecurityGroups []*ec2.SecurityGroup + input := &ec2.DescribeSecurityGroupsInput{} + err := r.client.DescribeSecurityGroupsPages(input, func(res *ec2.DescribeSecurityGroupsOutput, lastPage bool) bool { + for _, securityGroup := range res.SecurityGroups { + if securityGroup.GroupName != nil && *securityGroup.GroupName == "default" { + defaultSecurityGroups = append(defaultSecurityGroups, securityGroup) + continue + } + securityGroups = append(securityGroups, securityGroup) + } + return !lastPage + }) + if err != nil { + return nil, nil, err + } + + r.cache.Put(cacheKey, securityGroups) + r.cache.Put(defaultCacheKey, defaultSecurityGroups) + return securityGroups, defaultSecurityGroups, nil +} + +func (r *ec2Repository) ListAllNetworkACLs() ([]*ec2.NetworkAcl, error) { + + cacheKey := "ec2ListAllNetworkACLs" + v := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + if v != nil { + return v.([]*ec2.NetworkAcl), nil + } + + var ACLs []*ec2.NetworkAcl + input := ec2.DescribeNetworkAclsInput{} + err := r.client.DescribeNetworkAclsPages(&input, + func(resp *ec2.DescribeNetworkAclsOutput, lastPage bool) bool { + ACLs = append(ACLs, resp.NetworkAcls...) + return !lastPage + }, + ) + + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, ACLs) + return ACLs, nil +} + +func (r *ec2Repository) DescribeLaunchTemplates() ([]*ec2.LaunchTemplate, error) { + cacheKey := "DescribeLaunchTemplates" + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*ec2.LaunchTemplate), nil + } + + input := ec2.DescribeLaunchTemplatesInput{} + resp, err := r.client.DescribeLaunchTemplates(&input) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resp.LaunchTemplates) + return resp.LaunchTemplates, nil +} + +func (r *ec2Repository) IsEbsEncryptionEnabledByDefault() (bool, error) { + if v := r.cache.Get("ec2IsEbsEncryptionEnabledByDefault"); v != nil { + return v.(bool), nil + } + + input := &ec2.GetEbsEncryptionByDefaultInput{} + resp, err := r.client.GetEbsEncryptionByDefault(input) + if err != nil { + return false, err + } + r.cache.Put("ec2IsEbsEncryptionEnabledByDefault", *resp.EbsEncryptionByDefault) + return *resp.EbsEncryptionByDefault, err +} diff --git a/enumeration/remote/aws/repository/ec2_repository_test.go b/enumeration/remote/aws/repository/ec2_repository_test.go new file mode 100644 index 00000000..273a811b --- /dev/null +++ b/enumeration/remote/aws/repository/ec2_repository_test.go @@ -0,0 +1,1429 @@ +package repository + +import ( + cache2 "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/pkg/errors" + awstest "github.com/snyk/driftctl/test/aws" + "github.com/stretchr/testify/mock" + + "github.com/aws/aws-sdk-go/service/ec2" + + "github.com/aws/aws-sdk-go/aws" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" +) + +func Test_ec2Repository_ListAllImages(t *testing.T) { + + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2) + want []*ec2.Image + wantErr error + }{ + { + name: "List all images", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeImages", + &ec2.DescribeImagesInput{ + Owners: []*string{ + aws.String("self"), + }, + }).Return(&ec2.DescribeImagesOutput{ + Images: []*ec2.Image{ + {ImageId: aws.String("1")}, + {ImageId: aws.String("2")}, + {ImageId: aws.String("3")}, + {ImageId: aws.String("4")}, + }, + }, nil).Once() + }, + want: []*ec2.Image{ + {ImageId: aws.String("1")}, + {ImageId: aws.String("2")}, + {ImageId: aws.String("3")}, + {ImageId: aws.String("4")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(1) + client := &awstest.MockFakeEC2{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllImages() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllImages() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.Image{}, store.Get("ec2ListAllImages")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_ec2Repository_ListAllSnapshots(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2) + want []*ec2.Snapshot + wantErr error + }{ + {name: "List with 2 pages", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeSnapshotsPages", + &ec2.DescribeSnapshotsInput{ + OwnerIds: []*string{ + aws.String("self"), + }, + }, + mock.MatchedBy(func(callback func(res *ec2.DescribeSnapshotsOutput, lastPage bool) bool) bool { + callback(&ec2.DescribeSnapshotsOutput{ + Snapshots: []*ec2.Snapshot{ + {VolumeId: aws.String("1")}, + {VolumeId: aws.String("2")}, + {VolumeId: aws.String("3")}, + {VolumeId: aws.String("4")}, + }, + }, false) + callback(&ec2.DescribeSnapshotsOutput{ + Snapshots: []*ec2.Snapshot{ + {VolumeId: aws.String("5")}, + {VolumeId: aws.String("6")}, + {VolumeId: aws.String("7")}, + {VolumeId: aws.String("8")}, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*ec2.Snapshot{ + {VolumeId: aws.String("1")}, + {VolumeId: aws.String("2")}, + {VolumeId: aws.String("3")}, + {VolumeId: aws.String("4")}, + {VolumeId: aws.String("5")}, + {VolumeId: aws.String("6")}, + {VolumeId: aws.String("7")}, + {VolumeId: aws.String("8")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(1) + client := &awstest.MockFakeEC2{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllSnapshots() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllSnapshots() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.Snapshot{}, store.Get("ec2ListAllSnapshots")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_ec2Repository_ListAllVolumes(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2) + want []*ec2.Volume + wantErr error + }{ + {name: "List with 2 pages", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeVolumesPages", + &ec2.DescribeVolumesInput{}, + mock.MatchedBy(func(callback func(res *ec2.DescribeVolumesOutput, lastPage bool) bool) bool { + callback(&ec2.DescribeVolumesOutput{ + Volumes: []*ec2.Volume{ + {VolumeId: aws.String("1")}, + {VolumeId: aws.String("2")}, + {VolumeId: aws.String("3")}, + {VolumeId: aws.String("4")}, + }, + }, false) + callback(&ec2.DescribeVolumesOutput{ + Volumes: []*ec2.Volume{ + {VolumeId: aws.String("5")}, + {VolumeId: aws.String("6")}, + {VolumeId: aws.String("7")}, + {VolumeId: aws.String("8")}, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*ec2.Volume{ + {VolumeId: aws.String("1")}, + {VolumeId: aws.String("2")}, + {VolumeId: aws.String("3")}, + {VolumeId: aws.String("4")}, + {VolumeId: aws.String("5")}, + {VolumeId: aws.String("6")}, + {VolumeId: aws.String("7")}, + {VolumeId: aws.String("8")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(1) + client := &awstest.MockFakeEC2{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllVolumes() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllVolumes() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.Volume{}, store.Get("ec2ListAllVolumes")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_ec2Repository_ListAllAddresses(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2) + want []*ec2.Address + wantErr error + }{ + { + name: "List address", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeAddresses", &ec2.DescribeAddressesInput{}). + Return(&ec2.DescribeAddressesOutput{ + Addresses: []*ec2.Address{ + {AssociationId: aws.String("1")}, + {AssociationId: aws.String("2")}, + {AssociationId: aws.String("3")}, + {AssociationId: aws.String("4")}, + }, + }, nil).Once() + }, + want: []*ec2.Address{ + {AssociationId: aws.String("1")}, + {AssociationId: aws.String("2")}, + {AssociationId: aws.String("3")}, + {AssociationId: aws.String("4")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(1) + client := &awstest.MockFakeEC2{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllAddresses() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllAddresses() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.Address{}, store.Get("ec2ListAllAddresses")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_ec2Repository_ListAllAddressesAssociation(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2) + want []*ec2.Address + wantErr error + }{ + { + name: "List address", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeAddresses", &ec2.DescribeAddressesInput{}). + Return(&ec2.DescribeAddressesOutput{ + Addresses: []*ec2.Address{ + {AssociationId: aws.String("1")}, + {AssociationId: aws.String("2")}, + {AssociationId: aws.String("3")}, + {AssociationId: aws.String("4")}, + }, + }, nil).Once() + }, + want: []*ec2.Address{ + {AssociationId: aws.String("1")}, + {AssociationId: aws.String("2")}, + {AssociationId: aws.String("3")}, + {AssociationId: aws.String("4")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(1) + client := &awstest.MockFakeEC2{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllAddressesAssociation() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllAddressesAssociation() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.Address{}, store.Get("ec2ListAllAddressesAssociation")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_ec2Repository_ListAllInstances(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2) + want []*ec2.Instance + wantErr error + }{ + {name: "List with 2 pages", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeInstancesPages", + &ec2.DescribeInstancesInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("instance-state-name"), + Values: aws.StringSlice([]string{ + "pending", + "running", + "stopping", + "shutting-down", + "stopped", + }), + }, + }, + }, + mock.MatchedBy(func(callback func(res *ec2.DescribeInstancesOutput, lastPage bool) bool) bool { + callback(&ec2.DescribeInstancesOutput{ + Reservations: []*ec2.Reservation{ + { + Instances: []*ec2.Instance{ + {ImageId: aws.String("1")}, + {ImageId: aws.String("2")}, + {ImageId: aws.String("3")}, + }, + }, + { + Instances: []*ec2.Instance{ + {ImageId: aws.String("4")}, + {ImageId: aws.String("5")}, + {ImageId: aws.String("6")}, + }, + }, + }, + }, false) + callback(&ec2.DescribeInstancesOutput{ + Reservations: []*ec2.Reservation{ + { + Instances: []*ec2.Instance{ + {ImageId: aws.String("7")}, + {ImageId: aws.String("8")}, + {ImageId: aws.String("9")}, + }, + }, + { + Instances: []*ec2.Instance{ + {ImageId: aws.String("10")}, + {ImageId: aws.String("11")}, + {ImageId: aws.String("12")}, + }, + }, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*ec2.Instance{ + {ImageId: aws.String("1")}, + {ImageId: aws.String("2")}, + {ImageId: aws.String("3")}, + {ImageId: aws.String("4")}, + {ImageId: aws.String("5")}, + {ImageId: aws.String("6")}, + {ImageId: aws.String("7")}, + {ImageId: aws.String("8")}, + {ImageId: aws.String("9")}, + {ImageId: aws.String("10")}, + {ImageId: aws.String("11")}, + {ImageId: aws.String("12")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(1) + client := &awstest.MockFakeEC2{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllInstances() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllInstances() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.Instance{}, store.Get("ec2ListAllInstances")) + } + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_ec2Repository_ListAllKeyPairs(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2) + want []*ec2.KeyPairInfo + wantErr error + }{ + { + name: "List address", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeKeyPairs", &ec2.DescribeKeyPairsInput{}). + Return(&ec2.DescribeKeyPairsOutput{ + KeyPairs: []*ec2.KeyPairInfo{ + {KeyPairId: aws.String("1")}, + {KeyPairId: aws.String("2")}, + {KeyPairId: aws.String("3")}, + {KeyPairId: aws.String("4")}, + }, + }, nil).Once() + }, + want: []*ec2.KeyPairInfo{ + {KeyPairId: aws.String("1")}, + {KeyPairId: aws.String("2")}, + {KeyPairId: aws.String("3")}, + {KeyPairId: aws.String("4")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(1) + client := &awstest.MockFakeEC2{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllKeyPairs() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllKeyPairs() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.KeyPairInfo{}, store.Get("ec2ListAllKeyPairs")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_ec2Repository_ListAllInternetGateways(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2) + want []*ec2.InternetGateway + wantErr error + }{ + { + name: "List only gateways with multiple pages", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeInternetGatewaysPages", + &ec2.DescribeInternetGatewaysInput{}, + mock.MatchedBy(func(callback func(res *ec2.DescribeInternetGatewaysOutput, lastPage bool) bool) bool { + callback(&ec2.DescribeInternetGatewaysOutput{ + InternetGateways: []*ec2.InternetGateway{ + { + InternetGatewayId: aws.String("Internet-0"), + }, + { + InternetGatewayId: aws.String("Internet-1"), + }, + }, + }, false) + callback(&ec2.DescribeInternetGatewaysOutput{ + InternetGateways: []*ec2.InternetGateway{ + { + InternetGatewayId: aws.String("Internet-2"), + }, + { + InternetGatewayId: aws.String("Internet-3"), + }, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*ec2.InternetGateway{ + { + InternetGatewayId: aws.String("Internet-0"), + }, + { + InternetGatewayId: aws.String("Internet-1"), + }, + { + InternetGatewayId: aws.String("Internet-2"), + }, + { + InternetGatewayId: aws.String("Internet-3"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(1) + client := &awstest.MockFakeEC2{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllInternetGateways() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllInternetGateways() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.InternetGateway{}, store.Get("ec2ListAllInternetGateways")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_ec2Repository_ListAllSubnets(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2) + wantSubnet []*ec2.Subnet + wantDefaultSubnet []*ec2.Subnet + wantErr error + }{ + { + name: "List with 2 pages", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeSubnetsPages", + &ec2.DescribeSubnetsInput{}, + mock.MatchedBy(func(callback func(res *ec2.DescribeSubnetsOutput, lastPage bool) bool) bool { + callback(&ec2.DescribeSubnetsOutput{ + Subnets: []*ec2.Subnet{ + { + SubnetId: aws.String("subnet-0b13f1e0eacf67424"), // subnet2 + DefaultForAz: aws.Bool(false), + }, + { + SubnetId: aws.String("subnet-0c9b78001fe186e22"), // subnet3 + DefaultForAz: aws.Bool(false), + }, + { + SubnetId: aws.String("subnet-05810d3f933925f6d"), // subnet1 + DefaultForAz: aws.Bool(false), + }, + }, + }, false) + callback(&ec2.DescribeSubnetsOutput{ + Subnets: []*ec2.Subnet{ + { + SubnetId: aws.String("subnet-44fe0c65"), // us-east-1a + DefaultForAz: aws.Bool(true), + }, + { + SubnetId: aws.String("subnet-65e16628"), // us-east-1b + DefaultForAz: aws.Bool(true), + }, + { + SubnetId: aws.String("subnet-afa656f0"), // us-east-1c + DefaultForAz: aws.Bool(true), + }, + }, + }, true) + return true + })).Return(nil).Once() + }, + wantSubnet: []*ec2.Subnet{ + { + SubnetId: aws.String("subnet-0b13f1e0eacf67424"), // subnet2 + DefaultForAz: aws.Bool(false), + }, + { + SubnetId: aws.String("subnet-0c9b78001fe186e22"), // subnet3 + DefaultForAz: aws.Bool(false), + }, + { + SubnetId: aws.String("subnet-05810d3f933925f6d"), // subnet1 + DefaultForAz: aws.Bool(false), + }, + }, + wantDefaultSubnet: []*ec2.Subnet{ + { + SubnetId: aws.String("subnet-44fe0c65"), // us-east-1a + DefaultForAz: aws.Bool(true), + }, + { + SubnetId: aws.String("subnet-65e16628"), // us-east-1b + DefaultForAz: aws.Bool(true), + }, + { + SubnetId: aws.String("subnet-afa656f0"), // us-east-1c + DefaultForAz: aws.Bool(true), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(2) + client := &awstest.MockFakeEC2{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + cache: store, + } + gotSubnet, gotDefaultSubnet, err := r.ListAllSubnets() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, cachedDefaultData, err := r.ListAllSubnets() + assert.NoError(t, err) + assert.Equal(t, gotSubnet, cachedData) + assert.Equal(t, gotDefaultSubnet, cachedDefaultData) + assert.IsType(t, []*ec2.Subnet{}, store.Get("ec2ListAllSubnets")) + assert.IsType(t, []*ec2.Subnet{}, store.Get("ec2ListAllDefaultSubnets")) + } + + changelog, err := diff.Diff(gotSubnet, tt.wantSubnet) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + changelog, err = diff.Diff(gotDefaultSubnet, tt.wantDefaultSubnet) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_ec2Repository_ListAllNatGateways(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2) + want []*ec2.NatGateway + wantErr error + }{ + { + name: "List only gateways with multiple pages", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeNatGatewaysPages", + &ec2.DescribeNatGatewaysInput{}, + mock.MatchedBy(func(callback func(res *ec2.DescribeNatGatewaysOutput, lastPage bool) bool) bool { + callback(&ec2.DescribeNatGatewaysOutput{ + NatGateways: []*ec2.NatGateway{ + { + NatGatewayId: aws.String("nat-0"), + }, + { + NatGatewayId: aws.String("nat-1"), + }, + }, + }, false) + callback(&ec2.DescribeNatGatewaysOutput{ + NatGateways: []*ec2.NatGateway{ + { + NatGatewayId: aws.String("nat-2"), + }, + { + NatGatewayId: aws.String("nat-3"), + }, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*ec2.NatGateway{ + { + NatGatewayId: aws.String("nat-0"), + }, + { + NatGatewayId: aws.String("nat-1"), + }, + { + NatGatewayId: aws.String("nat-2"), + }, + { + NatGatewayId: aws.String("nat-3"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(1) + client := &awstest.MockFakeEC2{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllNatGateways() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllNatGateways() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.NatGateway{}, store.Get("ec2ListAllNatGateways")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_ec2Repository_ListAllRouteTables(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2) + want []*ec2.RouteTable + wantErr error + }{ + { + name: "List only route with multiple pages", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeRouteTablesPages", + &ec2.DescribeRouteTablesInput{}, + mock.MatchedBy(func(callback func(res *ec2.DescribeRouteTablesOutput, lastPage bool) bool) bool { + callback(&ec2.DescribeRouteTablesOutput{ + RouteTables: []*ec2.RouteTable{ + { + RouteTableId: aws.String("rtb-096bdfb69309c54c3"), // table1 + Routes: []*ec2.Route{ + { + DestinationCidrBlock: aws.String("10.0.0.0/16"), + Origin: aws.String("CreateRouteTable"), // default route + }, + { + DestinationCidrBlock: aws.String("1.1.1.1/32"), + GatewayId: aws.String("igw-030e74f73bd67f21b"), + }, + { + DestinationIpv6CidrBlock: aws.String("::/0"), + GatewayId: aws.String("igw-030e74f73bd67f21b"), + }, + }, + }, + { + RouteTableId: aws.String("rtb-0169b0937fd963ddc"), // table2 + Routes: []*ec2.Route{ + { + DestinationCidrBlock: aws.String("10.0.0.0/16"), + Origin: aws.String("CreateRouteTable"), // default route + }, + { + DestinationCidrBlock: aws.String("0.0.0.0/0"), + GatewayId: aws.String("igw-030e74f73bd67f21b"), + }, + { + DestinationIpv6CidrBlock: aws.String("::/0"), + GatewayId: aws.String("igw-030e74f73bd67f21b"), + }, + }, + }, + }, + }, false) + callback(&ec2.DescribeRouteTablesOutput{ + RouteTables: []*ec2.RouteTable{ + { + RouteTableId: aws.String("rtb-02780c485f0be93c5"), // default_table + VpcId: aws.String("vpc-09fe5abc2309ba49d"), + Associations: []*ec2.RouteTableAssociation{ + { + Main: aws.Bool(true), + }, + }, + Routes: []*ec2.Route{ + { + DestinationCidrBlock: aws.String("10.0.0.0/16"), + Origin: aws.String("CreateRouteTable"), // default route + }, + { + DestinationCidrBlock: aws.String("10.1.1.0/24"), + GatewayId: aws.String("igw-030e74f73bd67f21b"), + }, + { + DestinationCidrBlock: aws.String("10.1.2.0/24"), + GatewayId: aws.String("igw-030e74f73bd67f21b"), + }, + }, + }, + { + RouteTableId: aws.String(""), // table3 + Routes: []*ec2.Route{ + { + DestinationCidrBlock: aws.String("10.0.0.0/16"), + Origin: aws.String("CreateRouteTable"), // default route + }, + }, + }, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*ec2.RouteTable{ + { + RouteTableId: aws.String("rtb-096bdfb69309c54c3"), // table1 + Routes: []*ec2.Route{ + { + DestinationCidrBlock: aws.String("10.0.0.0/16"), + Origin: aws.String("CreateRouteTable"), // default route + }, + { + DestinationCidrBlock: aws.String("1.1.1.1/32"), + GatewayId: aws.String("igw-030e74f73bd67f21b"), + }, + { + DestinationIpv6CidrBlock: aws.String("::/0"), + GatewayId: aws.String("igw-030e74f73bd67f21b"), + }, + }, + }, + { + RouteTableId: aws.String("rtb-0169b0937fd963ddc"), // table2 + Routes: []*ec2.Route{ + { + DestinationCidrBlock: aws.String("10.0.0.0/16"), + Origin: aws.String("CreateRouteTable"), // default route + }, + { + DestinationCidrBlock: aws.String("0.0.0.0/0"), + GatewayId: aws.String("igw-030e74f73bd67f21b"), + }, + { + DestinationIpv6CidrBlock: aws.String("::/0"), + GatewayId: aws.String("igw-030e74f73bd67f21b"), + }, + }, + }, + { + RouteTableId: aws.String("rtb-02780c485f0be93c5"), // default_table + VpcId: aws.String("vpc-09fe5abc2309ba49d"), + Associations: []*ec2.RouteTableAssociation{ + { + Main: aws.Bool(true), + }, + }, + Routes: []*ec2.Route{ + { + DestinationCidrBlock: aws.String("10.0.0.0/16"), + Origin: aws.String("CreateRouteTable"), // default route + }, + { + DestinationCidrBlock: aws.String("10.1.1.0/24"), + GatewayId: aws.String("igw-030e74f73bd67f21b"), + }, + { + DestinationCidrBlock: aws.String("10.1.2.0/24"), + GatewayId: aws.String("igw-030e74f73bd67f21b"), + }, + }, + }, + { + RouteTableId: aws.String(""), // table3 + Routes: []*ec2.Route{ + { + DestinationCidrBlock: aws.String("10.0.0.0/16"), + Origin: aws.String("CreateRouteTable"), // default route + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(1) + client := &awstest.MockFakeEC2{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllRouteTables() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllRouteTables() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.RouteTable{}, store.Get("ec2ListAllRouteTables")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_ec2Repository_ListAllVPCs(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2) + wantVPC []*ec2.Vpc + wantDefaultVPC []*ec2.Vpc + wantErr error + }{ + { + name: "mixed default VPC and VPC", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeVpcsPages", + &ec2.DescribeVpcsInput{}, + mock.MatchedBy(func(callback func(res *ec2.DescribeVpcsOutput, lastPage bool) bool) bool { + callback(&ec2.DescribeVpcsOutput{ + Vpcs: []*ec2.Vpc{ + { + VpcId: aws.String("vpc-a8c5d4c1"), + IsDefault: aws.Bool(true), + }, + { + VpcId: aws.String("vpc-0768e1fd0029e3fc3"), + }, + { + VpcId: aws.String("vpc-020b072316a95b97f"), + IsDefault: aws.Bool(false), + }, + }, + }, false) + callback(&ec2.DescribeVpcsOutput{ + Vpcs: []*ec2.Vpc{ + { + VpcId: aws.String("vpc-02c50896b59598761"), + IsDefault: aws.Bool(false), + }, + }, + }, true) + return true + })).Return(nil).Once() + }, + wantVPC: []*ec2.Vpc{ + { + VpcId: aws.String("vpc-0768e1fd0029e3fc3"), + }, + { + VpcId: aws.String("vpc-020b072316a95b97f"), + IsDefault: aws.Bool(false), + }, + { + VpcId: aws.String("vpc-02c50896b59598761"), + IsDefault: aws.Bool(false), + }, + }, + wantDefaultVPC: []*ec2.Vpc{ + { + VpcId: aws.String("vpc-a8c5d4c1"), + IsDefault: aws.Bool(true), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(2) + client := &awstest.MockFakeEC2{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + cache: store, + } + gotVPCs, gotDefaultVPCs, err := r.ListAllVPCs() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, cachedDefaultData, err := r.ListAllVPCs() + assert.NoError(t, err) + assert.Equal(t, gotVPCs, cachedData) + assert.Equal(t, gotDefaultVPCs, cachedDefaultData) + assert.IsType(t, []*ec2.Vpc{}, store.Get("ec2ListAllVPCs")) + assert.IsType(t, []*ec2.Vpc{}, store.Get("ec2ListAllDefaultVPCs")) + } + + changelog, err := diff.Diff(gotVPCs, tt.wantVPC) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + changelog, err = diff.Diff(gotDefaultVPCs, tt.wantDefaultVPC) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_ec2Repository_ListAllSecurityGroups(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2) + wantSecurityGroup []*ec2.SecurityGroup + wantDefaultSecurityGroup []*ec2.SecurityGroup + wantErr error + }{ + { + name: "List with 1 pages", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeSecurityGroupsPages", + &ec2.DescribeSecurityGroupsInput{}, + mock.MatchedBy(func(callback func(res *ec2.DescribeSecurityGroupsOutput, lastPage bool) bool) bool { + callback(&ec2.DescribeSecurityGroupsOutput{ + SecurityGroups: []*ec2.SecurityGroup{ + { + GroupId: aws.String("sg-0254c038e32f25530"), + GroupName: aws.String("foo"), + }, + { + GroupId: aws.String("sg-9e0204ff"), + GroupName: aws.String("default"), + }, + }, + }, true) + return true + })).Return(nil).Once() + }, + wantSecurityGroup: []*ec2.SecurityGroup{ + { + GroupId: aws.String("sg-0254c038e32f25530"), + GroupName: aws.String("foo"), + }, + }, + wantDefaultSecurityGroup: []*ec2.SecurityGroup{ + { + GroupId: aws.String("sg-9e0204ff"), + GroupName: aws.String("default"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(2) + client := &awstest.MockFakeEC2{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + cache: store, + } + gotSecurityGroups, gotDefaultSecurityGroups, err := r.ListAllSecurityGroups() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, cachedDefaultData, err := r.ListAllSecurityGroups() + assert.NoError(t, err) + assert.Equal(t, gotSecurityGroups, cachedData) + assert.Equal(t, gotDefaultSecurityGroups, cachedDefaultData) + assert.IsType(t, []*ec2.SecurityGroup{}, store.Get("ec2ListAllSecurityGroups")) + assert.IsType(t, []*ec2.SecurityGroup{}, store.Get("ec2ListAllDefaultSecurityGroups")) + } + + changelog, err := diff.Diff(gotSecurityGroups, tt.wantSecurityGroup) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + changelog, err = diff.Diff(gotDefaultSecurityGroups, tt.wantDefaultSecurityGroup) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_ec2Repository_ListAllNetworkACLs(t *testing.T) { + + testErr := errors.New("test") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2) + want []*ec2.NetworkAcl + wantErr error + }{ + { + name: "List with 1 pages", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeNetworkAclsPages", + &ec2.DescribeNetworkAclsInput{}, + mock.MatchedBy(func(callback func(res *ec2.DescribeNetworkAclsOutput, lastPage bool) bool) bool { + callback(&ec2.DescribeNetworkAclsOutput{ + NetworkAcls: []*ec2.NetworkAcl{ + { + NetworkAclId: aws.String("id1"), + }, + { + NetworkAclId: aws.String("id2"), + }, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*ec2.NetworkAcl{ + { + NetworkAclId: aws.String("id1"), + }, + { + NetworkAclId: aws.String("id2"), + }, + }, + }, + { + name: "List return error", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeNetworkAclsPages", + &ec2.DescribeNetworkAclsInput{}, + mock.Anything, + ).Return(testErr) + }, + wantErr: testErr, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(2) + client := &awstest.MockFakeEC2{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllNetworkACLs() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllNetworkACLs() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.NetworkAcl{}, store.Get("ec2ListAllNetworkACLs")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + client.AssertExpectations(t) + }) + } +} + +func Test_ec2Repository_DescribeLaunchTemplates(t *testing.T) { + + testErr := errors.New("test") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2) + want []*ec2.LaunchTemplate + wantErr error + }{ + { + name: "List with 1 pages", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeLaunchTemplates", + &ec2.DescribeLaunchTemplatesInput{}, + ).Return(&ec2.DescribeLaunchTemplatesOutput{ + LaunchTemplates: []*ec2.LaunchTemplate{ + { + LaunchTemplateId: aws.String("id1"), + }, + { + LaunchTemplateId: aws.String("id2"), + }, + }, + }, nil).Once() + }, + want: []*ec2.LaunchTemplate{ + { + LaunchTemplateId: aws.String("id1"), + }, + { + LaunchTemplateId: aws.String("id2"), + }, + }, + }, + { + name: "List return error", + mocks: func(client *awstest.MockFakeEC2) { + client.On("DescribeLaunchTemplates", + &ec2.DescribeLaunchTemplatesInput{}, + ).Return(nil, testErr) + }, + wantErr: testErr, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(1) + client := &awstest.MockFakeEC2{} + tt.mocks(client) + r := &ec2Repository{ + client: client, + cache: store, + } + got, err := r.DescribeLaunchTemplates() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.DescribeLaunchTemplates() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ec2.LaunchTemplate{}, store.Get("DescribeLaunchTemplates")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + client.AssertExpectations(t) + }) + } +} + +func Test_ec2Repository_IsEbsEncryptionEnabledByDefault(t *testing.T) { + + testErr := errors.New("test") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeEC2, store *cache2.MockCache) + want bool + wantErr error + }{ + { + name: "test that encryption enabled by default", + mocks: func(client *awstest.MockFakeEC2, store *cache2.MockCache) { + store.On("Get", "ec2IsEbsEncryptionEnabledByDefault"). + Return(nil). + Once() + + client.On("GetEbsEncryptionByDefault", + &ec2.GetEbsEncryptionByDefaultInput{}, + ).Return(&ec2.GetEbsEncryptionByDefaultOutput{ + EbsEncryptionByDefault: aws.Bool(true), + }, nil).Once() + + store.On("Put", "ec2IsEbsEncryptionEnabledByDefault", true). + Return(false). + Once() + }, + want: true, + }, + { + name: "test that encryption enabled by default (cached)", + mocks: func(client *awstest.MockFakeEC2, store *cache2.MockCache) { + store.On("Get", "ec2IsEbsEncryptionEnabledByDefault"). + Return(false). + Once() + }, + want: false, + }, + { + name: "error while getting default encryption value", + mocks: func(client *awstest.MockFakeEC2, store *cache2.MockCache) { + store.On("Get", "ec2IsEbsEncryptionEnabledByDefault"). + Return(nil). + Once() + + client.On("GetEbsEncryptionByDefault", + &ec2.GetEbsEncryptionByDefaultInput{}, + ).Return(nil, testErr).Once() + }, + wantErr: testErr, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache2.MockCache{} + client := &awstest.MockFakeEC2{} + tt.mocks(client, store) + r := &ec2Repository{ + client: client, + cache: store, + } + got, err := r.IsEbsEncryptionEnabledByDefault() + + assert.Equal(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + + client.AssertExpectations(t) + store.AssertExpectations(t) + }) + } +} diff --git a/enumeration/remote/aws/repository/ecr_repository.go b/enumeration/remote/aws/repository/ecr_repository.go new file mode 100644 index 00000000..d7b42563 --- /dev/null +++ b/enumeration/remote/aws/repository/ecr_repository.go @@ -0,0 +1,66 @@ +package repository + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/cache" + + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ecr" + "github.com/aws/aws-sdk-go/service/ecr/ecriface" +) + +type ECRRepository interface { + ListAllRepositories() ([]*ecr.Repository, error) + GetRepositoryPolicy(*ecr.Repository) (*ecr.GetRepositoryPolicyOutput, error) +} + +type ecrRepository struct { + client ecriface.ECRAPI + cache cache.Cache +} + +func NewECRRepository(session *session.Session, c cache.Cache) *ecrRepository { + return &ecrRepository{ + ecr.New(session), + c, + } +} + +func (r *ecrRepository) ListAllRepositories() ([]*ecr.Repository, error) { + if v := r.cache.Get("ecrListAllRepositories"); v != nil { + return v.([]*ecr.Repository), nil + } + + var repositories []*ecr.Repository + input := &ecr.DescribeRepositoriesInput{} + err := r.client.DescribeRepositoriesPages(input, func(res *ecr.DescribeRepositoriesOutput, lastPage bool) bool { + repositories = append(repositories, res.Repositories...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put("ecrListAllRepositories", repositories) + return repositories, nil +} + +func (r *ecrRepository) GetRepositoryPolicy(repo *ecr.Repository) (*ecr.GetRepositoryPolicyOutput, error) { + cacheKey := fmt.Sprintf("ecrListAllRepositoriesGetRepositoryPolicy_%s_%s", *repo.RegistryId, *repo.RepositoryName) + if v := r.cache.Get(cacheKey); v != nil { + return v.(*ecr.GetRepositoryPolicyOutput), nil + } + + var repositoryPolicyInput *ecr.GetRepositoryPolicyInput = &ecr.GetRepositoryPolicyInput{ + RegistryId: repo.RegistryId, + RepositoryName: repo.RepositoryName, + } + + repoOutput, err := r.client.GetRepositoryPolicy(repositoryPolicyInput) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, repoOutput) + return repoOutput, nil +} diff --git a/enumeration/remote/aws/repository/ecr_repository_test.go b/enumeration/remote/aws/repository/ecr_repository_test.go new file mode 100644 index 00000000..e36894e8 --- /dev/null +++ b/enumeration/remote/aws/repository/ecr_repository_test.go @@ -0,0 +1,171 @@ +package repository + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/service/ecr" + "github.com/pkg/errors" + awstest "github.com/snyk/driftctl/test/aws" + + "github.com/aws/aws-sdk-go/aws" + + "github.com/stretchr/testify/mock" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" +) + +func Test_ecrRepository_ListAllRepositories(t *testing.T) { + + tests := []struct { + name string + mocks func(client *awstest.MockFakeECR) + want []*ecr.Repository + wantErr error + }{ + { + name: "List with 2 pages", + mocks: func(client *awstest.MockFakeECR) { + client.On("DescribeRepositoriesPages", + &ecr.DescribeRepositoriesInput{}, + mock.MatchedBy(func(callback func(res *ecr.DescribeRepositoriesOutput, lastPage bool) bool) bool { + callback(&ecr.DescribeRepositoriesOutput{ + Repositories: []*ecr.Repository{ + {RepositoryName: aws.String("1")}, + {RepositoryName: aws.String("2")}, + {RepositoryName: aws.String("3")}, + }, + }, false) + callback(&ecr.DescribeRepositoriesOutput{ + Repositories: []*ecr.Repository{ + {RepositoryName: aws.String("4")}, + {RepositoryName: aws.String("5")}, + {RepositoryName: aws.String("6")}, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*ecr.Repository{ + {RepositoryName: aws.String("1")}, + {RepositoryName: aws.String("2")}, + {RepositoryName: aws.String("3")}, + {RepositoryName: aws.String("4")}, + {RepositoryName: aws.String("5")}, + {RepositoryName: aws.String("6")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := awstest.MockFakeECR{} + tt.mocks(&client) + r := &ecrRepository{ + client: &client, + cache: store, + } + got, err := r.ListAllRepositories() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllRepositories() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*ecr.Repository{}, store.Get("ecrListAllRepositories")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_ecrRepository_GetRepositoryPolicy(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeECR) + want *ecr.GetRepositoryPolicyOutput + wantErr error + }{ + { + name: "Get repository policy", + mocks: func(client *awstest.MockFakeECR) { + client.On("GetRepositoryPolicy", + &ecr.GetRepositoryPolicyInput{ + RegistryId: aws.String("1"), + RepositoryName: aws.String("2"), + }, + ).Return(&ecr.GetRepositoryPolicyOutput{ + RegistryId: aws.String("1"), + RepositoryName: aws.String("2"), + }, nil).Once() + }, + want: &ecr.GetRepositoryPolicyOutput{ + RegistryId: aws.String("1"), + RepositoryName: aws.String("2"), + }, + }, + { + name: "Get repository policy error", + mocks: func(client *awstest.MockFakeECR) { + client.On("GetRepositoryPolicy", + &ecr.GetRepositoryPolicyInput{ + RegistryId: aws.String("1"), + RepositoryName: aws.String("2"), + }, + ).Return(nil, dummyError).Once() + }, + wantErr: dummyError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := awstest.MockFakeECR{} + tt.mocks(&client) + r := &ecrRepository{ + client: &client, + cache: store, + } + + repo := &ecr.Repository{ + RegistryId: aws.String("1"), + RepositoryName: aws.String("2"), + } + + got, err := r.GetRepositoryPolicy(repo) + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.GetRepositoryPolicy(repo) + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + cacheKey := fmt.Sprintf("ecrListAllRepositoriesGetRepositoryPolicy_%s_%s", *repo.RegistryId, *repo.RepositoryName) + assert.IsType(t, &ecr.GetRepositoryPolicyOutput{}, store.Get(cacheKey)) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} diff --git a/enumeration/remote/aws/repository/elasticache_repository.go b/enumeration/remote/aws/repository/elasticache_repository.go new file mode 100644 index 00000000..9fc991ae --- /dev/null +++ b/enumeration/remote/aws/repository/elasticache_repository.go @@ -0,0 +1,45 @@ +package repository + +import ( + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/elasticache" + "github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface" + "github.com/snyk/driftctl/enumeration/remote/cache" +) + +type ElastiCacheRepository interface { + ListAllCacheClusters() ([]*elasticache.CacheCluster, error) +} + +type elasticacheRepository struct { + client elasticacheiface.ElastiCacheAPI + cache cache.Cache +} + +func NewElastiCacheRepository(session *session.Session, c cache.Cache) *elasticacheRepository { + return &elasticacheRepository{ + elasticache.New(session), + c, + } +} + +func (r *elasticacheRepository) ListAllCacheClusters() ([]*elasticache.CacheCluster, error) { + if v := r.cache.Get("elasticacheListAllCacheClusters"); v != nil { + return v.([]*elasticache.CacheCluster), nil + } + + var clusters []*elasticache.CacheCluster + input := elasticache.DescribeCacheClustersInput{} + err := r.client.DescribeCacheClustersPages(&input, + func(resp *elasticache.DescribeCacheClustersOutput, lastPage bool) bool { + clusters = append(clusters, resp.CacheClusters...) + return !lastPage + }, + ) + if err != nil { + return nil, err + } + + r.cache.Put("elasticacheListAllCacheClusters", clusters) + return clusters, nil +} diff --git a/enumeration/remote/aws/repository/elasticache_repository_test.go b/enumeration/remote/aws/repository/elasticache_repository_test.go new file mode 100644 index 00000000..110bff12 --- /dev/null +++ b/enumeration/remote/aws/repository/elasticache_repository_test.go @@ -0,0 +1,96 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/elasticache" + "github.com/pkg/errors" + "github.com/r3labs/diff/v2" + awstest "github.com/snyk/driftctl/test/aws" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func Test_elasticacheRepository_ListAllCacheClusters(t *testing.T) { + clusters := []*elasticache.CacheCluster{ + {CacheClusterId: aws.String("cluster1")}, + {CacheClusterId: aws.String("cluster2")}, + {CacheClusterId: aws.String("cluster3")}, + {CacheClusterId: aws.String("cluster4")}, + {CacheClusterId: aws.String("cluster5")}, + {CacheClusterId: aws.String("cluster6")}, + } + + remoteError := errors.New("remote error") + + tests := []struct { + name string + mocks func(client *awstest.MockFakeElastiCache, store *cache.MockCache) + want []*elasticache.CacheCluster + wantErr error + }{ + { + name: "List cache clusters", + mocks: func(client *awstest.MockFakeElastiCache, store *cache.MockCache) { + client.On("DescribeCacheClustersPages", + &elasticache.DescribeCacheClustersInput{}, + mock.MatchedBy(func(callback func(res *elasticache.DescribeCacheClustersOutput, lastPage bool) bool) bool { + callback(&elasticache.DescribeCacheClustersOutput{ + CacheClusters: clusters[:3], + }, false) + callback(&elasticache.DescribeCacheClustersOutput{ + CacheClusters: clusters[3:], + }, true) + return true + })).Return(nil).Once() + store.On("Get", "elasticacheListAllCacheClusters").Return(nil).Times(1) + store.On("Put", "elasticacheListAllCacheClusters", clusters).Return(false).Times(1) + }, + want: clusters, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeElastiCache, store *cache.MockCache) { + store.On("Get", "elasticacheListAllCacheClusters").Return(clusters).Times(1) + }, + want: clusters, + }, + { + name: "should return remote error", + mocks: func(client *awstest.MockFakeElastiCache, store *cache.MockCache) { + client.On("DescribeCacheClustersPages", + &elasticache.DescribeCacheClustersInput{}, + mock.AnythingOfType("func(*elasticache.DescribeCacheClustersOutput, bool) bool")).Return(remoteError).Once() + store.On("Get", "elasticacheListAllCacheClusters").Return(nil).Times(1) + }, + wantErr: remoteError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeElastiCache{} + tt.mocks(client, store) + r := &elasticacheRepository{ + client: client, + cache: store, + } + got, err := r.ListAllCacheClusters() + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + store.AssertExpectations(t) + client.AssertExpectations(t) + }) + } +} diff --git a/enumeration/remote/aws/repository/elb_repository.go b/enumeration/remote/aws/repository/elb_repository.go new file mode 100644 index 00000000..46c05cba --- /dev/null +++ b/enumeration/remote/aws/repository/elb_repository.go @@ -0,0 +1,43 @@ +package repository + +import ( + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/elb" + "github.com/aws/aws-sdk-go/service/elb/elbiface" + "github.com/snyk/driftctl/enumeration/remote/cache" +) + +type ELBRepository interface { + ListAllLoadBalancers() ([]*elb.LoadBalancerDescription, error) +} + +type elbRepository struct { + client elbiface.ELBAPI + cache cache.Cache +} + +func NewELBRepository(session *session.Session, c cache.Cache) *elbRepository { + return &elbRepository{ + elb.New(session), + c, + } +} + +func (r *elbRepository) ListAllLoadBalancers() ([]*elb.LoadBalancerDescription, error) { + if v := r.cache.Get("elbListAllLoadBalancers"); v != nil { + return v.([]*elb.LoadBalancerDescription), nil + } + + results := make([]*elb.LoadBalancerDescription, 0) + input := elb.DescribeLoadBalancersInput{} + err := r.client.DescribeLoadBalancersPages(&input, func(res *elb.DescribeLoadBalancersOutput, lastPage bool) bool { + results = append(results, res.LoadBalancerDescriptions...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put("elbListAllLoadBalancers", results) + return results, nil +} diff --git a/enumeration/remote/aws/repository/elb_repository_test.go b/enumeration/remote/aws/repository/elb_repository_test.go new file mode 100644 index 00000000..b21b06e0 --- /dev/null +++ b/enumeration/remote/aws/repository/elb_repository_test.go @@ -0,0 +1,119 @@ +package repository + +import ( + "errors" + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/elb" + awstest "github.com/snyk/driftctl/test/aws" + + "github.com/stretchr/testify/mock" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" +) + +func Test_ELBRepository_ListAllLoadBalancers(t *testing.T) { + dummyErr := errors.New("dummy error") + + results := []*elb.LoadBalancerDescription{ + { + LoadBalancerName: aws.String("test-lb-1"), + }, + { + LoadBalancerName: aws.String("test-lb-2"), + }, + } + + tests := []struct { + name string + mocks func(*awstest.MockFakeELB, *cache.MockCache) + want []*elb.LoadBalancerDescription + wantErr error + }{ + { + name: "List load balancers with multiple pages", + mocks: func(client *awstest.MockFakeELB, store *cache.MockCache) { + store.On("Get", "elbListAllLoadBalancers").Return(nil).Once() + + client.On("DescribeLoadBalancersPages", + &elb.DescribeLoadBalancersInput{}, + mock.MatchedBy(func(callback func(res *elb.DescribeLoadBalancersOutput, lastPage bool) bool) bool { + callback(&elb.DescribeLoadBalancersOutput{LoadBalancerDescriptions: []*elb.LoadBalancerDescription{ + results[0], + }}, false) + callback(&elb.DescribeLoadBalancersOutput{LoadBalancerDescriptions: []*elb.LoadBalancerDescription{ + results[1], + }}, true) + return true + })).Return(nil).Once() + + store.On("Put", "elbListAllLoadBalancers", results).Return(false).Once() + }, + want: []*elb.LoadBalancerDescription{ + { + LoadBalancerName: aws.String("test-lb-1"), + }, + { + LoadBalancerName: aws.String("test-lb-2"), + }, + }, + }, + { + name: "List load balancers with multiple pages (cache hit)", + mocks: func(client *awstest.MockFakeELB, store *cache.MockCache) { + store.On("Get", "elbListAllLoadBalancers").Return(results).Once() + }, + want: []*elb.LoadBalancerDescription{ + { + LoadBalancerName: aws.String("test-lb-1"), + }, + { + LoadBalancerName: aws.String("test-lb-2"), + }, + }, + }, + { + name: "Error listing load balancers", + mocks: func(client *awstest.MockFakeELB, store *cache.MockCache) { + store.On("Get", "elbListAllLoadBalancers").Return(nil).Once() + + client.On("DescribeLoadBalancersPages", + &elb.DescribeLoadBalancersInput{}, + mock.MatchedBy(func(callback func(res *elb.DescribeLoadBalancersOutput, lastPage bool) bool) bool { + callback(&elb.DescribeLoadBalancersOutput{LoadBalancerDescriptions: []*elb.LoadBalancerDescription{}}, true) + return true + })).Return(dummyErr).Once() + }, + wantErr: dummyErr, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeELB{} + tt.mocks(client, store) + r := &elbRepository{ + client: client, + cache: store, + } + got, err := r.ListAllLoadBalancers() + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + + client.AssertExpectations(t) + store.AssertExpectations(t) + }) + } +} diff --git a/enumeration/remote/aws/repository/elbv2_repository.go b/enumeration/remote/aws/repository/elbv2_repository.go new file mode 100644 index 00000000..28639b15 --- /dev/null +++ b/enumeration/remote/aws/repository/elbv2_repository.go @@ -0,0 +1,65 @@ +package repository + +import ( + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/elbv2" + "github.com/aws/aws-sdk-go/service/elbv2/elbv2iface" + "github.com/snyk/driftctl/enumeration/remote/cache" +) + +type ELBV2Repository interface { + ListAllLoadBalancers() ([]*elbv2.LoadBalancer, error) + ListAllLoadBalancerListeners(string) ([]*elbv2.Listener, error) +} + +type elbv2Repository struct { + client elbv2iface.ELBV2API + cache cache.Cache +} + +func NewELBV2Repository(session *session.Session, c cache.Cache) *elbv2Repository { + return &elbv2Repository{ + elbv2.New(session), + c, + } +} + +func (r *elbv2Repository) ListAllLoadBalancers() ([]*elbv2.LoadBalancer, error) { + cacheKey := "elbv2ListAllLoadBalancers" + defer r.cache.Unlock(cacheKey) + if v := r.cache.GetAndLock(cacheKey); v != nil { + return v.([]*elbv2.LoadBalancer), nil + } + + results := make([]*elbv2.LoadBalancer, 0) + input := &elbv2.DescribeLoadBalancersInput{} + err := r.client.DescribeLoadBalancersPages(input, func(res *elbv2.DescribeLoadBalancersOutput, lastPage bool) bool { + results = append(results, res.LoadBalancers...) + return !lastPage + }) + if err != nil { + return nil, err + } + r.cache.Put(cacheKey, results) + return results, err +} + +func (r *elbv2Repository) ListAllLoadBalancerListeners(loadBalancerArn string) ([]*elbv2.Listener, error) { + if v := r.cache.Get("elbv2ListAllLoadBalancerListeners"); v != nil { + return v.([]*elbv2.Listener), nil + } + + results := make([]*elbv2.Listener, 0) + input := &elbv2.DescribeListenersInput{ + LoadBalancerArn: &loadBalancerArn, + } + err := r.client.DescribeListenersPages(input, func(res *elbv2.DescribeListenersOutput, lastPage bool) bool { + results = append(results, res.Listeners...) + return !lastPage + }) + if err != nil { + return nil, err + } + r.cache.Put("elbv2ListAllLoadBalancerListeners", results) + return results, err +} diff --git a/enumeration/remote/aws/repository/elbv2_repository_test.go b/enumeration/remote/aws/repository/elbv2_repository_test.go new file mode 100644 index 00000000..2ca320d6 --- /dev/null +++ b/enumeration/remote/aws/repository/elbv2_repository_test.go @@ -0,0 +1,243 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/elbv2" + "github.com/pkg/errors" + awstest "github.com/snyk/driftctl/test/aws" + "github.com/stretchr/testify/mock" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" +) + +func Test_ELBV2Repository_ListAllLoadBalancers(t *testing.T) { + dummyError := errors.New("dummy error") + + tests := []struct { + name string + mocks func(*awstest.MockFakeELBV2, *cache.MockCache) + want []*elbv2.LoadBalancer + wantErr error + }{ + { + name: "list load balancers", + mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { + results := &elbv2.DescribeLoadBalancersOutput{ + LoadBalancers: []*elbv2.LoadBalancer{ + { + LoadBalancerArn: aws.String("test-1"), + LoadBalancerName: aws.String("test-1"), + }, + { + LoadBalancerArn: aws.String("test-2"), + LoadBalancerName: aws.String("test-2"), + }, + }, + } + + store.On("GetAndLock", "elbv2ListAllLoadBalancers").Return(nil).Once() + store.On("Unlock", "elbv2ListAllLoadBalancers").Return().Once() + + client.On("DescribeLoadBalancersPages", + &elbv2.DescribeLoadBalancersInput{}, + mock.MatchedBy(func(callback func(res *elbv2.DescribeLoadBalancersOutput, lastPage bool) bool) bool { + callback(&elbv2.DescribeLoadBalancersOutput{LoadBalancers: []*elbv2.LoadBalancer{ + results.LoadBalancers[0], + }}, false) + callback(&elbv2.DescribeLoadBalancersOutput{LoadBalancers: []*elbv2.LoadBalancer{ + results.LoadBalancers[1], + }}, true) + return true + })).Return(nil).Once() + + store.On("Put", "elbv2ListAllLoadBalancers", results.LoadBalancers).Return(false).Once() + }, + want: []*elbv2.LoadBalancer{ + { + LoadBalancerArn: aws.String("test-1"), + LoadBalancerName: aws.String("test-1"), + }, + { + LoadBalancerArn: aws.String("test-2"), + LoadBalancerName: aws.String("test-2"), + }, + }, + }, + { + name: "list load balancers from cache", + mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { + output := &elbv2.DescribeLoadBalancersOutput{ + LoadBalancers: []*elbv2.LoadBalancer{ + { + LoadBalancerArn: aws.String("test-1"), + LoadBalancerName: aws.String("test-1"), + }, + }, + } + + store.On("GetAndLock", "elbv2ListAllLoadBalancers").Return(output.LoadBalancers).Once() + store.On("Unlock", "elbv2ListAllLoadBalancers").Return().Once() + }, + want: []*elbv2.LoadBalancer{ + { + LoadBalancerArn: aws.String("test-1"), + LoadBalancerName: aws.String("test-1"), + }, + }, + }, + { + name: "error listing load balancers", + mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { + store.On("GetAndLock", "elbv2ListAllLoadBalancers").Return(nil).Once() + store.On("Unlock", "elbv2ListAllLoadBalancers").Return().Once() + + client.On("DescribeLoadBalancersPages", + &elbv2.DescribeLoadBalancersInput{}, + mock.MatchedBy(func(callback func(res *elbv2.DescribeLoadBalancersOutput, lastPage bool) bool) bool { + callback(&elbv2.DescribeLoadBalancersOutput{LoadBalancers: []*elbv2.LoadBalancer{}}, true) + return true + })).Return(dummyError).Once() + }, + wantErr: dummyError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeELBV2{} + tt.mocks(client, store) + r := &elbv2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllLoadBalancers() + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_ELBV2Repository_ListAllLoadBalancerListeners(t *testing.T) { + dummyError := errors.New("dummy error") + + tests := []struct { + name string + mocks func(*awstest.MockFakeELBV2, *cache.MockCache) + want []*elbv2.Listener + wantErr error + }{ + { + name: "list load balancer listeners", + mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { + results := &elbv2.DescribeListenersOutput{ + Listeners: []*elbv2.Listener{ + { + LoadBalancerArn: aws.String("test-lb"), + ListenerArn: aws.String("test-lb-listener-1"), + }, + { + LoadBalancerArn: aws.String("test-lb"), + ListenerArn: aws.String("test-lb-listener-2"), + }, + }, + } + + store.On("Get", "elbv2ListAllLoadBalancerListeners").Return(nil).Once() + + client.On("DescribeListenersPages", + &elbv2.DescribeListenersInput{LoadBalancerArn: aws.String("test-lb")}, + mock.MatchedBy(func(callback func(res *elbv2.DescribeListenersOutput, lastPage bool) bool) bool { + callback(&elbv2.DescribeListenersOutput{Listeners: []*elbv2.Listener{ + results.Listeners[0], + }}, false) + callback(&elbv2.DescribeListenersOutput{Listeners: []*elbv2.Listener{ + results.Listeners[1], + }}, true) + return true + })).Return(nil).Once() + + store.On("Put", "elbv2ListAllLoadBalancerListeners", results.Listeners).Return(false).Once() + }, + want: []*elbv2.Listener{ + { + LoadBalancerArn: aws.String("test-lb"), + ListenerArn: aws.String("test-lb-listener-1"), + }, + { + LoadBalancerArn: aws.String("test-lb"), + ListenerArn: aws.String("test-lb-listener-2"), + }, + }, + }, + { + name: "list load balancer listeners from cache", + mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { + output := &elbv2.DescribeListenersOutput{ + Listeners: []*elbv2.Listener{ + { + LoadBalancerArn: aws.String("test-lb"), + ListenerArn: aws.String("test-lb-listener"), + }, + }, + } + + store.On("Get", "elbv2ListAllLoadBalancerListeners").Return(output.Listeners).Once() + }, + want: []*elbv2.Listener{ + { + LoadBalancerArn: aws.String("test-lb"), + ListenerArn: aws.String("test-lb-listener"), + }, + }, + }, + { + name: "error listing load balancer listeners", + mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { + store.On("Get", "elbv2ListAllLoadBalancerListeners").Return(nil).Once() + + client.On("DescribeListenersPages", + &elbv2.DescribeListenersInput{LoadBalancerArn: aws.String("test-lb")}, + mock.MatchedBy(func(callback func(res *elbv2.DescribeListenersOutput, lastPage bool) bool) bool { + callback(&elbv2.DescribeListenersOutput{Listeners: []*elbv2.Listener{}}, true) + return true + })).Return(dummyError).Once() + }, + wantErr: dummyError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache.MockCache{} + client := &awstest.MockFakeELBV2{} + tt.mocks(client, store) + r := &elbv2Repository{ + client: client, + cache: store, + } + got, err := r.ListAllLoadBalancerListeners("test-lb") + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} diff --git a/enumeration/remote/aws/repository/iam_repository.go b/enumeration/remote/aws/repository/iam_repository.go new file mode 100644 index 00000000..e7882fe9 --- /dev/null +++ b/enumeration/remote/aws/repository/iam_repository.go @@ -0,0 +1,367 @@ +package repository + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/cache" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/iam/iamiface" +) + +type IAMRepository interface { + ListAllAccessKeys([]*iam.User) ([]*iam.AccessKeyMetadata, error) + ListAllUsers() ([]*iam.User, error) + ListAllPolicies() ([]*iam.Policy, error) + ListAllRoles() ([]*iam.Role, error) + ListAllRolePolicyAttachments([]*iam.Role) ([]*AttachedRolePolicy, error) + ListAllRolePolicies([]*iam.Role) ([]RolePolicy, error) + ListAllUserPolicyAttachments([]*iam.User) ([]*AttachedUserPolicy, error) + ListAllUserPolicies([]*iam.User) ([]string, error) + ListAllGroups() ([]*iam.Group, error) + ListAllGroupPolicies([]*iam.Group) ([]string, error) + ListAllGroupPolicyAttachments([]*iam.Group) ([]*AttachedGroupPolicy, error) +} + +type iamRepository struct { + client iamiface.IAMAPI + cache cache.Cache +} + +func NewIAMRepository(session *session.Session, c cache.Cache) *iamRepository { + return &iamRepository{ + iam.New(session), + c, + } +} + +func (r *iamRepository) ListAllAccessKeys(users []*iam.User) ([]*iam.AccessKeyMetadata, error) { + var resources []*iam.AccessKeyMetadata + for _, user := range users { + cacheKey := fmt.Sprintf("iamListAllAccessKeys_user_%s", *user.UserName) + if v := r.cache.Get(cacheKey); v != nil { + resources = append(resources, v.([]*iam.AccessKeyMetadata)...) + continue + } + + userResources := make([]*iam.AccessKeyMetadata, 0) + input := &iam.ListAccessKeysInput{ + UserName: user.UserName, + } + err := r.client.ListAccessKeysPages(input, func(res *iam.ListAccessKeysOutput, lastPage bool) bool { + userResources = append(userResources, res.AccessKeyMetadata...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, userResources) + resources = append(resources, userResources...) + } + + return resources, nil +} + +func (r *iamRepository) ListAllUsers() ([]*iam.User, error) { + + cacheKey := "iamListAllUsers" + v := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + if v != nil { + return v.([]*iam.User), nil + } + + var resources []*iam.User + input := &iam.ListUsersInput{} + err := r.client.ListUsersPages(input, func(res *iam.ListUsersOutput, lastPage bool) bool { + resources = append(resources, res.Users...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resources) + return resources, nil +} + +func (r *iamRepository) ListAllPolicies() ([]*iam.Policy, error) { + if v := r.cache.Get("iamListAllPolicies"); v != nil { + return v.([]*iam.Policy), nil + } + + var resources []*iam.Policy + input := &iam.ListPoliciesInput{ + Scope: aws.String(iam.PolicyScopeTypeLocal), + } + err := r.client.ListPoliciesPages(input, func(res *iam.ListPoliciesOutput, lastPage bool) bool { + resources = append(resources, res.Policies...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put("iamListAllPolicies", resources) + return resources, nil +} + +func (r *iamRepository) ListAllRoles() ([]*iam.Role, error) { + cacheKey := "iamListAllRoles" + v := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + if v != nil { + return v.([]*iam.Role), nil + } + + var resources []*iam.Role + input := &iam.ListRolesInput{} + err := r.client.ListRolesPages(input, func(res *iam.ListRolesOutput, lastPage bool) bool { + resources = append(resources, res.Roles...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resources) + return resources, nil +} + +func (r *iamRepository) ListAllRolePolicyAttachments(roles []*iam.Role) ([]*AttachedRolePolicy, error) { + var resources []*AttachedRolePolicy + for _, role := range roles { + cacheKey := fmt.Sprintf("iamListAllRolePolicyAttachments_role_%s", *role.RoleName) + if v := r.cache.Get(cacheKey); v != nil { + resources = append(resources, v.([]*AttachedRolePolicy)...) + continue + } + + roleResources := make([]*AttachedRolePolicy, 0) + input := &iam.ListAttachedRolePoliciesInput{ + RoleName: role.RoleName, + } + err := r.client.ListAttachedRolePoliciesPages(input, func(res *iam.ListAttachedRolePoliciesOutput, lastPage bool) bool { + for _, policy := range res.AttachedPolicies { + p := *policy + roleResources = append(roleResources, &AttachedRolePolicy{ + AttachedPolicy: p, + RoleName: *input.RoleName, + }) + } + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, roleResources) + resources = append(resources, roleResources...) + } + + return resources, nil +} + +func (r *iamRepository) ListAllRolePolicies(roles []*iam.Role) ([]RolePolicy, error) { + var resources []RolePolicy + for _, role := range roles { + cacheKey := fmt.Sprintf("iamListAllRolePolicies_role_%s", *role.RoleName) + if v := r.cache.Get(cacheKey); v != nil { + resources = append(resources, v.([]RolePolicy)...) + continue + } + + roleResources := make([]RolePolicy, 0) + input := &iam.ListRolePoliciesInput{ + RoleName: role.RoleName, + } + err := r.client.ListRolePoliciesPages(input, func(res *iam.ListRolePoliciesOutput, lastPage bool) bool { + for _, policy := range res.PolicyNames { + roleResources = append(roleResources, RolePolicy{*policy, *input.RoleName}) + } + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, roleResources) + resources = append(resources, roleResources...) + } + + return resources, nil +} + +func (r *iamRepository) ListAllUserPolicyAttachments(users []*iam.User) ([]*AttachedUserPolicy, error) { + var resources []*AttachedUserPolicy + for _, user := range users { + cacheKey := fmt.Sprintf("iamListAllUserPolicyAttachments_user_%s", *user.UserName) + if v := r.cache.Get(cacheKey); v != nil { + resources = append(resources, v.([]*AttachedUserPolicy)...) + continue + } + + userResources := make([]*AttachedUserPolicy, 0) + input := &iam.ListAttachedUserPoliciesInput{ + UserName: user.UserName, + } + err := r.client.ListAttachedUserPoliciesPages(input, func(res *iam.ListAttachedUserPoliciesOutput, lastPage bool) bool { + for _, policy := range res.AttachedPolicies { + p := *policy + userResources = append(userResources, &AttachedUserPolicy{ + AttachedPolicy: p, + UserName: *input.UserName, + }) + } + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, userResources) + resources = append(resources, userResources...) + } + + return resources, nil +} + +func (r *iamRepository) ListAllUserPolicies(users []*iam.User) ([]string, error) { + var resources []string + for _, user := range users { + cacheKey := fmt.Sprintf("iamListAllUserPolicies_user_%s", *user.UserName) + if v := r.cache.Get(cacheKey); v != nil { + resources = append(resources, v.([]string)...) + continue + } + + userResources := make([]string, 0) + input := &iam.ListUserPoliciesInput{ + UserName: user.UserName, + } + err := r.client.ListUserPoliciesPages(input, func(res *iam.ListUserPoliciesOutput, lastPage bool) bool { + for _, polName := range res.PolicyNames { + userResources = append(userResources, fmt.Sprintf("%s:%s", *input.UserName, *polName)) + } + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, userResources) + resources = append(resources, userResources...) + } + + return resources, nil +} + +func (r *iamRepository) ListAllGroups() ([]*iam.Group, error) { + + cacheKey := "iamListAllGroups" + v := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + + if v != nil { + return v.([]*iam.Group), nil + } + + var resources []*iam.Group + input := &iam.ListGroupsInput{} + err := r.client.ListGroupsPages(input, func(res *iam.ListGroupsOutput, lastPage bool) bool { + resources = append(resources, res.Groups...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, resources) + return resources, nil +} + +func (r *iamRepository) ListAllGroupPolicies(groups []*iam.Group) ([]string, error) { + var resources []string + for _, group := range groups { + cacheKey := fmt.Sprintf("iamListAllGroupPolicies_group_%s", *group.GroupName) + if v := r.cache.Get(cacheKey); v != nil { + resources = append(resources, v.([]string)...) + continue + } + + groupResources := make([]string, 0) + input := &iam.ListGroupPoliciesInput{ + GroupName: group.GroupName, + } + err := r.client.ListGroupPoliciesPages(input, func(res *iam.ListGroupPoliciesOutput, lastPage bool) bool { + for _, polName := range res.PolicyNames { + groupResources = append(groupResources, fmt.Sprintf("%s:%s", *input.GroupName, *polName)) + } + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, groupResources) + resources = append(resources, groupResources...) + } + + return resources, nil +} + +func (r *iamRepository) ListAllGroupPolicyAttachments(groups []*iam.Group) ([]*AttachedGroupPolicy, error) { + var resources []*AttachedGroupPolicy + for _, group := range groups { + cacheKey := fmt.Sprintf("iamListAllGroupPolicyAttachments_%s", *group.GroupId) + if v := r.cache.Get(cacheKey); v != nil { + resources = append(resources, v.([]*AttachedGroupPolicy)...) + continue + } + + attachedGroupPolicies := make([]*AttachedGroupPolicy, 0) + input := &iam.ListAttachedGroupPoliciesInput{ + GroupName: group.GroupName, + } + err := r.client.ListAttachedGroupPoliciesPages(input, func(res *iam.ListAttachedGroupPoliciesOutput, lastPage bool) bool { + for _, policy := range res.AttachedPolicies { + p := *policy + attachedGroupPolicies = append(attachedGroupPolicies, &AttachedGroupPolicy{ + AttachedPolicy: p, + GroupName: *input.GroupName, + }) + } + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, attachedGroupPolicies) + resources = append(resources, attachedGroupPolicies...) + } + + return resources, nil +} + +type AttachedUserPolicy struct { + iam.AttachedPolicy + UserName string +} + +type AttachedRolePolicy struct { + iam.AttachedPolicy + RoleName string +} + +type AttachedGroupPolicy struct { + iam.AttachedPolicy + GroupName string +} + +type RolePolicy struct { + Policy string + RoleName string +} diff --git a/enumeration/remote/aws/repository/iam_repository_test.go b/enumeration/remote/aws/repository/iam_repository_test.go new file mode 100644 index 00000000..8d2ad6e2 --- /dev/null +++ b/enumeration/remote/aws/repository/iam_repository_test.go @@ -0,0 +1,1100 @@ +package repository + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/iam" + awstest "github.com/snyk/driftctl/test/aws" + + "github.com/stretchr/testify/mock" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" +) + +func Test_IAMRepository_ListAllAccessKeys(t *testing.T) { + tests := []struct { + name string + users []*iam.User + mocks func(client *awstest.MockFakeIAM) + want []*iam.AccessKeyMetadata + wantErr error + }{ + { + name: "List only access keys with multiple pages", + users: []*iam.User{ + { + UserName: aws.String("test-driftctl"), + }, + { + UserName: aws.String("test-driftctl2"), + }, + }, + mocks: func(client *awstest.MockFakeIAM) { + + client.On("ListAccessKeysPages", + &iam.ListAccessKeysInput{ + UserName: aws.String("test-driftctl"), + }, + mock.MatchedBy(func(callback func(res *iam.ListAccessKeysOutput, lastPage bool) bool) bool { + callback(&iam.ListAccessKeysOutput{AccessKeyMetadata: []*iam.AccessKeyMetadata{ + { + AccessKeyId: aws.String("AKIA5QYBVVD223VWU32A"), + UserName: aws.String("test-driftctl"), + }, + }}, false) + callback(&iam.ListAccessKeysOutput{AccessKeyMetadata: []*iam.AccessKeyMetadata{ + { + AccessKeyId: aws.String("AKIA5QYBVVD2QYI36UZP"), + UserName: aws.String("test-driftctl"), + }, + }}, true) + return true + })).Return(nil).Once() + client.On("ListAccessKeysPages", + &iam.ListAccessKeysInput{ + UserName: aws.String("test-driftctl2"), + }, + mock.MatchedBy(func(callback func(res *iam.ListAccessKeysOutput, lastPage bool) bool) bool { + callback(&iam.ListAccessKeysOutput{AccessKeyMetadata: []*iam.AccessKeyMetadata{ + { + AccessKeyId: aws.String("AKIA5QYBVVD26EJME25D"), + UserName: aws.String("test-driftctl2"), + }, + }}, false) + callback(&iam.ListAccessKeysOutput{AccessKeyMetadata: []*iam.AccessKeyMetadata{ + { + AccessKeyId: aws.String("AKIA5QYBVVD2SWDFVVMG"), + UserName: aws.String("test-driftctl2"), + }, + }}, true) + return true + })).Return(nil).Once() + }, + want: []*iam.AccessKeyMetadata{ + { + AccessKeyId: aws.String("AKIA5QYBVVD223VWU32A"), + UserName: aws.String("test-driftctl"), + }, + { + AccessKeyId: aws.String("AKIA5QYBVVD2QYI36UZP"), + UserName: aws.String("test-driftctl"), + }, + { + AccessKeyId: aws.String("AKIA5QYBVVD223VWU32A"), + UserName: aws.String("test-driftctl"), + }, + { + AccessKeyId: aws.String("AKIA5QYBVVD2QYI36UZP"), + UserName: aws.String("test-driftctl"), + }, + { + AccessKeyId: aws.String("AKIA5QYBVVD26EJME25D"), + UserName: aws.String("test-driftctl2"), + }, + { + AccessKeyId: aws.String("AKIA5QYBVVD2SWDFVVMG"), + UserName: aws.String("test-driftctl2"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(2) + client := &awstest.MockFakeIAM{} + tt.mocks(client) + r := &iamRepository{ + client: client, + cache: store, + } + got, err := r.ListAllAccessKeys(tt.users) + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllAccessKeys(tt.users) + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + for _, user := range tt.users { + assert.IsType(t, []*iam.AccessKeyMetadata{}, store.Get(fmt.Sprintf("iamListAllAccessKeys_user_%s", *user.UserName))) + } + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_IAMRepository_ListAllUsers(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeIAM) + want []*iam.User + wantErr error + }{ + { + name: "List only users with multiple pages", + mocks: func(client *awstest.MockFakeIAM) { + + client.On("ListUsersPages", + &iam.ListUsersInput{}, + mock.MatchedBy(func(callback func(res *iam.ListUsersOutput, lastPage bool) bool) bool { + callback(&iam.ListUsersOutput{Users: []*iam.User{ + { + UserName: aws.String("test-driftctl"), + }, + { + UserName: aws.String("test-driftctl2"), + }, + }}, false) + callback(&iam.ListUsersOutput{Users: []*iam.User{ + { + UserName: aws.String("test-driftctl3"), + }, + { + UserName: aws.String("test-driftctl4"), + }, + }}, true) + return true + })).Return(nil).Once() + }, + want: []*iam.User{ + { + UserName: aws.String("test-driftctl"), + }, + { + UserName: aws.String("test-driftctl2"), + }, + { + UserName: aws.String("test-driftctl3"), + }, + { + UserName: aws.String("test-driftctl4"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := &awstest.MockFakeIAM{} + tt.mocks(client) + r := &iamRepository{ + client: client, + cache: store, + } + got, err := r.ListAllUsers() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllUsers() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*iam.User{}, store.Get("iamListAllUsers")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_IAMRepository_ListAllPolicies(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeIAM) + want []*iam.Policy + wantErr error + }{ + { + name: "List only policies with multiple pages", + mocks: func(client *awstest.MockFakeIAM) { + + client.On("ListPoliciesPages", + &iam.ListPoliciesInput{Scope: aws.String(iam.PolicyScopeTypeLocal)}, + mock.MatchedBy(func(callback func(res *iam.ListPoliciesOutput, lastPage bool) bool) bool { + callback(&iam.ListPoliciesOutput{Policies: []*iam.Policy{ + { + PolicyName: aws.String("test-driftctl"), + }, + { + PolicyName: aws.String("test-driftctl2"), + }, + }}, false) + callback(&iam.ListPoliciesOutput{Policies: []*iam.Policy{ + { + PolicyName: aws.String("test-driftctl3"), + }, + { + PolicyName: aws.String("test-driftctl4"), + }, + }}, true) + return true + })).Return(nil).Once() + }, + want: []*iam.Policy{ + { + PolicyName: aws.String("test-driftctl"), + }, + { + PolicyName: aws.String("test-driftctl2"), + }, + { + PolicyName: aws.String("test-driftctl3"), + }, + { + PolicyName: aws.String("test-driftctl4"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := &awstest.MockFakeIAM{} + tt.mocks(client) + r := &iamRepository{ + client: client, + cache: store, + } + got, err := r.ListAllPolicies() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllPolicies() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*iam.Policy{}, store.Get("iamListAllPolicies")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_IAMRepository_ListAllRoles(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeIAM) + want []*iam.Role + wantErr error + }{ + { + name: "List only roles with multiple pages", + mocks: func(client *awstest.MockFakeIAM) { + + client.On("ListRolesPages", + &iam.ListRolesInput{}, + mock.MatchedBy(func(callback func(res *iam.ListRolesOutput, lastPage bool) bool) bool { + callback(&iam.ListRolesOutput{Roles: []*iam.Role{ + { + RoleName: aws.String("test-driftctl"), + }, + { + RoleName: aws.String("test-driftctl2"), + }, + }}, false) + callback(&iam.ListRolesOutput{Roles: []*iam.Role{ + { + RoleName: aws.String("test-driftctl3"), + }, + { + RoleName: aws.String("test-driftctl4"), + }, + }}, true) + return true + })).Return(nil).Once() + }, + want: []*iam.Role{ + { + RoleName: aws.String("test-driftctl"), + }, + { + RoleName: aws.String("test-driftctl2"), + }, + { + RoleName: aws.String("test-driftctl3"), + }, + { + RoleName: aws.String("test-driftctl4"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := &awstest.MockFakeIAM{} + tt.mocks(client) + r := &iamRepository{ + client: client, + cache: store, + } + got, err := r.ListAllRoles() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllRoles() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*iam.Role{}, store.Get("iamListAllRoles")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_IAMRepository_ListAllRolePolicyAttachments(t *testing.T) { + tests := []struct { + name string + roles []*iam.Role + mocks func(client *awstest.MockFakeIAM) + want []*AttachedRolePolicy + wantErr error + }{ + { + name: "List only role policy attachments with multiple pages", + roles: []*iam.Role{ + { + RoleName: aws.String("test-role"), + }, + { + RoleName: aws.String("test-role2"), + }, + }, + mocks: func(client *awstest.MockFakeIAM) { + + shouldSkipfirst := false + shouldSkipSecond := false + + client.On("ListAttachedRolePoliciesPages", + &iam.ListAttachedRolePoliciesInput{ + RoleName: aws.String("test-role"), + }, + mock.MatchedBy(func(callback func(res *iam.ListAttachedRolePoliciesOutput, lastPage bool) bool) bool { + if shouldSkipfirst { + return false + } + callback(&iam.ListAttachedRolePoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ + { + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy"), + PolicyName: aws.String("policy"), + }, + { + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy2"), + PolicyName: aws.String("policy2"), + }, + }}, false) + callback(&iam.ListAttachedRolePoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ + { + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy3"), + PolicyName: aws.String("policy3"), + }, + }}, true) + shouldSkipfirst = true + return true + })).Return(nil).Once() + + client.On("ListAttachedRolePoliciesPages", + &iam.ListAttachedRolePoliciesInput{ + RoleName: aws.String("test-role2"), + }, + mock.MatchedBy(func(callback func(res *iam.ListAttachedRolePoliciesOutput, lastPage bool) bool) bool { + if shouldSkipSecond { + return false + } + callback(&iam.ListAttachedRolePoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ + { + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy"), + PolicyName: aws.String("policy"), + }, + { + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy2"), + PolicyName: aws.String("policy2"), + }, + }}, false) + callback(&iam.ListAttachedRolePoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ + { + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy3"), + PolicyName: aws.String("policy3"), + }, + }}, true) + shouldSkipSecond = true + return true + })).Return(nil).Once() + }, + want: []*AttachedRolePolicy{ + { + iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy"), + PolicyName: aws.String("policy"), + }, + *aws.String("test-role"), + }, + { + iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy2"), + PolicyName: aws.String("policy2"), + }, + *aws.String("test-role"), + }, + { + iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy3"), + PolicyName: aws.String("policy3"), + }, + *aws.String("test-role"), + }, + { + iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy"), + PolicyName: aws.String("policy"), + }, + *aws.String("test-role2"), + }, + { + iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy2"), + PolicyName: aws.String("policy2"), + }, + *aws.String("test-role2"), + }, + { + iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy3"), + PolicyName: aws.String("policy3"), + }, + *aws.String("test-role2"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(2) + client := &awstest.MockFakeIAM{} + tt.mocks(client) + r := &iamRepository{ + client: client, + cache: store, + } + got, err := r.ListAllRolePolicyAttachments(tt.roles) + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllRolePolicyAttachments(tt.roles) + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + for _, role := range tt.roles { + assert.IsType(t, []*AttachedRolePolicy{}, store.Get(fmt.Sprintf("iamListAllRolePolicyAttachments_role_%s", *role.RoleName))) + } + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_IAMRepository_ListAllRolePolicies(t *testing.T) { + tests := []struct { + name string + roles []*iam.Role + mocks func(client *awstest.MockFakeIAM) + want []RolePolicy + wantErr error + }{ + { + name: "List only role policies with multiple pages", + roles: []*iam.Role{ + { + RoleName: aws.String("test_role_0"), + }, + { + RoleName: aws.String("test_role_1"), + }, + }, + mocks: func(client *awstest.MockFakeIAM) { + firstMockCalled := false + client.On("ListRolePoliciesPages", + &iam.ListRolePoliciesInput{ + RoleName: aws.String("test_role_0"), + }, + mock.MatchedBy(func(callback func(res *iam.ListRolePoliciesOutput, lastPage bool) bool) bool { + if firstMockCalled { + return false + } + callback(&iam.ListRolePoliciesOutput{ + PolicyNames: []*string{ + aws.String("policy-role0-0"), + aws.String("policy-role0-1"), + }, + }, false) + callback(&iam.ListRolePoliciesOutput{ + PolicyNames: []*string{ + aws.String("policy-role0-2"), + }, + }, true) + firstMockCalled = true + return true + })).Once().Return(nil) + client.On("ListRolePoliciesPages", + &iam.ListRolePoliciesInput{ + RoleName: aws.String("test_role_1"), + }, + mock.MatchedBy(func(callback func(res *iam.ListRolePoliciesOutput, lastPage bool) bool) bool { + callback(&iam.ListRolePoliciesOutput{ + PolicyNames: []*string{ + aws.String("policy-role1-0"), + aws.String("policy-role1-1"), + }, + }, false) + callback(&iam.ListRolePoliciesOutput{ + PolicyNames: []*string{ + aws.String("policy-role1-2"), + }, + }, true) + return true + })).Once().Return(nil) + }, + want: []RolePolicy{ + {Policy: "policy-role0-0", RoleName: "test_role_0"}, + {Policy: "policy-role0-1", RoleName: "test_role_0"}, + {Policy: "policy-role0-2", RoleName: "test_role_0"}, + {Policy: "policy-role1-0", RoleName: "test_role_1"}, + {Policy: "policy-role1-1", RoleName: "test_role_1"}, + {Policy: "policy-role1-2", RoleName: "test_role_1"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(2) + client := &awstest.MockFakeIAM{} + tt.mocks(client) + r := &iamRepository{ + client: client, + cache: store, + } + got, err := r.ListAllRolePolicies(tt.roles) + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllRolePolicies(tt.roles) + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + for _, role := range tt.roles { + assert.IsType(t, []RolePolicy{}, store.Get(fmt.Sprintf("iamListAllRolePolicies_role_%s", *role.RoleName))) + } + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_IAMRepository_ListAllUserPolicyAttachments(t *testing.T) { + tests := []struct { + name string + users []*iam.User + mocks func(client *awstest.MockFakeIAM) + want []*AttachedUserPolicy + wantErr error + }{ + { + name: "List only user policy attachments with multiple pages", + users: []*iam.User{ + { + UserName: aws.String("loadbalancer"), + }, + { + UserName: aws.String("loadbalancer2"), + }, + }, + mocks: func(client *awstest.MockFakeIAM) { + + client.On("ListAttachedUserPoliciesPages", + &iam.ListAttachedUserPoliciesInput{ + UserName: aws.String("loadbalancer"), + }, + mock.MatchedBy(func(callback func(res *iam.ListAttachedUserPoliciesOutput, lastPage bool) bool) bool { + callback(&iam.ListAttachedUserPoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ + { + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test"), + PolicyName: aws.String("test-attach"), + }, + { + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test2"), + PolicyName: aws.String("test-attach2"), + }, + }}, false) + callback(&iam.ListAttachedUserPoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ + { + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test3"), + PolicyName: aws.String("test-attach3"), + }, + }}, true) + return true + })).Return(nil).Once() + + client.On("ListAttachedUserPoliciesPages", + &iam.ListAttachedUserPoliciesInput{ + UserName: aws.String("loadbalancer2"), + }, + mock.MatchedBy(func(callback func(res *iam.ListAttachedUserPoliciesOutput, lastPage bool) bool) bool { + callback(&iam.ListAttachedUserPoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ + { + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test"), + PolicyName: aws.String("test-attach"), + }, + { + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test2"), + PolicyName: aws.String("test-attach2"), + }, + }}, false) + callback(&iam.ListAttachedUserPoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ + { + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test3"), + PolicyName: aws.String("test-attach3"), + }, + }}, true) + return true + })).Return(nil).Once() + }, + + want: []*AttachedUserPolicy{ + { + iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test"), + PolicyName: aws.String("test-attach"), + }, + *aws.String("loadbalancer"), + }, + { + iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test2"), + PolicyName: aws.String("test-attach2"), + }, + *aws.String("loadbalancer"), + }, + { + iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test3"), + PolicyName: aws.String("test-attach3"), + }, + *aws.String("loadbalancer"), + }, + { + iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test"), + PolicyName: aws.String("test-attach"), + }, + *aws.String("loadbalancer2"), + }, + { + iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test2"), + PolicyName: aws.String("test-attach2"), + }, + *aws.String("loadbalancer2"), + }, + { + iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test3"), + PolicyName: aws.String("test-attach3"), + }, + *aws.String("loadbalancer2"), + }, + { + iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test"), + PolicyName: aws.String("test-attach"), + }, + *aws.String("loadbalancer2"), + }, + { + iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test2"), + PolicyName: aws.String("test-attach2"), + }, + *aws.String("loadbalancer2"), + }, + { + iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test3"), + PolicyName: aws.String("test-attach3"), + }, + *aws.String("loadbalancer2"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(2) + client := &awstest.MockFakeIAM{} + tt.mocks(client) + r := &iamRepository{ + client: client, + cache: store, + } + got, err := r.ListAllUserPolicyAttachments(tt.users) + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllUserPolicyAttachments(tt.users) + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + for _, user := range tt.users { + assert.IsType(t, []*AttachedUserPolicy{}, store.Get(fmt.Sprintf("iamListAllUserPolicyAttachments_user_%s", *user.UserName))) + } + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_IAMRepository_ListAllUserPolicies(t *testing.T) { + tests := []struct { + name string + users []*iam.User + mocks func(client *awstest.MockFakeIAM) + want []string + wantErr error + }{ + { + name: "List only user policies with multiple pages", + users: []*iam.User{ + { + UserName: aws.String("loadbalancer"), + }, + { + UserName: aws.String("loadbalancer2"), + }, + }, + mocks: func(client *awstest.MockFakeIAM) { + + client.On("ListUserPoliciesPages", + &iam.ListUserPoliciesInput{ + UserName: aws.String("loadbalancer"), + }, + mock.MatchedBy(func(callback func(res *iam.ListUserPoliciesOutput, lastPage bool) bool) bool { + callback(&iam.ListUserPoliciesOutput{PolicyNames: []*string{ + aws.String("test"), + aws.String("test2"), + aws.String("test3"), + }}, false) + callback(&iam.ListUserPoliciesOutput{PolicyNames: []*string{ + aws.String("test4"), + }}, true) + return true + })).Return(nil).Once() + + client.On("ListUserPoliciesPages", + &iam.ListUserPoliciesInput{ + UserName: aws.String("loadbalancer2"), + }, + mock.MatchedBy(func(callback func(res *iam.ListUserPoliciesOutput, lastPage bool) bool) bool { + callback(&iam.ListUserPoliciesOutput{PolicyNames: []*string{ + aws.String("test2"), + aws.String("test22"), + aws.String("test23"), + }}, false) + callback(&iam.ListUserPoliciesOutput{PolicyNames: []*string{ + aws.String("test24"), + }}, true) + return true + })).Return(nil).Once() + }, + want: []string{ + *aws.String("loadbalancer:test"), + *aws.String("loadbalancer:test2"), + *aws.String("loadbalancer:test3"), + *aws.String("loadbalancer:test4"), + *aws.String("loadbalancer2:test"), + *aws.String("loadbalancer2:test2"), + *aws.String("loadbalancer2:test3"), + *aws.String("loadbalancer2:test4"), + *aws.String("loadbalancer2:test2"), + *aws.String("loadbalancer2:test22"), + *aws.String("loadbalancer2:test23"), + *aws.String("loadbalancer2:test24"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(2) + client := &awstest.MockFakeIAM{} + tt.mocks(client) + r := &iamRepository{ + client: client, + cache: store, + } + got, err := r.ListAllUserPolicies(tt.users) + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllUserPolicies(tt.users) + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + for _, user := range tt.users { + assert.IsType(t, []string{}, store.Get(fmt.Sprintf("iamListAllUserPolicies_user_%s", *user.UserName))) + } + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_IAMRepository_ListAllGroups(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeIAM) + want []*iam.Group + wantErr error + }{ + { + name: "List groups with multiple pages", + mocks: func(client *awstest.MockFakeIAM) { + + client.On("ListGroupsPages", + &iam.ListGroupsInput{}, + mock.MatchedBy(func(callback func(res *iam.ListGroupsOutput, lastPage bool) bool) bool { + callback(&iam.ListGroupsOutput{Groups: []*iam.Group{ + { + GroupName: aws.String("group1"), + }, + { + GroupName: aws.String("group2"), + }, + }}, false) + callback(&iam.ListGroupsOutput{Groups: []*iam.Group{ + { + GroupName: aws.String("group3"), + }, + { + GroupName: aws.String("group4"), + }, + }}, true) + return true + })).Return(nil).Once() + }, + want: []*iam.Group{ + { + GroupName: aws.String("group1"), + }, + { + GroupName: aws.String("group2"), + }, + { + GroupName: aws.String("group3"), + }, + { + GroupName: aws.String("group4"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := &awstest.MockFakeIAM{} + tt.mocks(client) + r := &iamRepository{ + client: client, + cache: store, + } + got, err := r.ListAllGroups() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllGroups() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*iam.Group{}, store.Get("iamListAllGroups")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_IAMRepository_ListAllGroupPolicies(t *testing.T) { + tests := []struct { + name string + groups []*iam.Group + mocks func(client *awstest.MockFakeIAM) + want []string + wantErr error + }{ + { + name: "List only group policies with multiple pages", + groups: []*iam.Group{ + { + GroupName: aws.String("group1"), + }, + { + GroupName: aws.String("group2"), + }, + }, + mocks: func(client *awstest.MockFakeIAM) { + firstMockCalled := false + client.On("ListGroupPoliciesPages", + &iam.ListGroupPoliciesInput{ + GroupName: aws.String("group1"), + }, + mock.MatchedBy(func(callback func(res *iam.ListGroupPoliciesOutput, lastPage bool) bool) bool { + if firstMockCalled { + return false + } + callback(&iam.ListGroupPoliciesOutput{PolicyNames: []*string{ + aws.String("policy1"), + aws.String("policy2"), + aws.String("policy3"), + }}, false) + callback(&iam.ListGroupPoliciesOutput{PolicyNames: []*string{ + aws.String("policy4"), + }}, true) + firstMockCalled = true + return true + })).Return(nil).Once() + + client.On("ListGroupPoliciesPages", + &iam.ListGroupPoliciesInput{ + GroupName: aws.String("group2"), + }, + mock.MatchedBy(func(callback func(res *iam.ListGroupPoliciesOutput, lastPage bool) bool) bool { + callback(&iam.ListGroupPoliciesOutput{PolicyNames: []*string{ + aws.String("policy2"), + aws.String("policy22"), + aws.String("policy23"), + }}, false) + callback(&iam.ListGroupPoliciesOutput{PolicyNames: []*string{ + aws.String("policy24"), + }}, true) + return true + })).Return(nil).Once() + }, + want: []string{ + *aws.String("group1:policy1"), + *aws.String("group1:policy2"), + *aws.String("group1:policy3"), + *aws.String("group1:policy4"), + *aws.String("group2:policy2"), + *aws.String("group2:policy22"), + *aws.String("group2:policy23"), + *aws.String("group2:policy24"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(2) + client := &awstest.MockFakeIAM{} + tt.mocks(client) + r := &iamRepository{ + client: client, + cache: store, + } + got, err := r.ListAllGroupPolicies(tt.groups) + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllGroupPolicies(tt.groups) + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + for _, group := range tt.groups { + assert.IsType(t, []string{}, store.Get(fmt.Sprintf("iamListAllGroupPolicies_group_%s", *group.GroupName))) + } + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} diff --git a/enumeration/remote/aws/repository/kms_repository.go b/enumeration/remote/aws/repository/kms_repository.go new file mode 100644 index 00000000..347ad5ca --- /dev/null +++ b/enumeration/remote/aws/repository/kms_repository.go @@ -0,0 +1,146 @@ +package repository + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "sync" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go/service/kms/kmsiface" + "github.com/sirupsen/logrus" +) + +type KMSRepository interface { + ListAllKeys() ([]*kms.KeyListEntry, error) + ListAllAliases() ([]*kms.AliasListEntry, error) +} + +type kmsRepository struct { + client kmsiface.KMSAPI + cache cache.Cache + describeKeyLock *sync.Mutex +} + +func NewKMSRepository(session *session.Session, c cache.Cache) *kmsRepository { + return &kmsRepository{ + kms.New(session), + c, + &sync.Mutex{}, + } +} + +func (r *kmsRepository) ListAllKeys() ([]*kms.KeyListEntry, error) { + if v := r.cache.Get("kmsListAllKeys"); v != nil { + return v.([]*kms.KeyListEntry), nil + } + + var keys []*kms.KeyListEntry + input := kms.ListKeysInput{} + err := r.client.ListKeysPages(&input, + func(resp *kms.ListKeysOutput, lastPage bool) bool { + keys = append(keys, resp.Keys...) + return !lastPage + }, + ) + if err != nil { + return nil, err + } + customerKeys, err := r.filterKeys(keys) + if err != nil { + return nil, err + } + + r.cache.Put("kmsListAllKeys", customerKeys) + return customerKeys, nil +} + +func (r *kmsRepository) ListAllAliases() ([]*kms.AliasListEntry, error) { + if v := r.cache.Get("kmsListAllAliases"); v != nil { + return v.([]*kms.AliasListEntry), nil + } + + var aliases []*kms.AliasListEntry + input := kms.ListAliasesInput{} + err := r.client.ListAliasesPages(&input, + func(resp *kms.ListAliasesOutput, lastPage bool) bool { + aliases = append(aliases, resp.Aliases...) + return !lastPage + }, + ) + if err != nil { + return nil, err + } + + result, err := r.filterAliases(aliases) + if err != nil { + return nil, err + } + r.cache.Put("kmsListAllAliases", result) + return result, nil +} + +func (r *kmsRepository) describeKey(keyId *string) (*kms.DescribeKeyOutput, error) { + var results interface{} + // Since this method can be call in parallel, we should lock and unlock if we want to be sure to hit the cache + r.describeKeyLock.Lock() + defer r.describeKeyLock.Unlock() + cacheKey := fmt.Sprintf("kmsDescribeKey-%s", *keyId) + results = r.cache.Get(cacheKey) + if results == nil { + var err error + results, err = r.client.DescribeKey(&kms.DescribeKeyInput{KeyId: keyId}) + if err != nil { + return nil, err + } + r.cache.Put(cacheKey, results) + } + describeKey := results.(*kms.DescribeKeyOutput) + if aws.StringValue(describeKey.KeyMetadata.KeyState) == kms.KeyStatePendingDeletion { + return nil, nil + } + return describeKey, nil +} + +func (r *kmsRepository) filterKeys(keys []*kms.KeyListEntry) ([]*kms.KeyListEntry, error) { + var customerKeys []*kms.KeyListEntry + for _, key := range keys { + k, err := r.describeKey(key.KeyId) + if err != nil { + return nil, err + } + if k == nil { + logrus.WithFields(logrus.Fields{ + "id": *key.KeyId, + }).Debug("Ignored kms key from listing since it is pending from deletion") + continue + } + if k.KeyMetadata.KeyManager != nil && *k.KeyMetadata.KeyManager != "AWS" { + customerKeys = append(customerKeys, key) + } + } + return customerKeys, nil +} + +func (r *kmsRepository) filterAliases(aliases []*kms.AliasListEntry) ([]*kms.AliasListEntry, error) { + var customerAliases []*kms.AliasListEntry + for _, alias := range aliases { + if alias.AliasName != nil && !strings.HasPrefix(*alias.AliasName, "alias/aws/") { + k, err := r.describeKey(alias.TargetKeyId) + if err != nil { + return nil, err + } + if k == nil { + logrus.WithFields(logrus.Fields{ + "id": *alias.TargetKeyId, + "alias": *alias.AliasName, + }).Debug("Ignored kms key alias from listing since it is linked to a pending from deletion key") + continue + } + customerAliases = append(customerAliases, alias) + } + } + return customerAliases, nil +} diff --git a/enumeration/remote/aws/repository/kms_repository_test.go b/enumeration/remote/aws/repository/kms_repository_test.go new file mode 100644 index 00000000..30ea14ec --- /dev/null +++ b/enumeration/remote/aws/repository/kms_repository_test.go @@ -0,0 +1,249 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "sync" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/kms" + awstest "github.com/snyk/driftctl/test/aws" + "github.com/stretchr/testify/mock" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" +) + +func Test_KMSRepository_ListAllKeys(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeKMS) + want []*kms.KeyListEntry + wantErr error + }{ + { + name: "List only enabled keys", + mocks: func(client *awstest.MockFakeKMS) { + client.On("ListKeysPages", + &kms.ListKeysInput{}, + mock.MatchedBy(func(callback func(res *kms.ListKeysOutput, lastPage bool) bool) bool { + callback(&kms.ListKeysOutput{ + Keys: []*kms.KeyListEntry{ + {KeyId: aws.String("1")}, + {KeyId: aws.String("2")}, + }, + }, true) + return true + })).Return(nil).Once() + client.On("DescribeKey", + &kms.DescribeKeyInput{ + KeyId: aws.String("1"), + }).Return(&kms.DescribeKeyOutput{ + KeyMetadata: &kms.KeyMetadata{ + KeyId: aws.String("1"), + KeyManager: aws.String("CUSTOMER"), + KeyState: aws.String(kms.KeyStateEnabled), + }, + }, nil).Once() + client.On("DescribeKey", + &kms.DescribeKeyInput{ + KeyId: aws.String("2"), + }).Return(&kms.DescribeKeyOutput{ + KeyMetadata: &kms.KeyMetadata{ + KeyId: aws.String("2"), + KeyManager: aws.String("CUSTOMER"), + KeyState: aws.String(kms.KeyStatePendingDeletion), + }, + }, nil).Once() + }, + want: []*kms.KeyListEntry{ + {KeyId: aws.String("1")}, + }, + }, + { + name: "List only customer keys", + mocks: func(client *awstest.MockFakeKMS) { + client.On("ListKeysPages", + &kms.ListKeysInput{}, + mock.MatchedBy(func(callback func(res *kms.ListKeysOutput, lastPage bool) bool) bool { + callback(&kms.ListKeysOutput{ + Keys: []*kms.KeyListEntry{ + {KeyId: aws.String("1")}, + {KeyId: aws.String("2")}, + {KeyId: aws.String("3")}, + }, + }, true) + return true + })).Return(nil).Once() + client.On("DescribeKey", + &kms.DescribeKeyInput{ + KeyId: aws.String("1"), + }).Return(&kms.DescribeKeyOutput{ + KeyMetadata: &kms.KeyMetadata{ + KeyId: aws.String("1"), + KeyManager: aws.String("CUSTOMER"), + KeyState: aws.String(kms.KeyStateEnabled), + }, + }, nil).Once() + client.On("DescribeKey", + &kms.DescribeKeyInput{ + KeyId: aws.String("2"), + }).Return(&kms.DescribeKeyOutput{ + KeyMetadata: &kms.KeyMetadata{ + KeyId: aws.String("2"), + KeyManager: aws.String("AWS"), + KeyState: aws.String(kms.KeyStateEnabled), + }, + }, nil).Once() + client.On("DescribeKey", + &kms.DescribeKeyInput{ + KeyId: aws.String("3"), + }).Return(&kms.DescribeKeyOutput{ + KeyMetadata: &kms.KeyMetadata{ + KeyId: aws.String("3"), + KeyManager: aws.String("AWS"), + KeyState: aws.String(kms.KeyStateEnabled), + }, + }, nil).Once() + }, + want: []*kms.KeyListEntry{ + {KeyId: aws.String("1")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := awstest.MockFakeKMS{} + tt.mocks(&client) + r := &kmsRepository{ + client: &client, + cache: store, + describeKeyLock: &sync.Mutex{}, + } + got, err := r.ListAllKeys() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllKeys() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*kms.KeyListEntry{}, store.Get("kmsListAllKeys")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_KMSRepository_ListAllAliases(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeKMS) + want []*kms.AliasListEntry + wantErr error + }{ + { + name: "List only aliases for enabled keys", + mocks: func(client *awstest.MockFakeKMS) { + client.On("ListAliasesPages", + &kms.ListAliasesInput{}, + mock.MatchedBy(func(callback func(res *kms.ListAliasesOutput, lastPage bool) bool) bool { + callback(&kms.ListAliasesOutput{ + Aliases: []*kms.AliasListEntry{ + {AliasName: aws.String("alias/1"), TargetKeyId: aws.String("key-id-1")}, + {AliasName: aws.String("alias/2"), TargetKeyId: aws.String("key-id-2")}, + }, + }, true) + return true + })).Return(nil).Once() + client.On("DescribeKey", &kms.DescribeKeyInput{KeyId: aws.String("key-id-1")}).Return(&kms.DescribeKeyOutput{ + KeyMetadata: &kms.KeyMetadata{ + KeyState: aws.String(kms.KeyStatePendingDeletion), + }, + }, nil) + client.On("DescribeKey", &kms.DescribeKeyInput{KeyId: aws.String("key-id-2")}).Return(&kms.DescribeKeyOutput{ + KeyMetadata: &kms.KeyMetadata{ + KeyState: aws.String(kms.KeyStateEnabled), + }, + }, nil) + }, + want: []*kms.AliasListEntry{ + {AliasName: aws.String("alias/2"), TargetKeyId: aws.String("key-id-2")}, + }, + }, + { + name: "List only customer aliases", + mocks: func(client *awstest.MockFakeKMS) { + client.On("ListAliasesPages", + &kms.ListAliasesInput{}, + mock.MatchedBy(func(callback func(res *kms.ListAliasesOutput, lastPage bool) bool) bool { + callback(&kms.ListAliasesOutput{ + Aliases: []*kms.AliasListEntry{ + {AliasName: aws.String("alias/1"), TargetKeyId: aws.String("key-id-1")}, + {AliasName: aws.String("alias/foo/2"), TargetKeyId: aws.String("key-id-2")}, + {AliasName: aws.String("alias/aw/3"), TargetKeyId: aws.String("key-id-3")}, + {AliasName: aws.String("alias/aws/4"), TargetKeyId: aws.String("key-id-4")}, + {AliasName: aws.String("alias/aws/5"), TargetKeyId: aws.String("key-id-5")}, + {AliasName: aws.String("alias/awss/6"), TargetKeyId: aws.String("key-id-6")}, + {AliasName: aws.String("alias/aws7"), TargetKeyId: aws.String("key-id-7")}, + }, + }, true) + return true + })).Return(nil).Once() + client.On("DescribeKey", mock.Anything).Return(&kms.DescribeKeyOutput{ + KeyMetadata: &kms.KeyMetadata{ + KeyState: aws.String(kms.KeyStateEnabled), + }, + }, nil) + }, + want: []*kms.AliasListEntry{ + {AliasName: aws.String("alias/1"), TargetKeyId: aws.String("key-id-1")}, + {AliasName: aws.String("alias/foo/2"), TargetKeyId: aws.String("key-id-2")}, + {AliasName: aws.String("alias/aw/3"), TargetKeyId: aws.String("key-id-3")}, + {AliasName: aws.String("alias/awss/6"), TargetKeyId: aws.String("key-id-6")}, + {AliasName: aws.String("alias/aws7"), TargetKeyId: aws.String("key-id-7")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := awstest.MockFakeKMS{} + tt.mocks(&client) + r := &kmsRepository{ + client: &client, + cache: store, + describeKeyLock: &sync.Mutex{}, + } + got, err := r.ListAllAliases() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllAliases() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*kms.AliasListEntry{}, store.Get("kmsListAllAliases")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} diff --git a/enumeration/remote/aws/repository/lambda_repository.go b/enumeration/remote/aws/repository/lambda_repository.go new file mode 100644 index 00000000..d9377d26 --- /dev/null +++ b/enumeration/remote/aws/repository/lambda_repository.go @@ -0,0 +1,63 @@ +package repository + +import ( + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/lambda" + "github.com/aws/aws-sdk-go/service/lambda/lambdaiface" + "github.com/snyk/driftctl/enumeration/remote/cache" +) + +type LambdaRepository interface { + ListAllLambdaFunctions() ([]*lambda.FunctionConfiguration, error) + ListAllLambdaEventSourceMappings() ([]*lambda.EventSourceMappingConfiguration, error) +} + +type lambdaRepository struct { + client lambdaiface.LambdaAPI + cache cache.Cache +} + +func NewLambdaRepository(session *session.Session, c cache.Cache) *lambdaRepository { + return &lambdaRepository{ + lambda.New(session), + c, + } +} + +func (r *lambdaRepository) ListAllLambdaFunctions() ([]*lambda.FunctionConfiguration, error) { + if v := r.cache.Get("lambdaListAllLambdaFunctions"); v != nil { + return v.([]*lambda.FunctionConfiguration), nil + } + + var functions []*lambda.FunctionConfiguration + input := &lambda.ListFunctionsInput{} + err := r.client.ListFunctionsPages(input, func(res *lambda.ListFunctionsOutput, lastPage bool) bool { + functions = append(functions, res.Functions...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put("lambdaListAllLambdaFunctions", functions) + return functions, nil +} + +func (r *lambdaRepository) ListAllLambdaEventSourceMappings() ([]*lambda.EventSourceMappingConfiguration, error) { + if v := r.cache.Get("lambdaListAllLambdaEventSourceMappings"); v != nil { + return v.([]*lambda.EventSourceMappingConfiguration), nil + } + + var eventSourceMappingConfigurations []*lambda.EventSourceMappingConfiguration + input := &lambda.ListEventSourceMappingsInput{} + err := r.client.ListEventSourceMappingsPages(input, func(res *lambda.ListEventSourceMappingsOutput, lastPage bool) bool { + eventSourceMappingConfigurations = append(eventSourceMappingConfigurations, res.EventSourceMappings...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put("lambdaListAllLambdaEventSourceMappings", eventSourceMappingConfigurations) + return eventSourceMappingConfigurations, nil +} diff --git a/enumeration/remote/aws/repository/lambda_repository_test.go b/enumeration/remote/aws/repository/lambda_repository_test.go new file mode 100644 index 00000000..4981884e --- /dev/null +++ b/enumeration/remote/aws/repository/lambda_repository_test.go @@ -0,0 +1,169 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + awstest "github.com/snyk/driftctl/test/aws" + "github.com/stretchr/testify/mock" + + "github.com/aws/aws-sdk-go/service/lambda" + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" +) + +func Test_lambdaRepository_ListAllLambdaFunctions(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeLambda) + want []*lambda.FunctionConfiguration + wantErr error + }{ + { + name: "List with 2 pages", + mocks: func(client *awstest.MockFakeLambda) { + client.On("ListFunctionsPages", + &lambda.ListFunctionsInput{}, + mock.MatchedBy(func(callback func(res *lambda.ListFunctionsOutput, lastPage bool) bool) bool { + callback(&lambda.ListFunctionsOutput{ + Functions: []*lambda.FunctionConfiguration{ + {FunctionName: aws.String("1")}, + {FunctionName: aws.String("2")}, + {FunctionName: aws.String("3")}, + {FunctionName: aws.String("4")}, + }, + }, false) + callback(&lambda.ListFunctionsOutput{ + Functions: []*lambda.FunctionConfiguration{ + {FunctionName: aws.String("5")}, + {FunctionName: aws.String("6")}, + {FunctionName: aws.String("7")}, + {FunctionName: aws.String("8")}, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*lambda.FunctionConfiguration{ + {FunctionName: aws.String("1")}, + {FunctionName: aws.String("2")}, + {FunctionName: aws.String("3")}, + {FunctionName: aws.String("4")}, + {FunctionName: aws.String("5")}, + {FunctionName: aws.String("6")}, + {FunctionName: aws.String("7")}, + {FunctionName: aws.String("8")}, + }, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := &awstest.MockFakeLambda{} + tt.mocks(client) + r := &lambdaRepository{ + client: client, + cache: store, + } + got, err := r.ListAllLambdaFunctions() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllLambdaFunctions() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*lambda.FunctionConfiguration{}, store.Get("lambdaListAllLambdaFunctions")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_lambdaRepository_ListAllLambdaEventSourceMappings(t *testing.T) { + tests := []struct { + name string + mocks func(mock *awstest.MockFakeLambda) + want []*lambda.EventSourceMappingConfiguration + wantErr error + }{ + { + name: "List with 2 pages", + mocks: func(client *awstest.MockFakeLambda) { + client.On("ListEventSourceMappingsPages", + &lambda.ListEventSourceMappingsInput{}, + mock.MatchedBy(func(callback func(res *lambda.ListEventSourceMappingsOutput, lastPage bool) bool) bool { + callback(&lambda.ListEventSourceMappingsOutput{ + EventSourceMappings: []*lambda.EventSourceMappingConfiguration{ + {UUID: aws.String("1")}, + {UUID: aws.String("2")}, + {UUID: aws.String("3")}, + {UUID: aws.String("4")}, + }, + }, false) + callback(&lambda.ListEventSourceMappingsOutput{ + EventSourceMappings: []*lambda.EventSourceMappingConfiguration{ + {UUID: aws.String("5")}, + {UUID: aws.String("6")}, + {UUID: aws.String("7")}, + {UUID: aws.String("8")}, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*lambda.EventSourceMappingConfiguration{ + {UUID: aws.String("1")}, + {UUID: aws.String("2")}, + {UUID: aws.String("3")}, + {UUID: aws.String("4")}, + {UUID: aws.String("5")}, + {UUID: aws.String("6")}, + {UUID: aws.String("7")}, + {UUID: aws.String("8")}, + }, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := &awstest.MockFakeLambda{} + tt.mocks(client) + r := &lambdaRepository{ + client: client, + cache: store, + } + got, err := r.ListAllLambdaEventSourceMappings() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllLambdaEventSourceMappings() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*lambda.EventSourceMappingConfiguration{}, store.Get("lambdaListAllLambdaEventSourceMappings")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} diff --git a/pkg/remote/aws/repository/mock_ApiGatewayRepository.go b/enumeration/remote/aws/repository/mock_ApiGatewayRepository.go similarity index 100% rename from pkg/remote/aws/repository/mock_ApiGatewayRepository.go rename to enumeration/remote/aws/repository/mock_ApiGatewayRepository.go diff --git a/pkg/remote/aws/repository/mock_ApiGatewayV2Repository.go b/enumeration/remote/aws/repository/mock_ApiGatewayV2Repository.go similarity index 100% rename from pkg/remote/aws/repository/mock_ApiGatewayV2Repository.go rename to enumeration/remote/aws/repository/mock_ApiGatewayV2Repository.go diff --git a/pkg/remote/aws/repository/mock_AppAutoScalingRepository.go b/enumeration/remote/aws/repository/mock_AppAutoScalingRepository.go similarity index 100% rename from pkg/remote/aws/repository/mock_AppAutoScalingRepository.go rename to enumeration/remote/aws/repository/mock_AppAutoScalingRepository.go diff --git a/pkg/remote/aws/repository/mock_AutoScalingRepository.go b/enumeration/remote/aws/repository/mock_AutoScalingRepository.go similarity index 100% rename from pkg/remote/aws/repository/mock_AutoScalingRepository.go rename to enumeration/remote/aws/repository/mock_AutoScalingRepository.go diff --git a/pkg/remote/aws/repository/mock_CloudformationRepository.go b/enumeration/remote/aws/repository/mock_CloudformationRepository.go similarity index 100% rename from pkg/remote/aws/repository/mock_CloudformationRepository.go rename to enumeration/remote/aws/repository/mock_CloudformationRepository.go diff --git a/pkg/remote/aws/repository/mock_CloudfrontRepository.go b/enumeration/remote/aws/repository/mock_CloudfrontRepository.go similarity index 100% rename from pkg/remote/aws/repository/mock_CloudfrontRepository.go rename to enumeration/remote/aws/repository/mock_CloudfrontRepository.go diff --git a/pkg/remote/aws/repository/mock_DynamoDBRepository.go b/enumeration/remote/aws/repository/mock_DynamoDBRepository.go similarity index 100% rename from pkg/remote/aws/repository/mock_DynamoDBRepository.go rename to enumeration/remote/aws/repository/mock_DynamoDBRepository.go diff --git a/pkg/remote/aws/repository/mock_EC2Repository.go b/enumeration/remote/aws/repository/mock_EC2Repository.go similarity index 100% rename from pkg/remote/aws/repository/mock_EC2Repository.go rename to enumeration/remote/aws/repository/mock_EC2Repository.go diff --git a/pkg/remote/aws/repository/mock_ECRRepository.go b/enumeration/remote/aws/repository/mock_ECRRepository.go similarity index 100% rename from pkg/remote/aws/repository/mock_ECRRepository.go rename to enumeration/remote/aws/repository/mock_ECRRepository.go diff --git a/pkg/remote/aws/repository/mock_ELBRepository.go b/enumeration/remote/aws/repository/mock_ELBRepository.go similarity index 100% rename from pkg/remote/aws/repository/mock_ELBRepository.go rename to enumeration/remote/aws/repository/mock_ELBRepository.go diff --git a/pkg/remote/aws/repository/mock_ELBV2Repository.go b/enumeration/remote/aws/repository/mock_ELBV2Repository.go similarity index 100% rename from pkg/remote/aws/repository/mock_ELBV2Repository.go rename to enumeration/remote/aws/repository/mock_ELBV2Repository.go diff --git a/pkg/remote/aws/repository/mock_ElastiCacheRepository.go b/enumeration/remote/aws/repository/mock_ElastiCacheRepository.go similarity index 100% rename from pkg/remote/aws/repository/mock_ElastiCacheRepository.go rename to enumeration/remote/aws/repository/mock_ElastiCacheRepository.go diff --git a/pkg/remote/aws/repository/mock_IAMRepository.go b/enumeration/remote/aws/repository/mock_IAMRepository.go similarity index 100% rename from pkg/remote/aws/repository/mock_IAMRepository.go rename to enumeration/remote/aws/repository/mock_IAMRepository.go diff --git a/pkg/remote/aws/repository/mock_KMSRepository.go b/enumeration/remote/aws/repository/mock_KMSRepository.go similarity index 100% rename from pkg/remote/aws/repository/mock_KMSRepository.go rename to enumeration/remote/aws/repository/mock_KMSRepository.go diff --git a/pkg/remote/aws/repository/mock_LambdaRepository.go b/enumeration/remote/aws/repository/mock_LambdaRepository.go similarity index 100% rename from pkg/remote/aws/repository/mock_LambdaRepository.go rename to enumeration/remote/aws/repository/mock_LambdaRepository.go diff --git a/pkg/remote/aws/repository/mock_RDSRepository.go b/enumeration/remote/aws/repository/mock_RDSRepository.go similarity index 100% rename from pkg/remote/aws/repository/mock_RDSRepository.go rename to enumeration/remote/aws/repository/mock_RDSRepository.go diff --git a/pkg/remote/aws/repository/mock_Route53Repository.go b/enumeration/remote/aws/repository/mock_Route53Repository.go similarity index 100% rename from pkg/remote/aws/repository/mock_Route53Repository.go rename to enumeration/remote/aws/repository/mock_Route53Repository.go diff --git a/pkg/remote/aws/repository/mock_S3Repository.go b/enumeration/remote/aws/repository/mock_S3Repository.go similarity index 100% rename from pkg/remote/aws/repository/mock_S3Repository.go rename to enumeration/remote/aws/repository/mock_S3Repository.go diff --git a/pkg/remote/aws/repository/mock_SNSRepository.go b/enumeration/remote/aws/repository/mock_SNSRepository.go similarity index 100% rename from pkg/remote/aws/repository/mock_SNSRepository.go rename to enumeration/remote/aws/repository/mock_SNSRepository.go diff --git a/pkg/remote/aws/repository/mock_SQSRepository.go b/enumeration/remote/aws/repository/mock_SQSRepository.go similarity index 100% rename from pkg/remote/aws/repository/mock_SQSRepository.go rename to enumeration/remote/aws/repository/mock_SQSRepository.go diff --git a/enumeration/remote/aws/repository/rds_repository.go b/enumeration/remote/aws/repository/rds_repository.go new file mode 100644 index 00000000..95d2ebfb --- /dev/null +++ b/enumeration/remote/aws/repository/rds_repository.go @@ -0,0 +1,82 @@ +package repository + +import ( + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/rds" + "github.com/aws/aws-sdk-go/service/rds/rdsiface" + "github.com/snyk/driftctl/enumeration/remote/cache" +) + +type RDSRepository interface { + ListAllDBInstances() ([]*rds.DBInstance, error) + ListAllDBSubnetGroups() ([]*rds.DBSubnetGroup, error) + ListAllDBClusters() ([]*rds.DBCluster, error) +} + +type rdsRepository struct { + client rdsiface.RDSAPI + cache cache.Cache +} + +func NewRDSRepository(session *session.Session, c cache.Cache) *rdsRepository { + return &rdsRepository{ + rds.New(session), + c, + } +} + +func (r *rdsRepository) ListAllDBInstances() ([]*rds.DBInstance, error) { + if v := r.cache.Get("rdsListAllDBInstances"); v != nil { + return v.([]*rds.DBInstance), nil + } + + var result []*rds.DBInstance + input := &rds.DescribeDBInstancesInput{} + err := r.client.DescribeDBInstancesPages(input, func(res *rds.DescribeDBInstancesOutput, lastPage bool) bool { + result = append(result, res.DBInstances...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put("rdsListAllDBInstances", result) + return result, nil +} + +func (r *rdsRepository) ListAllDBSubnetGroups() ([]*rds.DBSubnetGroup, error) { + if v := r.cache.Get("rdsListAllDBSubnetGroups"); v != nil { + return v.([]*rds.DBSubnetGroup), nil + } + + var subnetGroups []*rds.DBSubnetGroup + input := rds.DescribeDBSubnetGroupsInput{} + err := r.client.DescribeDBSubnetGroupsPages(&input, + func(resp *rds.DescribeDBSubnetGroupsOutput, lastPage bool) bool { + subnetGroups = append(subnetGroups, resp.DBSubnetGroups...) + return !lastPage + }, + ) + + r.cache.Put("rdsListAllDBSubnetGroups", subnetGroups) + return subnetGroups, err +} + +func (r *rdsRepository) ListAllDBClusters() ([]*rds.DBCluster, error) { + cacheKey := "rdsListAllDBClusters" + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*rds.DBCluster), nil + } + + var clusters []*rds.DBCluster + input := rds.DescribeDBClustersInput{} + err := r.client.DescribeDBClustersPages(&input, + func(resp *rds.DescribeDBClustersOutput, lastPage bool) bool { + clusters = append(clusters, resp.DBClusters...) + return !lastPage + }, + ) + + r.cache.Put(cacheKey, clusters) + return clusters, err +} diff --git a/enumeration/remote/aws/repository/rds_repository_test.go b/enumeration/remote/aws/repository/rds_repository_test.go new file mode 100644 index 00000000..639ffb8c --- /dev/null +++ b/enumeration/remote/aws/repository/rds_repository_test.go @@ -0,0 +1,245 @@ +package repository + +import ( + cache2 "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/rds" + "github.com/r3labs/diff/v2" + awstest "github.com/snyk/driftctl/test/aws" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func Test_rdsRepository_ListAllDBInstances(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeRDS) + want []*rds.DBInstance + wantErr error + }{ + { + name: "List with 2 pages", + mocks: func(client *awstest.MockFakeRDS) { + client.On("DescribeDBInstancesPages", + &rds.DescribeDBInstancesInput{}, + mock.MatchedBy(func(callback func(res *rds.DescribeDBInstancesOutput, lastPage bool) bool) bool { + callback(&rds.DescribeDBInstancesOutput{ + DBInstances: []*rds.DBInstance{ + {DBInstanceIdentifier: aws.String("1")}, + {DBInstanceIdentifier: aws.String("2")}, + {DBInstanceIdentifier: aws.String("3")}, + }, + }, false) + callback(&rds.DescribeDBInstancesOutput{ + DBInstances: []*rds.DBInstance{ + {DBInstanceIdentifier: aws.String("4")}, + {DBInstanceIdentifier: aws.String("5")}, + {DBInstanceIdentifier: aws.String("6")}, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*rds.DBInstance{ + {DBInstanceIdentifier: aws.String("1")}, + {DBInstanceIdentifier: aws.String("2")}, + {DBInstanceIdentifier: aws.String("3")}, + {DBInstanceIdentifier: aws.String("4")}, + {DBInstanceIdentifier: aws.String("5")}, + {DBInstanceIdentifier: aws.String("6")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(1) + client := &awstest.MockFakeRDS{} + tt.mocks(client) + r := &rdsRepository{ + client: client, + cache: store, + } + got, err := r.ListAllDBInstances() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllDBInstances() + assert.Nil(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*rds.DBInstance{}, store.Get("rdsListAllDBInstances")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_rdsRepository_ListAllDBSubnetGroups(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeRDS) + want []*rds.DBSubnetGroup + wantErr error + }{ + { + name: "List with 2 pages", + mocks: func(client *awstest.MockFakeRDS) { + client.On("DescribeDBSubnetGroupsPages", + &rds.DescribeDBSubnetGroupsInput{}, + mock.MatchedBy(func(callback func(res *rds.DescribeDBSubnetGroupsOutput, lastPage bool) bool) bool { + callback(&rds.DescribeDBSubnetGroupsOutput{ + DBSubnetGroups: []*rds.DBSubnetGroup{ + {DBSubnetGroupName: aws.String("1")}, + {DBSubnetGroupName: aws.String("2")}, + {DBSubnetGroupName: aws.String("3")}, + }, + }, false) + callback(&rds.DescribeDBSubnetGroupsOutput{ + DBSubnetGroups: []*rds.DBSubnetGroup{ + {DBSubnetGroupName: aws.String("4")}, + {DBSubnetGroupName: aws.String("5")}, + {DBSubnetGroupName: aws.String("6")}, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*rds.DBSubnetGroup{ + {DBSubnetGroupName: aws.String("1")}, + {DBSubnetGroupName: aws.String("2")}, + {DBSubnetGroupName: aws.String("3")}, + {DBSubnetGroupName: aws.String("4")}, + {DBSubnetGroupName: aws.String("5")}, + {DBSubnetGroupName: aws.String("6")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache2.New(1) + client := &awstest.MockFakeRDS{} + tt.mocks(client) + r := &rdsRepository{ + client: client, + cache: store, + } + got, err := r.ListAllDBSubnetGroups() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllDBSubnetGroups() + assert.Nil(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*rds.DBSubnetGroup{}, store.Get("rdsListAllDBSubnetGroups")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_rdsRepository_ListAllDBClusters(t *testing.T) { + tests := []struct { + name string + mocks func(*awstest.MockFakeRDS, *cache2.MockCache) + want []*rds.DBCluster + wantErr error + }{ + { + name: "should list with 2 pages", + mocks: func(client *awstest.MockFakeRDS, store *cache2.MockCache) { + clusters := []*rds.DBCluster{ + {DBClusterIdentifier: aws.String("1")}, + {DBClusterIdentifier: aws.String("2")}, + {DBClusterIdentifier: aws.String("3")}, + {DBClusterIdentifier: aws.String("4")}, + {DBClusterIdentifier: aws.String("5")}, + {DBClusterIdentifier: aws.String("6")}, + } + + client.On("DescribeDBClustersPages", + &rds.DescribeDBClustersInput{}, + mock.MatchedBy(func(callback func(res *rds.DescribeDBClustersOutput, lastPage bool) bool) bool { + callback(&rds.DescribeDBClustersOutput{DBClusters: clusters[:3]}, false) + callback(&rds.DescribeDBClustersOutput{DBClusters: clusters[3:]}, true) + return true + })).Return(nil).Once() + + store.On("Get", "rdsListAllDBClusters").Return(nil).Once() + store.On("Put", "rdsListAllDBClusters", clusters).Return(false).Once() + }, + want: []*rds.DBCluster{ + {DBClusterIdentifier: aws.String("1")}, + {DBClusterIdentifier: aws.String("2")}, + {DBClusterIdentifier: aws.String("3")}, + {DBClusterIdentifier: aws.String("4")}, + {DBClusterIdentifier: aws.String("5")}, + {DBClusterIdentifier: aws.String("6")}, + }, + }, + { + name: "should hit cache", + mocks: func(client *awstest.MockFakeRDS, store *cache2.MockCache) { + clusters := []*rds.DBCluster{ + {DBClusterIdentifier: aws.String("1")}, + {DBClusterIdentifier: aws.String("2")}, + {DBClusterIdentifier: aws.String("3")}, + {DBClusterIdentifier: aws.String("4")}, + {DBClusterIdentifier: aws.String("5")}, + {DBClusterIdentifier: aws.String("6")}, + } + + store.On("Get", "rdsListAllDBClusters").Return(clusters).Once() + }, + want: []*rds.DBCluster{ + {DBClusterIdentifier: aws.String("1")}, + {DBClusterIdentifier: aws.String("2")}, + {DBClusterIdentifier: aws.String("3")}, + {DBClusterIdentifier: aws.String("4")}, + {DBClusterIdentifier: aws.String("5")}, + {DBClusterIdentifier: aws.String("6")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := &cache2.MockCache{} + client := &awstest.MockFakeRDS{} + tt.mocks(client, store) + r := &rdsRepository{ + client: client, + cache: store, + } + got, err := r.ListAllDBClusters() + assert.Equal(t, tt.wantErr, err) + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} diff --git a/enumeration/remote/aws/repository/route53_repository.go b/enumeration/remote/aws/repository/route53_repository.go new file mode 100644 index 00000000..d1e7a51c --- /dev/null +++ b/enumeration/remote/aws/repository/route53_repository.go @@ -0,0 +1,92 @@ +package repository + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/cache" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/route53" + "github.com/aws/aws-sdk-go/service/route53/route53iface" +) + +type Route53Repository interface { + ListAllHealthChecks() ([]*route53.HealthCheck, error) + ListAllZones() ([]*route53.HostedZone, error) + ListRecordsForZone(zoneId string) ([]*route53.ResourceRecordSet, error) +} + +type route53Repository struct { + client route53iface.Route53API + cache cache.Cache +} + +func NewRoute53Repository(session *session.Session, c cache.Cache) *route53Repository { + return &route53Repository{ + route53.New(session), + c, + } +} + +func (r *route53Repository) ListAllHealthChecks() ([]*route53.HealthCheck, error) { + if v := r.cache.Get("route53ListAllHealthChecks"); v != nil { + return v.([]*route53.HealthCheck), nil + } + + var tables []*route53.HealthCheck + input := &route53.ListHealthChecksInput{} + err := r.client.ListHealthChecksPages(input, func(res *route53.ListHealthChecksOutput, lastPage bool) bool { + tables = append(tables, res.HealthChecks...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put("route53ListAllHealthChecks", tables) + return tables, nil +} + +func (r *route53Repository) ListAllZones() ([]*route53.HostedZone, error) { + cacheKey := "route53ListAllZones" + v := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + if v != nil { + return v.([]*route53.HostedZone), nil + } + + var result []*route53.HostedZone + input := &route53.ListHostedZonesInput{} + err := r.client.ListHostedZonesPages(input, func(res *route53.ListHostedZonesOutput, lastPage bool) bool { + result = append(result, res.HostedZones...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, result) + return result, nil +} + +func (r *route53Repository) ListRecordsForZone(zoneId string) ([]*route53.ResourceRecordSet, error) { + cacheKey := fmt.Sprintf("route53ListRecordsForZone_%s", zoneId) + if v := r.cache.Get(cacheKey); v != nil { + return v.([]*route53.ResourceRecordSet), nil + } + + var results []*route53.ResourceRecordSet + input := &route53.ListResourceRecordSetsInput{ + HostedZoneId: aws.String(zoneId), + } + err := r.client.ListResourceRecordSetsPages(input, func(res *route53.ListResourceRecordSetsOutput, lastPage bool) bool { + results = append(results, res.ResourceRecordSets...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, results) + return results, nil +} diff --git a/enumeration/remote/aws/repository/route53_repository_test.go b/enumeration/remote/aws/repository/route53_repository_test.go new file mode 100644 index 00000000..77a23c23 --- /dev/null +++ b/enumeration/remote/aws/repository/route53_repository_test.go @@ -0,0 +1,240 @@ +package repository + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/route53" + awstest "github.com/snyk/driftctl/test/aws" + "github.com/stretchr/testify/mock" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" +) + +func Test_route53Repository_ListAllHealthChecks(t *testing.T) { + + tests := []struct { + name string + mocks func(client *awstest.MockFakeRoute53) + want []*route53.HealthCheck + wantErr error + }{ + { + name: "List with 2 pages", + mocks: func(client *awstest.MockFakeRoute53) { + client.On("ListHealthChecksPages", + &route53.ListHealthChecksInput{}, + mock.MatchedBy(func(callback func(res *route53.ListHealthChecksOutput, lastPage bool) bool) bool { + callback(&route53.ListHealthChecksOutput{ + HealthChecks: []*route53.HealthCheck{ + {Id: aws.String("1")}, + {Id: aws.String("2")}, + {Id: aws.String("3")}, + }, + }, false) + callback(&route53.ListHealthChecksOutput{ + HealthChecks: []*route53.HealthCheck{ + {Id: aws.String("4")}, + {Id: aws.String("5")}, + {Id: aws.String("6")}, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*route53.HealthCheck{ + {Id: aws.String("1")}, + {Id: aws.String("2")}, + {Id: aws.String("3")}, + {Id: aws.String("4")}, + {Id: aws.String("5")}, + {Id: aws.String("6")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := awstest.MockFakeRoute53{} + tt.mocks(&client) + r := &route53Repository{ + client: &client, + cache: store, + } + got, err := r.ListAllHealthChecks() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllHealthChecks() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*route53.HealthCheck{}, store.Get("route53ListAllHealthChecks")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_route53Repository_ListAllZones(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeRoute53) + want []*route53.HostedZone + wantErr error + }{ + {name: "Zones with 2 pages", + mocks: func(client *awstest.MockFakeRoute53) { + client.On("ListHostedZonesPages", + &route53.ListHostedZonesInput{}, + mock.MatchedBy(func(callback func(res *route53.ListHostedZonesOutput, lastPage bool) bool) bool { + callback(&route53.ListHostedZonesOutput{ + HostedZones: []*route53.HostedZone{ + {Id: aws.String("1")}, + {Id: aws.String("2")}, + {Id: aws.String("3")}, + }, + }, false) + callback(&route53.ListHostedZonesOutput{ + HostedZones: []*route53.HostedZone{ + {Id: aws.String("4")}, + {Id: aws.String("5")}, + {Id: aws.String("6")}, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*route53.HostedZone{ + {Id: aws.String("1")}, + {Id: aws.String("2")}, + {Id: aws.String("3")}, + {Id: aws.String("4")}, + {Id: aws.String("5")}, + {Id: aws.String("6")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := awstest.MockFakeRoute53{} + tt.mocks(&client) + r := &route53Repository{ + client: &client, + cache: store, + } + got, err := r.ListAllZones() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllZones() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*route53.HostedZone{}, store.Get("route53ListAllZones")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_route53Repository_ListRecordsForZone(t *testing.T) { + tests := []struct { + name string + zoneIds []string + mocks func(client *awstest.MockFakeRoute53) + want []*route53.ResourceRecordSet + wantErr error + }{ + { + name: "records for zone with 2 pages", + zoneIds: []string{ + "1", + }, + mocks: func(client *awstest.MockFakeRoute53) { + client.On("ListResourceRecordSetsPages", + &route53.ListResourceRecordSetsInput{ + HostedZoneId: aws.String("1"), + }, + mock.MatchedBy(func(callback func(res *route53.ListResourceRecordSetsOutput, lastPage bool) bool) bool { + callback(&route53.ListResourceRecordSetsOutput{ + ResourceRecordSets: []*route53.ResourceRecordSet{ + {Name: aws.String("1")}, + {Name: aws.String("2")}, + {Name: aws.String("3")}, + }, + }, false) + callback(&route53.ListResourceRecordSetsOutput{ + ResourceRecordSets: []*route53.ResourceRecordSet{ + {Name: aws.String("4")}, + {Name: aws.String("5")}, + {Name: aws.String("6")}, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*route53.ResourceRecordSet{ + {Name: aws.String("1")}, + {Name: aws.String("2")}, + {Name: aws.String("3")}, + {Name: aws.String("4")}, + {Name: aws.String("5")}, + {Name: aws.String("6")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := awstest.MockFakeRoute53{} + tt.mocks(&client) + r := &route53Repository{ + client: &client, + cache: store, + } + for _, id := range tt.zoneIds { + got, err := r.ListRecordsForZone(id) + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListRecordsForZone(id) + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*route53.ResourceRecordSet{}, store.Get(fmt.Sprintf("route53ListRecordsForZone_%s", id))) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + } + }) + } +} diff --git a/enumeration/remote/aws/repository/s3_repository.go b/enumeration/remote/aws/repository/s3_repository.go new file mode 100644 index 00000000..2b5f5557 --- /dev/null +++ b/enumeration/remote/aws/repository/s3_repository.go @@ -0,0 +1,287 @@ +package repository + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/aws/client" + "github.com/snyk/driftctl/enumeration/remote/cache" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +type S3Repository interface { + ListAllBuckets() ([]*s3.Bucket, error) + GetBucketNotification(bucketName, region string) (*s3.NotificationConfiguration, error) + GetBucketPolicy(bucketName, region string) (*string, error) + GetBucketPublicAccessBlock(bucketName, region string) (*s3.PublicAccessBlockConfiguration, error) + ListBucketInventoryConfigurations(bucket *s3.Bucket, region string) ([]*s3.InventoryConfiguration, error) + ListBucketMetricsConfigurations(bucket *s3.Bucket, region string) ([]*s3.MetricsConfiguration, error) + ListBucketAnalyticsConfigurations(bucket *s3.Bucket, region string) ([]*s3.AnalyticsConfiguration, error) + GetBucketLocation(bucketName string) (string, error) +} + +type s3Repository struct { + clientFactory client.AwsClientFactoryInterface + cache cache.Cache +} + +func NewS3Repository(factory client.AwsClientFactoryInterface, c cache.Cache) *s3Repository { + return &s3Repository{ + factory, + c, + } +} + +func (s *s3Repository) ListAllBuckets() ([]*s3.Bucket, error) { + cacheKey := "s3ListAllBuckets" + v := s.cache.GetAndLock(cacheKey) + defer s.cache.Unlock(cacheKey) + if v != nil { + return v.([]*s3.Bucket), nil + } + + out, err := s.clientFactory.GetS3Client(nil).ListBuckets(&s3.ListBucketsInput{}) + if err != nil { + return nil, err + } + s.cache.Put(cacheKey, out.Buckets) + return out.Buckets, nil +} + +func (s *s3Repository) GetBucketPolicy(bucketName, region string) (*string, error) { + cacheKey := fmt.Sprintf("s3GetBucketPolicy_%s_%s", bucketName, region) + if v := s.cache.Get(cacheKey); v != nil { + return v.(*string), nil + } + policy, err := s.clientFactory. + GetS3Client(&awssdk.Config{Region: ®ion}). + GetBucketPolicy( + &s3.GetBucketPolicyInput{Bucket: &bucketName}, + ) + if err != nil { + if awsErr, ok := err.(awserr.Error); ok { + if awsErr.Code() == "NoSuchBucketPolicy" { + return nil, nil + } + } + return nil, errors.Wrapf( + err, + "Error listing bucket policy %s", + bucketName, + ) + } + + result := policy.Policy + if result != nil && *result == "" { + result = nil + } + + s.cache.Put(cacheKey, result) + return result, nil +} + +func (s *s3Repository) GetBucketPublicAccessBlock(bucketName, region string) (*s3.PublicAccessBlockConfiguration, error) { + cacheKey := fmt.Sprintf("s3GetBucketPublicAccessBlock_%s_%s", bucketName, region) + if v := s.cache.Get(cacheKey); v != nil { + return v.(*s3.PublicAccessBlockConfiguration), nil + } + response, err := s.clientFactory. + GetS3Client(&awssdk.Config{Region: ®ion}). + GetPublicAccessBlock(&s3.GetPublicAccessBlockInput{Bucket: &bucketName}) + + if err != nil { + if awsErr, ok := err.(awserr.Error); ok { + if awsErr.Code() == "NoSuchPublicAccessBlockConfiguration" { + return nil, nil + } + } + return nil, errors.Wrapf( + err, + "Error listing bucket public access block %s", + bucketName, + ) + } + + result := response.PublicAccessBlockConfiguration + + s.cache.Put(cacheKey, result) + return result, nil +} + +func (s *s3Repository) GetBucketNotification(bucketName, region string) (*s3.NotificationConfiguration, error) { + cacheKey := fmt.Sprintf("s3GetBucketNotification_%s_%s", bucketName, region) + if v := s.cache.Get(cacheKey); v != nil { + return v.(*s3.NotificationConfiguration), nil + } + bucketNotificationConfig, err := s.clientFactory. + GetS3Client(&awssdk.Config{Region: ®ion}). + GetBucketNotificationConfiguration( + &s3.GetBucketNotificationConfigurationRequest{Bucket: &bucketName}, + ) + if err != nil { + return nil, errors.Wrapf( + err, + "Error listing bucket notification configuration %s", + bucketName, + ) + } + + result := bucketNotificationConfig + if s.notificationIsEmpty(bucketNotificationConfig) { + result = nil + } + + s.cache.Put(cacheKey, result) + return result, nil +} + +func (s *s3Repository) notificationIsEmpty(notification *s3.NotificationConfiguration) bool { + return notification.TopicConfigurations == nil && + notification.QueueConfigurations == nil && + notification.LambdaFunctionConfigurations == nil +} + +func (s *s3Repository) ListBucketInventoryConfigurations(bucket *s3.Bucket, region string) ([]*s3.InventoryConfiguration, error) { + cacheKey := fmt.Sprintf("s3ListBucketInventoryConfigurations_%s_%s", *bucket.Name, region) + if v := s.cache.Get(cacheKey); v != nil { + return v.([]*s3.InventoryConfiguration), nil + } + + inventoryConfigurations := make([]*s3.InventoryConfiguration, 0) + s3client := s.clientFactory.GetS3Client(&awssdk.Config{Region: ®ion}) + request := &s3.ListBucketInventoryConfigurationsInput{ + Bucket: bucket.Name, + ContinuationToken: nil, + } + + for { + configurations, err := s3client.ListBucketInventoryConfigurations(request) + if err != nil { + return nil, errors.Wrapf( + err, + "Error listing bucket inventory configuration %s", + *bucket.Name, + ) + } + inventoryConfigurations = append(inventoryConfigurations, configurations.InventoryConfigurationList...) + if configurations.IsTruncated != nil && *configurations.IsTruncated { + request.ContinuationToken = configurations.NextContinuationToken + } else { + break + } + } + + s.cache.Put(cacheKey, inventoryConfigurations) + return inventoryConfigurations, nil +} + +func (s *s3Repository) ListBucketMetricsConfigurations(bucket *s3.Bucket, region string) ([]*s3.MetricsConfiguration, error) { + cacheKey := fmt.Sprintf("s3ListBucketMetricsConfigurations_%s_%s", *bucket.Name, region) + if v := s.cache.Get(cacheKey); v != nil { + return v.([]*s3.MetricsConfiguration), nil + } + + metricsConfigurationList := make([]*s3.MetricsConfiguration, 0) + s3client := s.clientFactory.GetS3Client(&awssdk.Config{Region: ®ion}) + request := &s3.ListBucketMetricsConfigurationsInput{ + Bucket: bucket.Name, + ContinuationToken: nil, + } + + for { + configurations, err := s3client.ListBucketMetricsConfigurations(request) + if err != nil { + return nil, errors.Wrapf( + err, + "Error listing bucket metrics configuration %s", + *bucket.Name, + ) + } + metricsConfigurationList = append(metricsConfigurationList, configurations.MetricsConfigurationList...) + if configurations.IsTruncated != nil && *configurations.IsTruncated { + request.ContinuationToken = configurations.NextContinuationToken + } else { + break + } + } + + s.cache.Put(cacheKey, metricsConfigurationList) + return metricsConfigurationList, nil +} + +func (s *s3Repository) ListBucketAnalyticsConfigurations(bucket *s3.Bucket, region string) ([]*s3.AnalyticsConfiguration, error) { + cacheKey := fmt.Sprintf("s3ListBucketAnalyticsConfigurations_%s_%s", *bucket.Name, region) + if v := s.cache.Get(cacheKey); v != nil { + return v.([]*s3.AnalyticsConfiguration), nil + } + + analyticsConfigurationList := make([]*s3.AnalyticsConfiguration, 0) + s3client := s.clientFactory.GetS3Client(&awssdk.Config{Region: ®ion}) + request := &s3.ListBucketAnalyticsConfigurationsInput{ + Bucket: bucket.Name, + ContinuationToken: nil, + } + + for { + configurations, err := s3client.ListBucketAnalyticsConfigurations(request) + if err != nil { + return nil, errors.Wrapf( + err, + "Error listing bucket analytics configuration %s", + *bucket.Name, + ) + } + analyticsConfigurationList = append(analyticsConfigurationList, configurations.AnalyticsConfigurationList...) + + if configurations.IsTruncated != nil && *configurations.IsTruncated { + request.ContinuationToken = configurations.NextContinuationToken + } else { + break + } + } + + s.cache.Put(cacheKey, analyticsConfigurationList) + return analyticsConfigurationList, nil +} + +func (s *s3Repository) GetBucketLocation(bucketName string) (string, error) { + cacheKey := fmt.Sprintf("s3GetBucketLocation_%s", bucketName) + v := s.cache.GetAndLock(cacheKey) + defer s.cache.Unlock(cacheKey) + if v != nil { + return v.(string), nil + } + + bucketLocationRequest := s3.GetBucketLocationInput{Bucket: &bucketName} + bucketLocationResponse, err := s.clientFactory.GetS3Client(nil).GetBucketLocation(&bucketLocationRequest) + if err != nil { + awsErr, ok := err.(awserr.Error) + if ok && awsErr.Code() == s3.ErrCodeNoSuchBucket { + logrus.WithFields(logrus.Fields{ + "bucket": bucketName, + }).Warning("Unable to retrieve bucket region, this may be an inconsistency in S3 api for fresh deleted bucket, skipping ...") + return "", nil + } + return "", err + } + + var location string + + // Buckets in Region us-east-1 have a LocationConstraint of null. + // https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetBucketLocation.html#API_GetBucketLocation_ResponseSyntax + if bucketLocationResponse.LocationConstraint == nil { + location = "us-east-1" + } else { + location = *bucketLocationResponse.LocationConstraint + } + + if location == "EU" { + location = "eu-west-1" + } + + s.cache.Put(cacheKey, location) + return location, nil +} diff --git a/enumeration/remote/aws/repository/s3_repository_test.go b/enumeration/remote/aws/repository/s3_repository_test.go new file mode 100644 index 00000000..a73c80c8 --- /dev/null +++ b/enumeration/remote/aws/repository/s3_repository_test.go @@ -0,0 +1,866 @@ +package repository + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/aws/client" + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/pkg/errors" + "github.com/r3labs/diff/v2" + awstest "github.com/snyk/driftctl/test/aws" + "github.com/stretchr/testify/assert" +) + +func Test_s3Repository_ListAllBuckets(t *testing.T) { + + tests := []struct { + name string + mocks func(client *awstest.MockFakeS3) + want []*s3.Bucket + wantErr error + }{ + { + name: "List buckets", + mocks: func(client *awstest.MockFakeS3) { + client.On("ListBuckets", &s3.ListBucketsInput{}).Return( + &s3.ListBucketsOutput{ + Buckets: []*s3.Bucket{ + {Name: aws.String("bucket1")}, + {Name: aws.String("bucket2")}, + {Name: aws.String("bucket3")}, + }, + }, + nil, + ).Once() + }, + want: []*s3.Bucket{ + {Name: aws.String("bucket1")}, + {Name: aws.String("bucket2")}, + {Name: aws.String("bucket3")}, + }, + }, + { + name: "Error listing buckets", + mocks: func(client *awstest.MockFakeS3) { + client.On("ListBuckets", &s3.ListBucketsInput{}).Return( + nil, + awserr.NewRequestFailure(nil, 403, ""), + ).Once() + }, + want: nil, + wantErr: awserr.NewRequestFailure(nil, 403, ""), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + mockedClient := &awstest.MockFakeS3{} + tt.mocks(mockedClient) + factory := client.MockAwsClientFactoryInterface{} + factory.On("GetS3Client", (*aws.Config)(nil)).Return(mockedClient).Once() + r := NewS3Repository(&factory, store) + got, err := r.ListAllBuckets() + factory.AssertExpectations(t) + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllBuckets() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*s3.Bucket{}, store.Get("s3ListAllBuckets")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_s3Repository_GetBucketNotification(t *testing.T) { + + tests := []struct { + name string + bucketName, region string + mocks func(client *awstest.MockFakeS3) + want *s3.NotificationConfiguration + wantErr string + }{ + { + name: "get empty bucket notification", + bucketName: "test-bucket", + region: "us-east-1", + mocks: func(client *awstest.MockFakeS3) { + client.On("GetBucketNotificationConfiguration", &s3.GetBucketNotificationConfigurationRequest{ + Bucket: aws.String("test-bucket"), + }).Return( + &s3.NotificationConfiguration{}, + nil, + ).Once() + }, + want: nil, + }, + { + name: "get bucket notification with lambda config", + bucketName: "test-bucket", + region: "us-east-1", + mocks: func(client *awstest.MockFakeS3) { + client.On("GetBucketNotificationConfiguration", &s3.GetBucketNotificationConfigurationRequest{ + Bucket: aws.String("test-bucket"), + }).Return( + &s3.NotificationConfiguration{ + LambdaFunctionConfigurations: []*s3.LambdaFunctionConfiguration{ + { + Id: aws.String("test"), + }, + }, + }, + nil, + ).Once() + }, + want: &s3.NotificationConfiguration{ + LambdaFunctionConfigurations: []*s3.LambdaFunctionConfiguration{ + { + Id: aws.String("test"), + }, + }, + }, + }, + { + name: "get bucket notification with queue config", + bucketName: "test-bucket", + region: "us-east-1", + mocks: func(client *awstest.MockFakeS3) { + client.On("GetBucketNotificationConfiguration", &s3.GetBucketNotificationConfigurationRequest{ + Bucket: aws.String("test-bucket"), + }).Return( + &s3.NotificationConfiguration{ + QueueConfigurations: []*s3.QueueConfiguration{ + { + Id: awssdk.String("test"), + }, + }, + }, + nil, + ).Once() + }, + want: &s3.NotificationConfiguration{ + QueueConfigurations: []*s3.QueueConfiguration{ + { + Id: awssdk.String("test"), + }, + }, + }, + }, + { + name: "get bucket notification with topic config", + bucketName: "test-bucket", + region: "us-east-1", + mocks: func(client *awstest.MockFakeS3) { + client.On("GetBucketNotificationConfiguration", &s3.GetBucketNotificationConfigurationRequest{ + Bucket: aws.String("test-bucket"), + }).Return( + &s3.NotificationConfiguration{ + TopicConfigurations: []*s3.TopicConfiguration{ + { + Id: awssdk.String("test"), + }, + }, + }, + nil, + ).Once() + }, + want: &s3.NotificationConfiguration{ + TopicConfigurations: []*s3.TopicConfiguration{ + { + Id: awssdk.String("test"), + }, + }, + }, + }, + { + name: "get bucket location when error", + bucketName: "test-bucket", + region: "us-east-1", + mocks: func(client *awstest.MockFakeS3) { + client.On("GetBucketNotificationConfiguration", &s3.GetBucketNotificationConfigurationRequest{ + Bucket: aws.String("test-bucket"), + }).Return( + nil, + awserr.New("UnknownError", "aws error", nil), + ).Once() + }, + wantErr: "Error listing bucket notification configuration test-bucket: UnknownError: aws error", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + mockedClient := &awstest.MockFakeS3{} + tt.mocks(mockedClient) + factory := client.MockAwsClientFactoryInterface{} + factory.On("GetS3Client", &aws.Config{Region: &tt.region}).Return(mockedClient).Once() + r := NewS3Repository(&factory, store) + got, err := r.GetBucketNotification(tt.bucketName, tt.region) + factory.AssertExpectations(t) + if err != nil && tt.wantErr == "" { + t.Fatalf("Unexpected error %+v", err) + } + if err != nil { + assert.Equal(t, tt.wantErr, err.Error()) + } + + if err == nil && tt.want != nil { + // Check that results were cached + cachedData, err := r.GetBucketNotification(tt.bucketName, tt.region) + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, &s3.NotificationConfiguration{}, store.Get(fmt.Sprintf("s3GetBucketNotification_%s_%s", tt.bucketName, tt.region))) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_s3Repository_GetBucketPolicy(t *testing.T) { + + tests := []struct { + name string + bucketName, region string + mocks func(client *awstest.MockFakeS3) + want *string + wantErr string + }{ + { + name: "get nil bucket policy", + bucketName: "test-bucket", + region: "us-east-1", + mocks: func(client *awstest.MockFakeS3) { + client.On("GetBucketPolicy", &s3.GetBucketPolicyInput{ + Bucket: aws.String("test-bucket"), + }).Return( + &s3.GetBucketPolicyOutput{}, + nil, + ).Once() + }, + want: nil, + }, + { + name: "get empty bucket policy", + bucketName: "test-bucket", + region: "us-east-1", + mocks: func(client *awstest.MockFakeS3) { + client.On("GetBucketPolicy", &s3.GetBucketPolicyInput{ + Bucket: aws.String("test-bucket"), + }).Return( + &s3.GetBucketPolicyOutput{ + Policy: awssdk.String(""), + }, + nil, + ).Once() + }, + want: nil, + }, + { + name: "get bucket policy", + bucketName: "test-bucket", + region: "us-east-1", + mocks: func(client *awstest.MockFakeS3) { + client.On("GetBucketPolicy", &s3.GetBucketPolicyInput{ + Bucket: aws.String("test-bucket"), + }).Return( + &s3.GetBucketPolicyOutput{ + Policy: awssdk.String("foobar"), + }, + nil, + ).Once() + }, + want: awssdk.String("foobar"), + }, + { + name: "get bucket location on 404", + bucketName: "test-bucket", + region: "us-east-1", + mocks: func(client *awstest.MockFakeS3) { + client.On("GetBucketPolicy", &s3.GetBucketPolicyInput{ + Bucket: aws.String("test-bucket"), + }).Return( + nil, + awserr.New("NoSuchBucketPolicy", "", nil), + ).Once() + }, + want: nil, + }, + { + name: "get bucket location when error", + bucketName: "test-bucket", + region: "us-east-1", + mocks: func(client *awstest.MockFakeS3) { + client.On("GetBucketPolicy", &s3.GetBucketPolicyInput{ + Bucket: aws.String("test-bucket"), + }).Return( + nil, + awserr.New("UnknownError", "aws error", nil), + ).Once() + }, + wantErr: "Error listing bucket policy test-bucket: UnknownError: aws error", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + mockedClient := &awstest.MockFakeS3{} + tt.mocks(mockedClient) + factory := client.MockAwsClientFactoryInterface{} + factory.On("GetS3Client", &aws.Config{Region: &tt.region}).Return(mockedClient).Once() + r := NewS3Repository(&factory, store) + got, err := r.GetBucketPolicy(tt.bucketName, tt.region) + factory.AssertExpectations(t) + if err != nil && tt.wantErr == "" { + t.Fatalf("Unexpected error %+v", err) + } + if err != nil { + assert.Equal(t, tt.wantErr, err.Error()) + } + + if err == nil && tt.want != nil { + // Check that results were cached + cachedData, err := r.GetBucketPolicy(tt.bucketName, tt.region) + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, awssdk.String(""), store.Get(fmt.Sprintf("s3GetBucketPolicy_%s_%s", tt.bucketName, tt.region))) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_s3Repository_ListBucketInventoryConfigurations(t *testing.T) { + tests := []struct { + name string + input struct { + bucket s3.Bucket + region string + } + mocks func(client *awstest.MockFakeS3) + want []*s3.InventoryConfiguration + wantErr string + }{ + { + name: "List inventory configs", + input: struct { + bucket s3.Bucket + region string + }{ + bucket: s3.Bucket{ + Name: awssdk.String("test-bucket"), + }, + region: "us-east-1", + }, + mocks: func(client *awstest.MockFakeS3) { + client.On( + "ListBucketInventoryConfigurations", + &s3.ListBucketInventoryConfigurationsInput{ + Bucket: awssdk.String("test-bucket"), + ContinuationToken: nil, + }, + ).Return( + &s3.ListBucketInventoryConfigurationsOutput{ + InventoryConfigurationList: []*s3.InventoryConfiguration{ + {Id: awssdk.String("config1")}, + {Id: awssdk.String("config2")}, + {Id: awssdk.String("config3")}, + }, + IsTruncated: awssdk.Bool(true), + NextContinuationToken: awssdk.String("nexttoken"), + }, + nil, + ).Once() + client.On( + "ListBucketInventoryConfigurations", + &s3.ListBucketInventoryConfigurationsInput{ + Bucket: awssdk.String("test-bucket"), + ContinuationToken: awssdk.String("nexttoken"), + }, + ).Return( + &s3.ListBucketInventoryConfigurationsOutput{ + InventoryConfigurationList: []*s3.InventoryConfiguration{ + {Id: awssdk.String("config4")}, + {Id: awssdk.String("config5")}, + {Id: awssdk.String("config6")}, + }, + IsTruncated: awssdk.Bool(false), + }, + nil, + ).Once() + }, + want: []*s3.InventoryConfiguration{ + {Id: awssdk.String("config1")}, + {Id: awssdk.String("config2")}, + {Id: awssdk.String("config3")}, + {Id: awssdk.String("config4")}, + {Id: awssdk.String("config5")}, + {Id: awssdk.String("config6")}, + }, + }, + { + name: "Error listing inventory configs", + input: struct { + bucket s3.Bucket + region string + }{ + bucket: s3.Bucket{ + Name: awssdk.String("test-bucket"), + }, + region: "us-east-1", + }, + mocks: func(client *awstest.MockFakeS3) { + client.On( + "ListBucketInventoryConfigurations", + &s3.ListBucketInventoryConfigurationsInput{ + Bucket: awssdk.String("test-bucket"), + }, + ).Return( + nil, + errors.New("aws error"), + ).Once() + }, + want: nil, + wantErr: "Error listing bucket inventory configuration test-bucket: aws error", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + mockedClient := &awstest.MockFakeS3{} + tt.mocks(mockedClient) + factory := client.MockAwsClientFactoryInterface{} + factory.On("GetS3Client", &aws.Config{Region: awssdk.String(tt.input.region)}).Return(mockedClient).Once() + r := NewS3Repository(&factory, store) + got, err := r.ListBucketInventoryConfigurations(&tt.input.bucket, tt.input.region) + factory.AssertExpectations(t) + if err != nil && tt.wantErr == "" { + t.Fatalf("Unexpected error %+v", err) + } + if err != nil { + assert.Equal(t, tt.wantErr, err.Error()) + } + + if err == nil { + // Check that results were cached + cachedData, err := r.ListBucketInventoryConfigurations(&tt.input.bucket, tt.input.region) + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*s3.InventoryConfiguration{}, store.Get(fmt.Sprintf("s3ListBucketInventoryConfigurations_%s_%s", *tt.input.bucket.Name, tt.input.region))) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_s3Repository_ListBucketMetricsConfigurations(t *testing.T) { + tests := []struct { + name string + input struct { + bucket s3.Bucket + region string + } + mocks func(client *awstest.MockFakeS3) + want []*s3.MetricsConfiguration + wantErr string + }{ + { + name: "List metrics configs", + input: struct { + bucket s3.Bucket + region string + }{ + bucket: s3.Bucket{ + Name: awssdk.String("test-bucket"), + }, + region: "us-east-1", + }, + mocks: func(client *awstest.MockFakeS3) { + client.On( + "ListBucketMetricsConfigurations", + &s3.ListBucketMetricsConfigurationsInput{ + Bucket: awssdk.String("test-bucket"), + ContinuationToken: nil, + }, + ).Return( + &s3.ListBucketMetricsConfigurationsOutput{ + MetricsConfigurationList: []*s3.MetricsConfiguration{ + {Id: awssdk.String("metric1")}, + {Id: awssdk.String("metric2")}, + {Id: awssdk.String("metric3")}, + }, + IsTruncated: awssdk.Bool(true), + NextContinuationToken: awssdk.String("nexttoken"), + }, + nil, + ).Once() + client.On( + "ListBucketMetricsConfigurations", + &s3.ListBucketMetricsConfigurationsInput{ + Bucket: awssdk.String("test-bucket"), + ContinuationToken: awssdk.String("nexttoken"), + }, + ).Return( + &s3.ListBucketMetricsConfigurationsOutput{ + MetricsConfigurationList: []*s3.MetricsConfiguration{ + {Id: awssdk.String("metric4")}, + {Id: awssdk.String("metric5")}, + {Id: awssdk.String("metric6")}, + }, + IsTruncated: awssdk.Bool(false), + }, + nil, + ).Once() + }, + want: []*s3.MetricsConfiguration{ + {Id: awssdk.String("metric1")}, + {Id: awssdk.String("metric2")}, + {Id: awssdk.String("metric3")}, + {Id: awssdk.String("metric4")}, + {Id: awssdk.String("metric5")}, + {Id: awssdk.String("metric6")}, + }, + }, + { + name: "Error listing metrics configs", + input: struct { + bucket s3.Bucket + region string + }{ + bucket: s3.Bucket{ + Name: awssdk.String("test-bucket"), + }, + region: "us-east-1", + }, + mocks: func(client *awstest.MockFakeS3) { + client.On( + "ListBucketMetricsConfigurations", + &s3.ListBucketMetricsConfigurationsInput{ + Bucket: awssdk.String("test-bucket"), + }, + ).Return( + nil, + errors.New("aws error"), + ).Once() + }, + want: nil, + wantErr: "Error listing bucket metrics configuration test-bucket: aws error", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + mockedClient := &awstest.MockFakeS3{} + tt.mocks(mockedClient) + factory := client.MockAwsClientFactoryInterface{} + factory.On("GetS3Client", &aws.Config{Region: awssdk.String(tt.input.region)}).Return(mockedClient).Once() + r := NewS3Repository(&factory, store) + got, err := r.ListBucketMetricsConfigurations(&tt.input.bucket, tt.input.region) + factory.AssertExpectations(t) + if err != nil && tt.wantErr == "" { + t.Fatalf("Unexpected error %+v", err) + } + if err != nil { + assert.Equal(t, tt.wantErr, err.Error()) + } + + if err == nil { + // Check that results were cached + cachedData, err := r.ListBucketMetricsConfigurations(&tt.input.bucket, tt.input.region) + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*s3.MetricsConfiguration{}, store.Get(fmt.Sprintf("s3ListBucketMetricsConfigurations_%s_%s", *tt.input.bucket.Name, tt.input.region))) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_s3Repository_ListBucketAnalyticsConfigurations(t *testing.T) { + tests := []struct { + name string + input struct { + bucket s3.Bucket + region string + } + mocks func(client *awstest.MockFakeS3) + want []*s3.AnalyticsConfiguration + wantErr string + }{ + { + name: "List analytics configs", + input: struct { + bucket s3.Bucket + region string + }{ + bucket: s3.Bucket{ + Name: awssdk.String("test-bucket"), + }, + region: "us-east-1", + }, + mocks: func(client *awstest.MockFakeS3) { + client.On( + "ListBucketAnalyticsConfigurations", + &s3.ListBucketAnalyticsConfigurationsInput{ + Bucket: awssdk.String("test-bucket"), + ContinuationToken: nil, + }, + ).Return( + &s3.ListBucketAnalyticsConfigurationsOutput{ + AnalyticsConfigurationList: []*s3.AnalyticsConfiguration{ + {Id: awssdk.String("analytic1")}, + {Id: awssdk.String("analytic2")}, + {Id: awssdk.String("analytic3")}, + }, + IsTruncated: awssdk.Bool(true), + NextContinuationToken: awssdk.String("nexttoken"), + }, + nil, + ).Once() + client.On( + "ListBucketAnalyticsConfigurations", + &s3.ListBucketAnalyticsConfigurationsInput{ + Bucket: awssdk.String("test-bucket"), + ContinuationToken: awssdk.String("nexttoken"), + }, + ).Return( + &s3.ListBucketAnalyticsConfigurationsOutput{ + AnalyticsConfigurationList: []*s3.AnalyticsConfiguration{ + {Id: awssdk.String("analytic4")}, + {Id: awssdk.String("analytic5")}, + {Id: awssdk.String("analytic6")}, + }, + IsTruncated: awssdk.Bool(false), + }, + nil, + ).Once() + }, + want: []*s3.AnalyticsConfiguration{ + {Id: awssdk.String("analytic1")}, + {Id: awssdk.String("analytic2")}, + {Id: awssdk.String("analytic3")}, + {Id: awssdk.String("analytic4")}, + {Id: awssdk.String("analytic5")}, + {Id: awssdk.String("analytic6")}, + }, + }, + { + name: "Error listing analytics configs", + input: struct { + bucket s3.Bucket + region string + }{ + bucket: s3.Bucket{ + Name: awssdk.String("test-bucket"), + }, + region: "us-east-1", + }, + mocks: func(client *awstest.MockFakeS3) { + client.On( + "ListBucketAnalyticsConfigurations", + &s3.ListBucketAnalyticsConfigurationsInput{ + Bucket: awssdk.String("test-bucket"), + }, + ).Return( + nil, + errors.New("aws error"), + ).Once() + }, + want: nil, + wantErr: "Error listing bucket analytics configuration test-bucket: aws error", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + mockedClient := &awstest.MockFakeS3{} + tt.mocks(mockedClient) + factory := client.MockAwsClientFactoryInterface{} + factory.On("GetS3Client", &aws.Config{Region: awssdk.String(tt.input.region)}).Return(mockedClient).Once() + r := NewS3Repository(&factory, store) + got, err := r.ListBucketAnalyticsConfigurations(&tt.input.bucket, tt.input.region) + factory.AssertExpectations(t) + if err != nil && tt.wantErr == "" { + t.Fatalf("Unexpected error %+v", err) + } + if err != nil { + assert.Equal(t, tt.wantErr, err.Error()) + } + + if err == nil { + // Check that results were cached + cachedData, err := r.ListBucketAnalyticsConfigurations(&tt.input.bucket, tt.input.region) + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*s3.AnalyticsConfiguration{}, store.Get(fmt.Sprintf("s3ListBucketAnalyticsConfigurations_%s_%s", *tt.input.bucket.Name, tt.input.region))) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_s3Repository_GetBucketLocation(t *testing.T) { + + tests := []struct { + name string + bucket *s3.Bucket + mocks func(client *awstest.MockFakeS3) + want string + wantErr string + }{ + { + name: "get bucket location", + bucket: &s3.Bucket{ + Name: awssdk.String("test-bucket"), + }, + mocks: func(client *awstest.MockFakeS3) { + client.On("GetBucketLocation", &s3.GetBucketLocationInput{ + Bucket: awssdk.String("test-bucket"), + }).Return( + &s3.GetBucketLocationOutput{ + LocationConstraint: awssdk.String("eu-east-1"), + }, + nil, + ).Once() + }, + want: "eu-east-1", + }, + { + name: "get bucket location for us-east-2", + bucket: &s3.Bucket{ + Name: awssdk.String("test-bucket"), + }, + mocks: func(client *awstest.MockFakeS3) { + client.On("GetBucketLocation", &s3.GetBucketLocationInput{ + Bucket: awssdk.String("test-bucket"), + }).Return( + &s3.GetBucketLocationOutput{}, + nil, + ).Once() + }, + want: "us-east-1", + }, + { + name: "get bucket location when no such bucket", + bucket: &s3.Bucket{ + Name: awssdk.String("test-bucket"), + }, + mocks: func(client *awstest.MockFakeS3) { + client.On("GetBucketLocation", &s3.GetBucketLocationInput{ + Bucket: awssdk.String("test-bucket"), + }).Return( + nil, + awserr.New(s3.ErrCodeNoSuchBucket, "", nil), + ).Once() + }, + want: "", + }, + { + name: "get bucket location when error", + bucket: &s3.Bucket{ + Name: awssdk.String("test-bucket"), + }, + mocks: func(client *awstest.MockFakeS3) { + client.On("GetBucketLocation", &s3.GetBucketLocationInput{ + Bucket: awssdk.String("test-bucket"), + }).Return( + nil, + awserr.New("UnknownError", "aws error", nil), + ).Once() + }, + wantErr: "UnknownError: aws error", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + mockedClient := &awstest.MockFakeS3{} + tt.mocks(mockedClient) + factory := client.MockAwsClientFactoryInterface{} + factory.On("GetS3Client", (*aws.Config)(nil)).Return(mockedClient).Once() + r := NewS3Repository(&factory, store) + got, err := r.GetBucketLocation(*tt.bucket.Name) + factory.AssertExpectations(t) + if err != nil && tt.wantErr == "" { + t.Fatalf("Unexpected error %+v", err) + } + if err != nil { + assert.Equal(t, tt.wantErr, err.Error()) + } + + if err == nil && tt.want != "" { + // Check that results were cached + cachedData, err := r.GetBucketLocation(*tt.bucket.Name) + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, "", store.Get(fmt.Sprintf("s3GetBucketLocation_%s", *tt.bucket.Name))) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} diff --git a/enumeration/remote/aws/repository/sns_repository.go b/enumeration/remote/aws/repository/sns_repository.go new file mode 100644 index 00000000..b4c17f5b --- /dev/null +++ b/enumeration/remote/aws/repository/sns_repository.go @@ -0,0 +1,67 @@ +package repository + +import ( + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sns" + "github.com/aws/aws-sdk-go/service/sns/snsiface" + "github.com/snyk/driftctl/enumeration/remote/cache" +) + +type SNSRepository interface { + ListAllTopics() ([]*sns.Topic, error) + ListAllSubscriptions() ([]*sns.Subscription, error) +} + +type snsRepository struct { + client snsiface.SNSAPI + cache cache.Cache +} + +func NewSNSRepository(session *session.Session, c cache.Cache) *snsRepository { + return &snsRepository{ + sns.New(session), + c, + } +} + +func (r *snsRepository) ListAllTopics() ([]*sns.Topic, error) { + + cacheKey := "snsListAllTopics" + v := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + if v != nil { + return v.([]*sns.Topic), nil + } + + var topics []*sns.Topic + input := &sns.ListTopicsInput{} + err := r.client.ListTopicsPages(input, func(res *sns.ListTopicsOutput, lastPage bool) bool { + topics = append(topics, res.Topics...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, topics) + return topics, nil +} + +func (r *snsRepository) ListAllSubscriptions() ([]*sns.Subscription, error) { + if v := r.cache.Get("snsListAllSubscriptions"); v != nil { + return v.([]*sns.Subscription), nil + } + + var subscriptions []*sns.Subscription + input := &sns.ListSubscriptionsInput{} + err := r.client.ListSubscriptionsPages(input, func(res *sns.ListSubscriptionsOutput, lastPage bool) bool { + subscriptions = append(subscriptions, res.Subscriptions...) + return !lastPage + }) + if err != nil { + return nil, err + } + + r.cache.Put("snsListAllSubscriptions", subscriptions) + return subscriptions, nil +} diff --git a/enumeration/remote/aws/repository/sns_repository_test.go b/enumeration/remote/aws/repository/sns_repository_test.go new file mode 100644 index 00000000..d5b6778d --- /dev/null +++ b/enumeration/remote/aws/repository/sns_repository_test.go @@ -0,0 +1,161 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws" + awstest "github.com/snyk/driftctl/test/aws" + "github.com/stretchr/testify/mock" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-sdk-go/service/sns" +) + +func Test_snsRepository_ListAllTopics(t *testing.T) { + + tests := []struct { + name string + mocks func(client *awstest.MockFakeSNS) + want []*sns.Topic + wantErr error + }{ + { + name: "List with 2 pages", + mocks: func(client *awstest.MockFakeSNS) { + client.On("ListTopicsPages", + &sns.ListTopicsInput{}, + mock.MatchedBy(func(callback func(res *sns.ListTopicsOutput, lastPage bool) bool) bool { + callback(&sns.ListTopicsOutput{ + Topics: []*sns.Topic{ + {TopicArn: aws.String("arn1")}, + {TopicArn: aws.String("arn2")}, + {TopicArn: aws.String("arn3")}, + }, + }, false) + callback(&sns.ListTopicsOutput{ + Topics: []*sns.Topic{ + {TopicArn: aws.String("arn4")}, + {TopicArn: aws.String("arn5")}, + {TopicArn: aws.String("arn6")}, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*sns.Topic{ + {TopicArn: aws.String("arn1")}, + {TopicArn: aws.String("arn2")}, + {TopicArn: aws.String("arn3")}, + {TopicArn: aws.String("arn4")}, + {TopicArn: aws.String("arn5")}, + {TopicArn: aws.String("arn6")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := &awstest.MockFakeSNS{} + tt.mocks(client) + r := &snsRepository{ + client: client, + cache: store, + } + got, err := r.ListAllTopics() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllTopics() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*sns.Topic{}, store.Get("snsListAllTopics")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_snsRepository_ListAllSubscriptions(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeSNS) + want []*sns.Subscription + wantErr error + }{ + { + name: "List with 2 pages", + mocks: func(client *awstest.MockFakeSNS) { + client.On("ListSubscriptionsPages", + &sns.ListSubscriptionsInput{}, + mock.MatchedBy(func(callback func(res *sns.ListSubscriptionsOutput, lastPage bool) bool) bool { + callback(&sns.ListSubscriptionsOutput{ + Subscriptions: []*sns.Subscription{ + {TopicArn: aws.String("arn1"), SubscriptionArn: aws.String("SubArn1")}, + {TopicArn: aws.String("arn2"), SubscriptionArn: aws.String("SubArn2")}, + {TopicArn: aws.String("arn3"), SubscriptionArn: aws.String("SubArn3")}, + }, + }, false) + callback(&sns.ListSubscriptionsOutput{ + Subscriptions: []*sns.Subscription{ + {TopicArn: aws.String("arn4"), SubscriptionArn: aws.String("SubArn4")}, + {TopicArn: aws.String("arn5"), SubscriptionArn: aws.String("SubArn5")}, + {TopicArn: aws.String("arn6"), SubscriptionArn: aws.String("SubArn6")}, + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*sns.Subscription{ + {TopicArn: aws.String("arn1"), SubscriptionArn: aws.String("SubArn1")}, + {TopicArn: aws.String("arn2"), SubscriptionArn: aws.String("SubArn2")}, + {TopicArn: aws.String("arn3"), SubscriptionArn: aws.String("SubArn3")}, + {TopicArn: aws.String("arn4"), SubscriptionArn: aws.String("SubArn4")}, + {TopicArn: aws.String("arn5"), SubscriptionArn: aws.String("SubArn5")}, + {TopicArn: aws.String("arn6"), SubscriptionArn: aws.String("SubArn6")}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := &awstest.MockFakeSNS{} + tt.mocks(client) + r := &snsRepository{ + client: client, + cache: store, + } + got, err := r.ListAllSubscriptions() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllSubscriptions() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*sns.Subscription{}, store.Get("snsListAllSubscriptions")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} diff --git a/enumeration/remote/aws/repository/sqs_repository.go b/enumeration/remote/aws/repository/sqs_repository.go new file mode 100644 index 00000000..dbf46bc2 --- /dev/null +++ b/enumeration/remote/aws/repository/sqs_repository.go @@ -0,0 +1,72 @@ +package repository + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/remote/cache" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" +) + +type SQSRepository interface { + ListAllQueues() ([]*string, error) + GetQueueAttributes(url string) (*sqs.GetQueueAttributesOutput, error) +} + +type sqsRepository struct { + client sqsiface.SQSAPI + cache cache.Cache +} + +func NewSQSRepository(session *session.Session, c cache.Cache) *sqsRepository { + return &sqsRepository{ + sqs.New(session), + c, + } +} + +func (r *sqsRepository) GetQueueAttributes(url string) (*sqs.GetQueueAttributesOutput, error) { + cacheKey := fmt.Sprintf("sqsGetQueueAttributes_%s", url) + if v := r.cache.Get(cacheKey); v != nil { + return v.(*sqs.GetQueueAttributesOutput), nil + } + + attributes, err := r.client.GetQueueAttributes(&sqs.GetQueueAttributesInput{ + AttributeNames: aws.StringSlice([]string{sqs.QueueAttributeNamePolicy}), + QueueUrl: &url, + }) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, attributes) + + return attributes, nil +} + +func (r *sqsRepository) ListAllQueues() ([]*string, error) { + + cacheKey := "sqsListAllQueues" + v := r.cache.GetAndLock(cacheKey) + defer r.cache.Unlock(cacheKey) + if v != nil { + return v.([]*string), nil + } + + var queues []*string + input := sqs.ListQueuesInput{} + err := r.client.ListQueuesPages(&input, + func(resp *sqs.ListQueuesOutput, lastPage bool) bool { + queues = append(queues, resp.QueueUrls...) + return !lastPage + }, + ) + if err != nil { + return nil, err + } + + r.cache.Put(cacheKey, queues) + return queues, nil +} diff --git a/enumeration/remote/aws/repository/sqs_repository_test.go b/enumeration/remote/aws/repository/sqs_repository_test.go new file mode 100644 index 00000000..51889afb --- /dev/null +++ b/enumeration/remote/aws/repository/sqs_repository_test.go @@ -0,0 +1,145 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "strings" + "testing" + + awssdk "github.com/aws/aws-sdk-go/aws" + awstest "github.com/snyk/driftctl/test/aws" + + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func Test_sqsRepository_ListAllQueues(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeSQS) + want []*string + wantErr error + }{ + { + name: "list with multiple pages", + mocks: func(client *awstest.MockFakeSQS) { + client.On("ListQueuesPages", + &sqs.ListQueuesInput{}, + mock.MatchedBy(func(callback func(res *sqs.ListQueuesOutput, lastPage bool) bool) bool { + callback(&sqs.ListQueuesOutput{ + QueueUrls: []*string{ + awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/bar.fifo"), + awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/foo"), + }, + }, false) + callback(&sqs.ListQueuesOutput{ + QueueUrls: []*string{ + awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/baz"), + }, + }, true) + return true + })).Return(nil).Once() + }, + want: []*string{ + awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/bar.fifo"), + awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/foo"), + awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/baz"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := &awstest.MockFakeSQS{} + tt.mocks(client) + r := &sqsRepository{ + client: client, + cache: store, + } + got, err := r.ListAllQueues() + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.ListAllQueues() + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, []*string{}, store.Get("sqsListAllQueues")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} + +func Test_sqsRepository_GetQueueAttributes(t *testing.T) { + tests := []struct { + name string + mocks func(client *awstest.MockFakeSQS) + want *sqs.GetQueueAttributesOutput + wantErr error + }{ + { + name: "get attributes", + mocks: func(client *awstest.MockFakeSQS) { + client.On( + "GetQueueAttributes", + &sqs.GetQueueAttributesInput{ + AttributeNames: awssdk.StringSlice([]string{sqs.QueueAttributeNamePolicy}), + QueueUrl: awssdk.String("http://example.com"), + }, + ).Return( + &sqs.GetQueueAttributesOutput{ + Attributes: map[string]*string{ + sqs.QueueAttributeNamePolicy: awssdk.String("foobar"), + }, + }, + nil, + ).Once() + }, + want: &sqs.GetQueueAttributesOutput{ + Attributes: map[string]*string{ + sqs.QueueAttributeNamePolicy: awssdk.String("foobar"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := cache.New(1) + client := &awstest.MockFakeSQS{} + tt.mocks(client) + r := &sqsRepository{ + client: client, + cache: store, + } + got, err := r.GetQueueAttributes("http://example.com") + assert.Equal(t, tt.wantErr, err) + + if err == nil { + // Check that results were cached + cachedData, err := r.GetQueueAttributes("http://example.com") + assert.NoError(t, err) + assert.Equal(t, got, cachedData) + assert.IsType(t, &sqs.GetQueueAttributesOutput{}, store.Get("sqsGetQueueAttributes_http://example.com")) + } + + changelog, err := diff.Diff(got, tt.want) + assert.Nil(t, err) + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) + } + t.Fail() + } + }) + } +} diff --git a/enumeration/remote/aws/route53_health_check_enumerator.go b/enumeration/remote/aws/route53_health_check_enumerator.go new file mode 100644 index 00000000..15d7019a --- /dev/null +++ b/enumeration/remote/aws/route53_health_check_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type Route53HealthCheckEnumerator struct { + repository repository.Route53Repository + factory resource.ResourceFactory +} + +func NewRoute53HealthCheckEnumerator(repo repository.Route53Repository, factory resource.ResourceFactory) *Route53HealthCheckEnumerator { + return &Route53HealthCheckEnumerator{ + repo, + factory, + } +} + +func (e *Route53HealthCheckEnumerator) SupportedType() resource.ResourceType { + return aws.AwsRoute53HealthCheckResourceType +} + +func (e *Route53HealthCheckEnumerator) Enumerate() ([]*resource.Resource, error) { + healthChecks, err := e.repository.ListAllHealthChecks() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(healthChecks)) + + for _, healthCheck := range healthChecks { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *healthCheck.Id, + map[string]interface{}{}, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/route53_record_enumerator.go b/enumeration/remote/aws/route53_record_enumerator.go new file mode 100644 index 00000000..4cefb5c5 --- /dev/null +++ b/enumeration/remote/aws/route53_record_enumerator.go @@ -0,0 +1,99 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "strconv" + "strings" + + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type Route53RecordEnumerator struct { + client repository.Route53Repository + factory resource.ResourceFactory +} + +func NewRoute53RecordEnumerator(repo repository.Route53Repository, factory resource.ResourceFactory) *Route53RecordEnumerator { + return &Route53RecordEnumerator{ + repo, + factory, + } +} + +func (e *Route53RecordEnumerator) SupportedType() resource.ResourceType { + return resourceaws.AwsRoute53RecordResourceType +} + +func (e *Route53RecordEnumerator) Enumerate() ([]*resource.Resource, error) { + + zones, err := e.client.ListAllZones() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsRoute53ZoneResourceType) + } + + results := make([]*resource.Resource, 0, len(zones)) + + for _, hostedZone := range zones { + records, err := e.listRecordsForZone(strings.TrimPrefix(*hostedZone.Id, "/hostedzone/")) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results = append(results, records...) + } + + return results, err +} + +func (e *Route53RecordEnumerator) listRecordsForZone(zoneId string) ([]*resource.Resource, error) { + + records, err := e.client.ListRecordsForZone(zoneId) + if err != nil { + return nil, err + } + + results := make([]*resource.Resource, 0, len(records)) + + for _, raw := range records { + rawType := *raw.Type + rawName := *raw.Name + rawSetIdentifier := raw.SetIdentifier + + vars := []string{ + zoneId, + strings.ToLower(strings.TrimSuffix(rawName, ".")), + rawType, + } + if rawSetIdentifier != nil { + vars = append(vars, *rawSetIdentifier) + } + + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + e.cleanRecordName(strings.Join(vars, "_")), + map[string]interface{}{ + "type": rawType, + }, + ), + ) + } + + return results, nil +} + +// cleanRecordName +// Route 53 stores certain characters with the octal equivalent in ASCII format. +// This function converts all of these characters back into the original character. +// E.g. "*" is stored as "\\052" and "@" as "\\100" +func (e *Route53RecordEnumerator) cleanRecordName(name string) string { + str := name + s, err := strconv.Unquote(`"` + str + `"`) + if err != nil { + return str + } + return s +} diff --git a/enumeration/remote/aws/route53_zone_enumerator.go b/enumeration/remote/aws/route53_zone_enumerator.go new file mode 100644 index 00000000..7e51bf16 --- /dev/null +++ b/enumeration/remote/aws/route53_zone_enumerator.go @@ -0,0 +1,48 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "strings" + + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type Route53ZoneSupplier struct { + client repository.Route53Repository + factory resource.ResourceFactory +} + +func NewRoute53ZoneEnumerator(repo repository.Route53Repository, factory resource.ResourceFactory) *Route53ZoneSupplier { + return &Route53ZoneSupplier{ + repo, + factory, + } +} + +func (e *Route53ZoneSupplier) SupportedType() resource.ResourceType { + return resourceaws.AwsRoute53ZoneResourceType +} + +func (e *Route53ZoneSupplier) Enumerate() ([]*resource.Resource, error) { + zones, err := e.client.ListAllZones() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(zones)) + + for _, hostedZone := range zones { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + strings.TrimPrefix(*hostedZone.Id, "/hostedzone/"), + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/s3_bucket_analytic_enumerator.go b/enumeration/remote/aws/s3_bucket_analytic_enumerator.go new file mode 100644 index 00000000..b2c931ee --- /dev/null +++ b/enumeration/remote/aws/s3_bucket_analytic_enumerator.go @@ -0,0 +1,80 @@ +package aws + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/common" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + tf "github.com/snyk/driftctl/enumeration/remote/terraform" + + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type S3BucketAnalyticEnumerator struct { + repository repository.S3Repository + factory resource.ResourceFactory + providerConfig tf.TerraformProviderConfig + alerter alerter.AlerterInterface +} + +func NewS3BucketAnalyticEnumerator(repo repository.S3Repository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3BucketAnalyticEnumerator { + return &S3BucketAnalyticEnumerator{ + repository: repo, + factory: factory, + providerConfig: providerConfig, + alerter: alerter, + } +} + +func (e *S3BucketAnalyticEnumerator) SupportedType() resource.ResourceType { + return aws.AwsS3BucketAnalyticsConfigurationResourceType +} + +func (e *S3BucketAnalyticEnumerator) Enumerate() ([]*resource.Resource, error) { + buckets, err := e.repository.ListAllBuckets() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsS3BucketResourceType) + } + + results := make([]*resource.Resource, 0, len(buckets)) + + for _, bucket := range buckets { + region, err := e.repository.GetBucketLocation(*bucket.Name) + if err != nil { + alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) + continue + } + if region == "" || region != e.providerConfig.DefaultAlias { + logrus.WithFields(logrus.Fields{ + "region": region, + "bucket": *bucket.Name, + }).Debug("Skipped bucket analytic") + continue + } + + analyticsConfigurationList, err := e.repository.ListBucketAnalyticsConfigurations(bucket, region) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, analytics := range analyticsConfigurationList { + id := fmt.Sprintf("%s:%s", *bucket.Name, *analytics.Id) + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + id, + map[string]interface{}{ + "region": region, + }, + ), + ) + } + } + + return results, err +} diff --git a/enumeration/remote/aws/s3_bucket_enumerator.go b/enumeration/remote/aws/s3_bucket_enumerator.go new file mode 100644 index 00000000..e96b69d5 --- /dev/null +++ b/enumeration/remote/aws/s3_bucket_enumerator.go @@ -0,0 +1,69 @@ +package aws + +import ( + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/common" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + tf "github.com/snyk/driftctl/enumeration/remote/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type S3BucketEnumerator struct { + repository repository.S3Repository + factory resource.ResourceFactory + providerConfig tf.TerraformProviderConfig + alerter alerter.AlerterInterface +} + +func NewS3BucketEnumerator(repo repository.S3Repository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3BucketEnumerator { + return &S3BucketEnumerator{ + repository: repo, + factory: factory, + providerConfig: providerConfig, + alerter: alerter, + } +} + +func (e *S3BucketEnumerator) SupportedType() resource.ResourceType { + return aws.AwsS3BucketResourceType +} + +func (e *S3BucketEnumerator) Enumerate() ([]*resource.Resource, error) { + buckets, err := e.repository.ListAllBuckets() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(buckets)) + + for _, bucket := range buckets { + region, err := e.repository.GetBucketLocation(*bucket.Name) + if err != nil { + alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) + continue + } + if region == "" || region != e.providerConfig.DefaultAlias { + logrus.WithFields(logrus.Fields{ + "region": region, + "bucket": *bucket.Name, + }).Debug("Skipped bucket") + continue + } + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *bucket.Name, + map[string]interface{}{ + "region": region, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/s3_bucket_inventory_enumerator.go b/enumeration/remote/aws/s3_bucket_inventory_enumerator.go new file mode 100644 index 00000000..97518550 --- /dev/null +++ b/enumeration/remote/aws/s3_bucket_inventory_enumerator.go @@ -0,0 +1,81 @@ +package aws + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/common" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + tf "github.com/snyk/driftctl/enumeration/remote/terraform" + + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type S3BucketInventoryEnumerator struct { + repository repository.S3Repository + factory resource.ResourceFactory + providerConfig tf.TerraformProviderConfig + alerter alerter.AlerterInterface +} + +func NewS3BucketInventoryEnumerator(repo repository.S3Repository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3BucketInventoryEnumerator { + return &S3BucketInventoryEnumerator{ + repository: repo, + factory: factory, + providerConfig: providerConfig, + alerter: alerter, + } +} + +func (e *S3BucketInventoryEnumerator) SupportedType() resource.ResourceType { + return aws.AwsS3BucketInventoryResourceType +} + +func (e *S3BucketInventoryEnumerator) Enumerate() ([]*resource.Resource, error) { + buckets, err := e.repository.ListAllBuckets() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsS3BucketResourceType) + } + + results := make([]*resource.Resource, 0, len(buckets)) + + for _, bucket := range buckets { + region, err := e.repository.GetBucketLocation(*bucket.Name) + if err != nil { + alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) + continue + } + if region == "" || region != e.providerConfig.DefaultAlias { + logrus.WithFields(logrus.Fields{ + "region": region, + "bucket": *bucket.Name, + }).Debug("Skipped bucket inventory") + continue + } + + inventoryConfigurations, err := e.repository.ListBucketInventoryConfigurations(bucket, region) + if err != nil { + // TODO: we should think about a way to ignore just one bucket inventory listing + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, config := range inventoryConfigurations { + id := fmt.Sprintf("%s:%s", *bucket.Name, *config.Id) + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + id, + map[string]interface{}{ + "region": region, + }, + ), + ) + } + } + + return results, err +} diff --git a/enumeration/remote/aws/s3_bucket_metrics_enumerator.go b/enumeration/remote/aws/s3_bucket_metrics_enumerator.go new file mode 100644 index 00000000..b70675ff --- /dev/null +++ b/enumeration/remote/aws/s3_bucket_metrics_enumerator.go @@ -0,0 +1,80 @@ +package aws + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/common" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + tf "github.com/snyk/driftctl/enumeration/remote/terraform" + + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type S3BucketMetricsEnumerator struct { + repository repository.S3Repository + factory resource.ResourceFactory + providerConfig tf.TerraformProviderConfig + alerter alerter.AlerterInterface +} + +func NewS3BucketMetricsEnumerator(repo repository.S3Repository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3BucketMetricsEnumerator { + return &S3BucketMetricsEnumerator{ + repository: repo, + factory: factory, + providerConfig: providerConfig, + alerter: alerter, + } +} + +func (e *S3BucketMetricsEnumerator) SupportedType() resource.ResourceType { + return aws.AwsS3BucketMetricResourceType +} + +func (e *S3BucketMetricsEnumerator) Enumerate() ([]*resource.Resource, error) { + buckets, err := e.repository.ListAllBuckets() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsS3BucketResourceType) + } + + results := make([]*resource.Resource, 0, len(buckets)) + + for _, bucket := range buckets { + region, err := e.repository.GetBucketLocation(*bucket.Name) + if err != nil { + alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) + continue + } + if region == "" || region != e.providerConfig.DefaultAlias { + logrus.WithFields(logrus.Fields{ + "region": region, + "bucket": *bucket.Name, + }).Debug("Skipped bucket") + continue + } + + metricsConfigurationList, err := e.repository.ListBucketMetricsConfigurations(bucket, region) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, metric := range metricsConfigurationList { + id := fmt.Sprintf("%s:%s", *bucket.Name, *metric.Id) + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + id, + map[string]interface{}{ + "region": region, + }, + ), + ) + } + } + + return results, nil +} diff --git a/enumeration/remote/aws/s3_bucket_notification_enumerator.go b/enumeration/remote/aws/s3_bucket_notification_enumerator.go new file mode 100644 index 00000000..8f9ec599 --- /dev/null +++ b/enumeration/remote/aws/s3_bucket_notification_enumerator.go @@ -0,0 +1,84 @@ +package aws + +import ( + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/common" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + tf "github.com/snyk/driftctl/enumeration/remote/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type S3BucketNotificationEnumerator struct { + repository repository.S3Repository + factory resource.ResourceFactory + providerConfig tf.TerraformProviderConfig + alerter alerter.AlerterInterface +} + +func NewS3BucketNotificationEnumerator(repo repository.S3Repository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3BucketNotificationEnumerator { + return &S3BucketNotificationEnumerator{ + repository: repo, + factory: factory, + providerConfig: providerConfig, + alerter: alerter, + } +} + +func (e *S3BucketNotificationEnumerator) SupportedType() resource.ResourceType { + return aws.AwsS3BucketNotificationResourceType +} + +func (e *S3BucketNotificationEnumerator) Enumerate() ([]*resource.Resource, error) { + buckets, err := e.repository.ListAllBuckets() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsS3BucketResourceType) + } + + results := make([]*resource.Resource, 0, len(buckets)) + + for _, bucket := range buckets { + region, err := e.repository.GetBucketLocation(*bucket.Name) + if err != nil { + alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) + continue + } + if region == "" || region != e.providerConfig.DefaultAlias { + logrus.WithFields(logrus.Fields{ + "region": region, + "bucket": *bucket.Name, + }).Debug("Skipped bucket") + continue + } + + notification, err := e.repository.GetBucketNotification(*bucket.Name, region) + if err != nil { + alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) + continue + } + + if notification == nil { + logrus.WithFields(logrus.Fields{ + "region": region, + "bucket": *bucket.Name, + }).Debug("Skipped empty bucket notification") + continue + } + + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *bucket.Name, + map[string]interface{}{ + "region": region, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/s3_bucket_policy_enumerator.go b/enumeration/remote/aws/s3_bucket_policy_enumerator.go new file mode 100644 index 00000000..c2f860f5 --- /dev/null +++ b/enumeration/remote/aws/s3_bucket_policy_enumerator.go @@ -0,0 +1,78 @@ +package aws + +import ( + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/common" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + tf "github.com/snyk/driftctl/enumeration/remote/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type S3BucketPolicyEnumerator struct { + repository repository.S3Repository + factory resource.ResourceFactory + providerConfig tf.TerraformProviderConfig + alerter alerter.AlerterInterface +} + +func NewS3BucketPolicyEnumerator(repo repository.S3Repository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3BucketPolicyEnumerator { + return &S3BucketPolicyEnumerator{ + repository: repo, + factory: factory, + providerConfig: providerConfig, + alerter: alerter, + } +} + +func (e *S3BucketPolicyEnumerator) SupportedType() resource.ResourceType { + return aws.AwsS3BucketPolicyResourceType +} + +func (e *S3BucketPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { + buckets, err := e.repository.ListAllBuckets() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsS3BucketResourceType) + } + + results := make([]*resource.Resource, 0, len(buckets)) + + for _, bucket := range buckets { + region, err := e.repository.GetBucketLocation(*bucket.Name) + if err != nil { + alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) + continue + } + if region == "" || region != e.providerConfig.DefaultAlias { + logrus.WithFields(logrus.Fields{ + "region": region, + "bucket": *bucket.Name, + }).Debug("Skipped bucket policy") + continue + } + + policy, err := e.repository.GetBucketPolicy(*bucket.Name, region) + if err != nil { + alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) + continue + } + + if policy != nil { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *bucket.Name, + map[string]interface{}{ + "region": region, + }, + ), + ) + } + } + + return results, err +} diff --git a/enumeration/remote/aws/s3_bucket_public_access_block_enumerator.go b/enumeration/remote/aws/s3_bucket_public_access_block_enumerator.go new file mode 100644 index 00000000..be658c3e --- /dev/null +++ b/enumeration/remote/aws/s3_bucket_public_access_block_enumerator.go @@ -0,0 +1,82 @@ +package aws + +import ( + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/common" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + tf "github.com/snyk/driftctl/enumeration/remote/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type S3BucketPublicAccessBlockEnumerator struct { + repository repository.S3Repository + factory resource.ResourceFactory + providerConfig tf.TerraformProviderConfig + alerter alerter.AlerterInterface +} + +func NewS3BucketPublicAccessBlockEnumerator(repo repository.S3Repository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3BucketPublicAccessBlockEnumerator { + return &S3BucketPublicAccessBlockEnumerator{ + repository: repo, + factory: factory, + providerConfig: providerConfig, + alerter: alerter, + } +} + +func (e *S3BucketPublicAccessBlockEnumerator) SupportedType() resource.ResourceType { + return aws.AwsS3BucketPublicAccessBlockResourceType +} + +func (e *S3BucketPublicAccessBlockEnumerator) Enumerate() ([]*resource.Resource, error) { + buckets, err := e.repository.ListAllBuckets() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsS3BucketResourceType) + } + + results := make([]*resource.Resource, 0, len(buckets)) + + for _, bucket := range buckets { + region, err := e.repository.GetBucketLocation(*bucket.Name) + if err != nil { + alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) + continue + } + if region == "" || region != e.providerConfig.DefaultAlias { + logrus.WithFields(logrus.Fields{ + "region": region, + "bucket": *bucket.Name, + }).Debug("Skipped bucket public access block") + continue + } + + block, err := e.repository.GetBucketPublicAccessBlock(*bucket.Name, region) + if err != nil { + alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) + continue + } + + if block != nil { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *bucket.Name, + map[string]interface{}{ + "block_public_acls": awssdk.BoolValue(block.BlockPublicAcls), + "block_public_policy": awssdk.BoolValue(block.BlockPublicPolicy), + "ignore_public_acls": awssdk.BoolValue(block.IgnorePublicAcls), + "restrict_public_buckets": awssdk.BoolValue(block.RestrictPublicBuckets), + }, + ), + ) + } + } + + return results, err +} diff --git a/enumeration/remote/aws/sns_topic_enumerator.go b/enumeration/remote/aws/sns_topic_enumerator.go new file mode 100644 index 00000000..60849617 --- /dev/null +++ b/enumeration/remote/aws/sns_topic_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type SNSTopicEnumerator struct { + repository repository.SNSRepository + factory resource.ResourceFactory +} + +func NewSNSTopicEnumerator(repo repository.SNSRepository, factory resource.ResourceFactory) *SNSTopicEnumerator { + return &SNSTopicEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *SNSTopicEnumerator) SupportedType() resource.ResourceType { + return aws.AwsSnsTopicResourceType +} + +func (e *SNSTopicEnumerator) Enumerate() ([]*resource.Resource, error) { + topics, err := e.repository.ListAllTopics() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(topics)) + + for _, topic := range topics { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *topic.TopicArn, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/sns_topic_policy_enumerator.go b/enumeration/remote/aws/sns_topic_policy_enumerator.go new file mode 100644 index 00000000..0c643383 --- /dev/null +++ b/enumeration/remote/aws/sns_topic_policy_enumerator.go @@ -0,0 +1,46 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type SNSTopicPolicyEnumerator struct { + repository repository.SNSRepository + factory resource.ResourceFactory +} + +func NewSNSTopicPolicyEnumerator(repo repository.SNSRepository, factory resource.ResourceFactory) *SNSTopicPolicyEnumerator { + return &SNSTopicPolicyEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *SNSTopicPolicyEnumerator) SupportedType() resource.ResourceType { + return aws.AwsSnsTopicPolicyResourceType +} + +func (e *SNSTopicPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { + topics, err := e.repository.ListAllTopics() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsSnsTopicResourceType) + } + + results := make([]*resource.Resource, 0, len(topics)) + + for _, topic := range topics { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *topic.TopicArn, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/sns_topic_subscription_enumerator.go b/enumeration/remote/aws/sns_topic_subscription_enumerator.go new file mode 100644 index 00000000..eaa89c55 --- /dev/null +++ b/enumeration/remote/aws/sns_topic_subscription_enumerator.go @@ -0,0 +1,84 @@ +package aws + +import ( + "fmt" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type wrongArnTopicAlert struct { + arn string + endpoint *string +} + +func NewWrongArnTopicAlert(arn string, endpoint *string) *wrongArnTopicAlert { + return &wrongArnTopicAlert{arn: arn, endpoint: endpoint} +} + +func (p *wrongArnTopicAlert) Message() string { + return fmt.Sprintf("%s with incorrect subscription arn (%s) for endpoint \"%s\" will be ignored", + aws.AwsSnsTopicSubscriptionResourceType, + p.arn, + awssdk.StringValue(p.endpoint)) +} + +func (p *wrongArnTopicAlert) ShouldIgnoreResource() bool { + return false +} + +type SNSTopicSubscriptionEnumerator struct { + repository repository.SNSRepository + factory resource.ResourceFactory + alerter alerter.AlerterInterface +} + +func NewSNSTopicSubscriptionEnumerator( + repo repository.SNSRepository, + factory resource.ResourceFactory, + alerter alerter.AlerterInterface, +) *SNSTopicSubscriptionEnumerator { + return &SNSTopicSubscriptionEnumerator{ + repo, + factory, + alerter, + } +} + +func (e *SNSTopicSubscriptionEnumerator) SupportedType() resource.ResourceType { + return aws.AwsSnsTopicSubscriptionResourceType +} + +func (e *SNSTopicSubscriptionEnumerator) Enumerate() ([]*resource.Resource, error) { + allSubscriptions, err := e.repository.ListAllSubscriptions() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(allSubscriptions)) + + for _, subscription := range allSubscriptions { + if subscription.SubscriptionArn == nil || !arn.IsARN(*subscription.SubscriptionArn) { + e.alerter.SendAlert( + fmt.Sprintf("%s.%s", e.SupportedType(), *subscription.SubscriptionArn), + NewWrongArnTopicAlert(*subscription.SubscriptionArn, subscription.Endpoint), + ) + continue + } + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *subscription.SubscriptionArn, + map[string]interface{}{}, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/sqs_queue_details_fetcher.go b/enumeration/remote/aws/sqs_queue_details_fetcher.go new file mode 100644 index 00000000..7aba9327 --- /dev/null +++ b/enumeration/remote/aws/sqs_queue_details_fetcher.go @@ -0,0 +1,47 @@ +package aws + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/terraform" + "strings" + + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) + +type SQSQueueDetailsFetcher struct { + reader terraform.ResourceReader + deserializer *resource.Deserializer +} + +func NewSQSQueueDetailsFetcher(provider terraform.ResourceReader, deserializer *resource.Deserializer) *SQSQueueDetailsFetcher { + return &SQSQueueDetailsFetcher{ + reader: provider, + deserializer: deserializer, + } +} + +func (r *SQSQueueDetailsFetcher) ReadDetails(res *resource.Resource) (*resource.Resource, error) { + ctyVal, err := r.reader.ReadResource(terraform.ReadResourceArgs{ + ID: res.ResourceId(), + Ty: aws.AwsSqsQueueResourceType, + }) + if err != nil { + if strings.Contains(err.Error(), "NonExistentQueue") { + logrus.WithFields(logrus.Fields{ + "id": res.ResourceId(), + "type": aws.AwsSqsQueueResourceType, + }).Debugf("Ignoring queue that seems to be already deleted: %+v", err) + return nil, nil + } + logrus.Error(err) + return nil, remoteerror.NewResourceScanningError(err, res.ResourceType(), res.ResourceId()) + } + deserializedRes, err := r.deserializer.DeserializeOne(aws.AwsSqsQueueResourceType, *ctyVal) + if err != nil { + return nil, err + } + + return deserializedRes, nil +} diff --git a/enumeration/remote/aws/sqs_queue_enumerator.go b/enumeration/remote/aws/sqs_queue_enumerator.go new file mode 100644 index 00000000..88af3ad5 --- /dev/null +++ b/enumeration/remote/aws/sqs_queue_enumerator.go @@ -0,0 +1,48 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" + + awssdk "github.com/aws/aws-sdk-go/aws" +) + +type SQSQueueEnumerator struct { + repository repository.SQSRepository + factory resource.ResourceFactory +} + +func NewSQSQueueEnumerator(repo repository.SQSRepository, factory resource.ResourceFactory) *SQSQueueEnumerator { + return &SQSQueueEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *SQSQueueEnumerator) SupportedType() resource.ResourceType { + return aws.AwsSqsQueueResourceType +} + +func (e *SQSQueueEnumerator) Enumerate() ([]*resource.Resource, error) { + queues, err := e.repository.ListAllQueues() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(queues)) + + for _, queue := range queues { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + awssdk.StringValue(queue), + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/aws/sqs_queue_policy_enumerator.go b/enumeration/remote/aws/sqs_queue_policy_enumerator.go new file mode 100644 index 00000000..3640cc8e --- /dev/null +++ b/enumeration/remote/aws/sqs_queue_policy_enumerator.go @@ -0,0 +1,69 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "strings" + + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" + + awssdk "github.com/aws/aws-sdk-go/aws" +) + +type SQSQueuePolicyEnumerator struct { + repository repository.SQSRepository + factory resource.ResourceFactory +} + +func NewSQSQueuePolicyEnumerator(repo repository.SQSRepository, factory resource.ResourceFactory) *SQSQueuePolicyEnumerator { + return &SQSQueuePolicyEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *SQSQueuePolicyEnumerator) SupportedType() resource.ResourceType { + return aws.AwsSqsQueuePolicyResourceType +} + +func (e *SQSQueuePolicyEnumerator) Enumerate() ([]*resource.Resource, error) { + queues, err := e.repository.ListAllQueues() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsSqsQueueResourceType) + } + + results := make([]*resource.Resource, 0, len(queues)) + + for _, queue := range queues { + attrs := map[string]interface{}{ + "policy": "", + } + attributes, err := e.repository.GetQueueAttributes(*queue) + if err != nil { + if strings.Contains(err.Error(), "NonExistentQueue") { + logrus.WithFields(logrus.Fields{ + "queue": *queue, + "type": aws.AwsSqsQueueResourceType, + }).Debugf("Ignoring queue that seems to be already deleted: %+v", err) + continue + } + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + if attributes.Attributes != nil { + attrs["policy"] = *attributes.Attributes[sqs.QueueAttributeNamePolicy] + } + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + awssdk.StringValue(queue), + attrs, + ), + ) + } + + return results, err +} diff --git a/pkg/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket-bucket-martin-test-drift.res.golden.json b/enumeration/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket-bucket-martin-test-drift.res.golden.json similarity index 100% rename from pkg/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket-bucket-martin-test-drift.res.golden.json rename to enumeration/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket-bucket-martin-test-drift.res.golden.json diff --git a/pkg/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_analytics_configuration-bucket-martin-test-drift_Analytics_Bucket.res.golden.json b/enumeration/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_analytics_configuration-bucket-martin-test-drift_Analytics_Bucket.res.golden.json similarity index 100% rename from pkg/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_analytics_configuration-bucket-martin-test-drift_Analytics_Bucket.res.golden.json rename to enumeration/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_analytics_configuration-bucket-martin-test-drift_Analytics_Bucket.res.golden.json diff --git a/pkg/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_inventory-bucket-martin-test-drift_Inventory_Bucket.res.golden.json b/enumeration/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_inventory-bucket-martin-test-drift_Inventory_Bucket.res.golden.json similarity index 100% rename from pkg/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_inventory-bucket-martin-test-drift_Inventory_Bucket.res.golden.json rename to enumeration/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_inventory-bucket-martin-test-drift_Inventory_Bucket.res.golden.json diff --git a/pkg/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_notification-bucket-martin-test-drift.res.golden.json b/enumeration/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_notification-bucket-martin-test-drift.res.golden.json similarity index 100% rename from pkg/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_notification-bucket-martin-test-drift.res.golden.json rename to enumeration/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_notification-bucket-martin-test-drift.res.golden.json diff --git a/pkg/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_policy-bucket-martin-test-drift.res.golden.json b/enumeration/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_policy-bucket-martin-test-drift.res.golden.json similarity index 100% rename from pkg/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_policy-bucket-martin-test-drift.res.golden.json rename to enumeration/remote/aws/test/analytics_inventory_nometrics/aws_s3_bucket_policy-bucket-martin-test-drift.res.golden.json diff --git a/pkg/remote/aws/test/analytics_inventory_nometrics/results.golden.json b/enumeration/remote/aws/test/analytics_inventory_nometrics/results.golden.json similarity index 100% rename from pkg/remote/aws/test/analytics_inventory_nometrics/results.golden.json rename to enumeration/remote/aws/test/analytics_inventory_nometrics/results.golden.json diff --git a/pkg/remote/aws/test/analytics_inventory_nometrics/schema.golden.json b/enumeration/remote/aws/test/analytics_inventory_nometrics/schema.golden.json similarity index 100% rename from pkg/remote/aws/test/analytics_inventory_nometrics/schema.golden.json rename to enumeration/remote/aws/test/analytics_inventory_nometrics/schema.golden.json diff --git a/pkg/remote/aws/test/s3_bucket_list/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory2_Bucket2.res.golden.json b/enumeration/remote/aws/test/s3_bucket_list/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory2_Bucket2.res.golden.json similarity index 100% rename from pkg/remote/aws/test/s3_bucket_list/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory2_Bucket2.res.golden.json rename to enumeration/remote/aws/test/s3_bucket_list/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory2_Bucket2.res.golden.json diff --git a/pkg/remote/aws/test/s3_bucket_list/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory_Bucket2.res.golden.json b/enumeration/remote/aws/test/s3_bucket_list/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory_Bucket2.res.golden.json similarity index 100% rename from pkg/remote/aws/test/s3_bucket_list/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory_Bucket2.res.golden.json rename to enumeration/remote/aws/test/s3_bucket_list/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory_Bucket2.res.golden.json diff --git a/pkg/remote/aws/test/s3_bucket_list/results.golden.json b/enumeration/remote/aws/test/s3_bucket_list/results.golden.json similarity index 100% rename from pkg/remote/aws/test/s3_bucket_list/results.golden.json rename to enumeration/remote/aws/test/s3_bucket_list/results.golden.json diff --git a/pkg/remote/aws/test/s3_bucket_list/schema.golden.json b/enumeration/remote/aws/test/s3_bucket_list/schema.golden.json similarity index 100% rename from pkg/remote/aws/test/s3_bucket_list/schema.golden.json rename to enumeration/remote/aws/test/s3_bucket_list/schema.golden.json diff --git a/enumeration/remote/aws/vpc_default_security_group_enumerator.go b/enumeration/remote/aws/vpc_default_security_group_enumerator.go new file mode 100644 index 00000000..1f18fb95 --- /dev/null +++ b/enumeration/remote/aws/vpc_default_security_group_enumerator.go @@ -0,0 +1,48 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + + "github.com/aws/aws-sdk-go/aws" +) + +type VPCDefaultSecurityGroupEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewVPCDefaultSecurityGroupEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *VPCDefaultSecurityGroupEnumerator { + return &VPCDefaultSecurityGroupEnumerator{ + repo, + factory, + } +} + +func (e *VPCDefaultSecurityGroupEnumerator) SupportedType() resource.ResourceType { + return resourceaws.AwsDefaultSecurityGroupResourceType +} + +func (e *VPCDefaultSecurityGroupEnumerator) Enumerate() ([]*resource.Resource, error) { + _, defaultSecurityGroups, err := e.repository.ListAllSecurityGroups() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(defaultSecurityGroups)) + + for _, item := range defaultSecurityGroups { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + aws.StringValue(item.GroupId), + map[string]interface{}{}, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/vpc_enumerator.go b/enumeration/remote/aws/vpc_enumerator.go new file mode 100644 index 00000000..bcac8347 --- /dev/null +++ b/enumeration/remote/aws/vpc_enumerator.go @@ -0,0 +1,47 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource/aws" + + "github.com/snyk/driftctl/enumeration/resource" +) + +type VPCEnumerator struct { + repo repository.EC2Repository + factory resource.ResourceFactory +} + +func NewVPCEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *VPCEnumerator { + return &VPCEnumerator{ + repo, + factory, + } +} + +func (e *VPCEnumerator) SupportedType() resource.ResourceType { + return aws.AwsVpcResourceType +} + +func (e *VPCEnumerator) Enumerate() ([]*resource.Resource, error) { + VPCs, _, err := e.repo.ListAllVPCs() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(VPCs)) + + for _, item := range VPCs { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *item.VpcId, + map[string]interface{}{}, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/vpc_security_group_enumerator.go b/enumeration/remote/aws/vpc_security_group_enumerator.go new file mode 100644 index 00000000..47d06699 --- /dev/null +++ b/enumeration/remote/aws/vpc_security_group_enumerator.go @@ -0,0 +1,48 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + + "github.com/aws/aws-sdk-go/aws" +) + +type VPCSecurityGroupEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +func NewVPCSecurityGroupEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *VPCSecurityGroupEnumerator { + return &VPCSecurityGroupEnumerator{ + repo, + factory, + } +} + +func (e *VPCSecurityGroupEnumerator) SupportedType() resource.ResourceType { + return resourceaws.AwsSecurityGroupResourceType +} + +func (e *VPCSecurityGroupEnumerator) Enumerate() ([]*resource.Resource, error) { + securityGroups, _, err := e.repository.ListAllSecurityGroups() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(securityGroups)) + + for _, item := range securityGroups { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + aws.StringValue(item.GroupId), + map[string]interface{}{}, + ), + ) + } + + return results, nil +} diff --git a/enumeration/remote/aws/vpc_security_group_rule_enumerator.go b/enumeration/remote/aws/vpc_security_group_rule_enumerator.go new file mode 100644 index 00000000..aacc2f74 --- /dev/null +++ b/enumeration/remote/aws/vpc_security_group_rule_enumerator.go @@ -0,0 +1,169 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ec2" +) + +const ( + sgRuleTypeIngress = "ingress" + sgRuleTypeEgress = "egress" +) + +type VPCSecurityGroupRuleEnumerator struct { + repository repository.EC2Repository + factory resource.ResourceFactory +} + +type securityGroupRule struct { + Type string + SecurityGroupId string + Protocol string + FromPort float64 + ToPort float64 + Self bool + SourceSecurityGroupId string + CidrBlocks []string + Ipv6CidrBlocks []string + PrefixListIds []string +} + +func (s *securityGroupRule) getId() string { + attrs := s.getAttrs() + return resourceaws.CreateSecurityGroupRuleIdHash(&attrs) +} + +func (s *securityGroupRule) getAttrs() resource.Attributes { + attrs := resource.Attributes{ + "type": s.Type, + "security_group_id": s.SecurityGroupId, + "protocol": s.Protocol, + "from_port": s.FromPort, + "to_port": s.ToPort, + "self": s.Self, + "source_security_group_id": s.SourceSecurityGroupId, + "cidr_blocks": toInterfaceSlice(s.CidrBlocks), + "ipv6_cidr_blocks": toInterfaceSlice(s.Ipv6CidrBlocks), + "prefix_list_ids": toInterfaceSlice(s.PrefixListIds), + } + + return attrs +} + +func toInterfaceSlice(val []string) []interface{} { + var res []interface{} + for _, v := range val { + res = append(res, v) + } + return res +} + +func NewVPCSecurityGroupRuleEnumerator(repository repository.EC2Repository, factory resource.ResourceFactory) *VPCSecurityGroupRuleEnumerator { + return &VPCSecurityGroupRuleEnumerator{ + repository, + factory, + } +} + +func (e *VPCSecurityGroupRuleEnumerator) SupportedType() resource.ResourceType { + return resourceaws.AwsSecurityGroupRuleResourceType +} + +func (e *VPCSecurityGroupRuleEnumerator) Enumerate() ([]*resource.Resource, error) { + securityGroups, defaultSecurityGroups, err := e.repository.ListAllSecurityGroups() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsSecurityGroupResourceType) + } + + secGroups := make([]*ec2.SecurityGroup, 0, len(securityGroups)+len(defaultSecurityGroups)) + secGroups = append(secGroups, securityGroups...) + secGroups = append(secGroups, defaultSecurityGroups...) + securityGroupsRules := e.listSecurityGroupsRules(secGroups) + + results := make([]*resource.Resource, 0, len(securityGroupsRules)) + for _, rule := range securityGroupsRules { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + rule.getId(), + rule.getAttrs(), + ), + ) + } + + return results, nil +} + +func (e *VPCSecurityGroupRuleEnumerator) listSecurityGroupsRules(securityGroups []*ec2.SecurityGroup) []securityGroupRule { + var securityGroupsRules []securityGroupRule + for _, sg := range securityGroups { + for _, rule := range sg.IpPermissions { + securityGroupsRules = append(securityGroupsRules, e.addSecurityGroupRule(sgRuleTypeIngress, rule, sg)...) + } + for _, rule := range sg.IpPermissionsEgress { + securityGroupsRules = append(securityGroupsRules, e.addSecurityGroupRule(sgRuleTypeEgress, rule, sg)...) + } + } + return securityGroupsRules +} + +// addSecurityGroupRule will iterate through each "Source" as per Aws definition and create a +// rule with custom attributes +func (e *VPCSecurityGroupRuleEnumerator) addSecurityGroupRule(ruleType string, rule *ec2.IpPermission, sg *ec2.SecurityGroup) []securityGroupRule { + var rules []securityGroupRule + for _, groupPair := range rule.UserIdGroupPairs { + r := securityGroupRule{ + Type: ruleType, + SecurityGroupId: aws.StringValue(sg.GroupId), + Protocol: aws.StringValue(rule.IpProtocol), + FromPort: float64(aws.Int64Value(rule.FromPort)), + ToPort: float64(aws.Int64Value(rule.ToPort)), + } + if aws.StringValue(groupPair.GroupId) == aws.StringValue(sg.GroupId) { + r.Self = true + } else { + r.SourceSecurityGroupId = aws.StringValue(groupPair.GroupId) + } + rules = append(rules, r) + } + for _, ipRange := range rule.IpRanges { + r := securityGroupRule{ + Type: ruleType, + SecurityGroupId: aws.StringValue(sg.GroupId), + Protocol: aws.StringValue(rule.IpProtocol), + FromPort: float64(aws.Int64Value(rule.FromPort)), + ToPort: float64(aws.Int64Value(rule.ToPort)), + CidrBlocks: []string{aws.StringValue(ipRange.CidrIp)}, + } + rules = append(rules, r) + } + for _, ipRange := range rule.Ipv6Ranges { + r := securityGroupRule{ + Type: ruleType, + SecurityGroupId: aws.StringValue(sg.GroupId), + Protocol: aws.StringValue(rule.IpProtocol), + FromPort: float64(aws.Int64Value(rule.FromPort)), + ToPort: float64(aws.Int64Value(rule.ToPort)), + Ipv6CidrBlocks: []string{aws.StringValue(ipRange.CidrIpv6)}, + } + rules = append(rules, r) + } + for _, listId := range rule.PrefixListIds { + r := securityGroupRule{ + Type: ruleType, + SecurityGroupId: aws.StringValue(sg.GroupId), + Protocol: aws.StringValue(rule.IpProtocol), + FromPort: float64(aws.Int64Value(rule.FromPort)), + ToPort: float64(aws.Int64Value(rule.ToPort)), + PrefixListIds: []string{aws.StringValue(listId.PrefixListId)}, + } + rules = append(rules, r) + } + return rules +} diff --git a/enumeration/remote/aws_api_gateway_scanner_test.go b/enumeration/remote/aws_api_gateway_scanner_test.go new file mode 100644 index 00000000..8be6291b --- /dev/null +++ b/enumeration/remote/aws_api_gateway_scanner_test.go @@ -0,0 +1,1727 @@ +package remote + +import ( + "testing" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/apigateway" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + aws2 "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/enumeration/terraform" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test/remote" + testresource "github.com/snyk/driftctl/test/resource" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestApiGatewayRestApi(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway rest apis", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRestApis").Return([]*apigateway.RestApi{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway rest apis", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRestApis").Return([]*apigateway.RestApi{ + {Id: awssdk.String("3of73v5ob4")}, + {Id: awssdk.String("1jitcobwol")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "3of73v5ob4") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayRestApiResourceType) + + assert.Equal(t, got[1].ResourceId(), "1jitcobwol") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayRestApiResourceType) + }, + }, + { + test: "cannot list api gateway rest apis", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRestApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayRestApiResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayRestApiResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayRestApiResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayRestApiEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayAccount(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway account", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("GetAccount").Return(nil, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "empty api gateway account", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("GetAccount").Return(&apigateway.Account{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + + assert.Equal(t, got[0].ResourceId(), "api-gateway-account") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayAccountResourceType) + }, + }, + { + test: "cannot get api gateway account", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("GetAccount").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayAccountResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayAccountResourceType, resourceaws.AwsApiGatewayAccountResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayAccountResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayAccountEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayApiKey(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway api keys", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApiKeys").Return([]*apigateway.ApiKey{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway api keys", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApiKeys").Return([]*apigateway.ApiKey{ + {Id: awssdk.String("fuwnl8lrva")}, + {Id: awssdk.String("9ge737dd45")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "fuwnl8lrva") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayApiKeyResourceType) + + assert.Equal(t, got[1].ResourceId(), "9ge737dd45") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayApiKeyResourceType) + }, + }, + { + test: "cannot list api gateway api keys", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApiKeys").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayApiKeyResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayApiKeyResourceType, resourceaws.AwsApiGatewayApiKeyResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayApiKeyResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayApiKeyEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayAuthorizer(t *testing.T) { + dummyError := errors.New("this is an error") + apis := []*apigateway.RestApi{ + {Id: awssdk.String("3of73v5ob4")}, + {Id: awssdk.String("1jitcobwol")}, + } + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway authorizers", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiAuthorizers", *apis[0].Id).Return([]*apigateway.Authorizer{}, nil).Once() + repo.On("ListAllRestApiAuthorizers", *apis[1].Id).Return([]*apigateway.Authorizer{}, nil).Once() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway authorizers", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiAuthorizers", *apis[0].Id).Return([]*apigateway.Authorizer{ + {Id: awssdk.String("ypcpde")}, + }, nil).Once() + repo.On("ListAllRestApiAuthorizers", *apis[1].Id).Return([]*apigateway.Authorizer{ + {Id: awssdk.String("bwhebj")}, + }, nil).Once() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "ypcpde") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayAuthorizerResourceType) + + assert.Equal(t, got[1].ResourceId(), "bwhebj") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayAuthorizerResourceType) + }, + }, + { + test: "cannot list rest apis", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayAuthorizerResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayAuthorizerResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayAuthorizerResourceType, resourceaws.AwsApiGatewayRestApiResourceType), + }, + { + test: "cannot list api gateway resources", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiAuthorizers", *apis[0].Id).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayAuthorizerResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayAuthorizerResourceType, resourceaws.AwsApiGatewayAuthorizerResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayAuthorizerResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayAuthorizerEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayStage(t *testing.T) { + dummyError := errors.New("this is an error") + apis := []*apigateway.RestApi{ + {Id: awssdk.String("3of73v5ob4")}, + } + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway stages", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiStages", *apis[0].Id).Return([]*apigateway.Stage{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway stages", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiStages", *apis[0].Id).Return([]*apigateway.Stage{ + {StageName: awssdk.String("foo")}, + {StageName: awssdk.String("baz")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "ags-3of73v5ob4-foo") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayStageResourceType) + + assert.Equal(t, got[1].ResourceId(), "ags-3of73v5ob4-baz") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayStageResourceType) + }, + }, + { + test: "cannot list rest apis", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayStageResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayStageResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayStageResourceType, resourceaws.AwsApiGatewayRestApiResourceType), + }, + { + test: "cannot list api gateway stages", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiStages", *apis[0].Id).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayStageResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayStageResourceType, resourceaws.AwsApiGatewayStageResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayStageResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayStageEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayResource(t *testing.T) { + dummyError := errors.New("this is an error") + apis := []*apigateway.RestApi{ + {Id: awssdk.String("3of73v5ob4")}, + } + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway resources", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway resources", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ + {Id: awssdk.String("21zk4y"), Path: awssdk.String("/")}, + {Id: awssdk.String("2ltv32p058"), Path: awssdk.String("/")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "21zk4y") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayResourceResourceType) + + assert.Equal(t, got[1].ResourceId(), "2ltv32p058") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayResourceResourceType) + }, + }, + { + test: "cannot list rest apis", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayResourceResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayResourceResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayResourceResourceType, resourceaws.AwsApiGatewayRestApiResourceType), + }, + { + test: "cannot list api gateway resources", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiResources", *apis[0].Id).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayResourceResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayResourceResourceType, resourceaws.AwsApiGatewayResourceResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayResourceResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayResourceEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayDomainName(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway domain names", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllDomainNames").Return([]*apigateway.DomainName{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "single api gateway domain name", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllDomainNames").Return([]*apigateway.DomainName{ + {DomainName: awssdk.String("example-driftctl.com")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + + assert.Equal(t, got[0].ResourceId(), "example-driftctl.com") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayDomainNameResourceType) + }, + }, + { + test: "cannot list api gateway domain names", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllDomainNames").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayDomainNameResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayDomainNameResourceType, resourceaws.AwsApiGatewayDomainNameResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayDomainNameResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayDomainNameEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayVpcLink(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway vpc links", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllVpcLinks").Return([]*apigateway.UpdateVpcLinkOutput{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "single api gateway vpc link", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllVpcLinks").Return([]*apigateway.UpdateVpcLinkOutput{ + {Id: awssdk.String("ipu24n")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + + assert.Equal(t, got[0].ResourceId(), "ipu24n") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayVpcLinkResourceType) + }, + }, + { + test: "cannot list api gateway vpc links", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllVpcLinks").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayVpcLinkResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayVpcLinkResourceType, resourceaws.AwsApiGatewayVpcLinkResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayVpcLinkResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayVpcLinkEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayRequestValidator(t *testing.T) { + dummyError := errors.New("this is an error") + apis := []*apigateway.RestApi{ + {Id: awssdk.String("vryjzimtj1")}, + } + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway request validators", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiRequestValidators", *apis[0].Id).Return([]*apigateway.UpdateRequestValidatorOutput{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway request validators", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiRequestValidators", *apis[0].Id).Return([]*apigateway.UpdateRequestValidatorOutput{ + {Id: awssdk.String("ywlcuf")}, + {Id: awssdk.String("qmpbs8")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "ywlcuf") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayRequestValidatorResourceType) + + assert.Equal(t, got[1].ResourceId(), "qmpbs8") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayRequestValidatorResourceType) + }, + }, + { + test: "cannot list rest apis", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayRequestValidatorResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayRequestValidatorResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayRequestValidatorResourceType, resourceaws.AwsApiGatewayRestApiResourceType), + }, + { + test: "cannot list api gateway request validators", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiRequestValidators", *apis[0].Id).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayRequestValidatorResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayRequestValidatorResourceType, resourceaws.AwsApiGatewayRequestValidatorResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayRequestValidatorResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayRequestValidatorEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayRestApiPolicy(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway rest api policies", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRestApis").Return([]*apigateway.RestApi{ + {Id: awssdk.String("3of73v5ob4")}, + {Id: awssdk.String("9x7kq9pbyh"), Policy: awssdk.String("")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway rest api policies", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRestApis").Return([]*apigateway.RestApi{ + {Id: awssdk.String("c3n3aqga5d"), Policy: awssdk.String("{\"Version\":\"2012-10-17\",\"Statement\":[{\"Effect\":\"Allow\",\"Principal\":{\"AWS\":\"*\"},\"Action\":\"execute-api:Invoke\",\"Resource\":\"arn:aws:execute-api:us-east-1:111111111111:c3n3aqga5d/*\",\"Condition\":{\"IpAddress\":{\"aws:SourceIp\":\"123.123.123.123/32\"}}}]}")}, + {Id: awssdk.String("9y1eus3hr7"), Policy: awssdk.String("{\"Version\":\"2012-10-17\",\"Statement\":[{\"Effect\":\"Allow\",\"Principal\":{\"AWS\":\"*\"},\"Action\":\"execute-api:Invoke\",\"Resource\":\"arn:aws:execute-api:us-east-1:111111111111:9y1eus3hr7/*\",\"Condition\":{\"IpAddress\":{\"aws:SourceIp\":\"123.123.123.123/32\"}}}]}")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "c3n3aqga5d") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayRestApiPolicyResourceType) + + assert.Equal(t, got[1].ResourceId(), "9y1eus3hr7") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayRestApiPolicyResourceType) + }, + }, + { + test: "cannot list rest apis", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayRestApiPolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayRestApiPolicyResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayRestApiPolicyResourceType, resourceaws.AwsApiGatewayRestApiResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayRestApiPolicyEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayBasePathMapping(t *testing.T) { + dummyError := errors.New("this is an error") + domainNames := []*apigateway.DomainName{ + {DomainName: awssdk.String("example-driftctl.com")}, + } + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no domain name base path mappings", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllDomainNames").Return(domainNames, nil) + repo.On("ListAllDomainNameBasePathMappings", *domainNames[0].DomainName).Return([]*apigateway.BasePathMapping{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple domain name base path mappings", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllDomainNames").Return(domainNames, nil) + repo.On("ListAllDomainNameBasePathMappings", *domainNames[0].DomainName).Return([]*apigateway.BasePathMapping{ + {BasePath: awssdk.String("foo")}, + {BasePath: awssdk.String("(none)")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "example-driftctl.com/foo") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayBasePathMappingResourceType) + + assert.Equal(t, got[1].ResourceId(), "example-driftctl.com/") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayBasePathMappingResourceType) + }, + }, + { + test: "cannot list domain names", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllDomainNames").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayBasePathMappingResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayBasePathMappingResourceType, resourceaws.AwsApiGatewayDomainNameResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayBasePathMappingResourceType, resourceaws.AwsApiGatewayDomainNameResourceType), + }, + { + test: "cannot list domain name base path mappings", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllDomainNames").Return(domainNames, nil) + repo.On("ListAllDomainNameBasePathMappings", *domainNames[0].DomainName).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayBasePathMappingResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayBasePathMappingResourceType, resourceaws.AwsApiGatewayBasePathMappingResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayBasePathMappingResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayBasePathMappingEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayMethod(t *testing.T) { + dummyError := errors.New("this is an error") + apis := []*apigateway.RestApi{ + {Id: awssdk.String("vryjzimtj1")}, + } + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway methods", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ + {Id: awssdk.String("hl7ksq"), Path: awssdk.String("/foo")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway methods", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ + {Id: awssdk.String("hl7ksq"), Path: awssdk.String("/foo"), ResourceMethods: map[string]*apigateway.Method{ + "GET": {}, + "POST": {}, + "DELETE": {}, + }}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 3) + + assert.Equal(t, got[0].ResourceId(), "agm-vryjzimtj1-hl7ksq-DELETE") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayMethodResourceType) + + assert.Equal(t, got[1].ResourceId(), "agm-vryjzimtj1-hl7ksq-GET") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayMethodResourceType) + + assert.Equal(t, got[2].ResourceId(), "agm-vryjzimtj1-hl7ksq-POST") + assert.Equal(t, got[2].ResourceType(), resourceaws.AwsApiGatewayMethodResourceType) + }, + }, + { + test: "cannot list rest apis", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayMethodResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResourceType, resourceaws.AwsApiGatewayRestApiResourceType), + }, + { + test: "cannot list api gateway resources", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiResources", *apis[0].Id).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayMethodResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResourceType, resourceaws.AwsApiGatewayResourceResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResourceType, resourceaws.AwsApiGatewayResourceResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayMethodEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := remote.NewSortableScanner(NewScanner(remoteLibrary, alerter, scanOptions, testFilter)) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayModel(t *testing.T) { + dummyError := errors.New("this is an error") + apis := []*apigateway.RestApi{ + {Id: awssdk.String("vryjzimtj1")}, + } + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway models", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiModels", *apis[0].Id).Return([]*apigateway.Model{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway models", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiModels", *apis[0].Id).Return([]*apigateway.Model{ + {Id: awssdk.String("g68a4s")}, + {Id: awssdk.String("85v536")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "g68a4s") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayModelResourceType) + + assert.Equal(t, got[1].ResourceId(), "85v536") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayModelResourceType) + }, + }, + { + test: "cannot list rest apis", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayModelResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayModelResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayModelResourceType, resourceaws.AwsApiGatewayRestApiResourceType), + }, + { + test: "cannot list api gateway models", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiModels", *apis[0].Id).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayModelResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayModelResourceType, resourceaws.AwsApiGatewayModelResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayModelResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayModelEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayMethodResponse(t *testing.T) { + dummyError := errors.New("this is an error") + apis := []*apigateway.RestApi{ + {Id: awssdk.String("vryjzimtj1")}, + } + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway method responses", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ + {Id: awssdk.String("hl7ksq"), Path: awssdk.String("/foo"), ResourceMethods: map[string]*apigateway.Method{ + "GET": {}, + }}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway method responses", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ + {Id: awssdk.String("hl7ksq"), Path: awssdk.String("/foo"), ResourceMethods: map[string]*apigateway.Method{ + "GET": {MethodResponses: map[string]*apigateway.MethodResponse{ + "200": {}, + "404": {}, + "503": {}, + }}, + }}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 3) + + assert.Equal(t, got[0].ResourceId(), "agmr-vryjzimtj1-hl7ksq-GET-200") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayMethodResponseResourceType) + + assert.Equal(t, got[1].ResourceId(), "agmr-vryjzimtj1-hl7ksq-GET-404") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayMethodResponseResourceType) + + assert.Equal(t, got[2].ResourceId(), "agmr-vryjzimtj1-hl7ksq-GET-503") + assert.Equal(t, got[2].ResourceType(), resourceaws.AwsApiGatewayMethodResponseResourceType) + }, + }, + { + test: "cannot list rest apis", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayMethodResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResponseResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResponseResourceType, resourceaws.AwsApiGatewayRestApiResourceType), + }, + { + test: "cannot list api gateway resources", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiResources", *apis[0].Id).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayMethodResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResponseResourceType, resourceaws.AwsApiGatewayResourceResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResponseResourceType, resourceaws.AwsApiGatewayResourceResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayMethodResponseEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := remote.NewSortableScanner(NewScanner(remoteLibrary, alerter, scanOptions, testFilter)) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayGatewayResponse(t *testing.T) { + dummyError := errors.New("this is an error") + apis := []*apigateway.RestApi{ + {Id: awssdk.String("vryjzimtj1")}, + } + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway gateway responses", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiGatewayResponses", *apis[0].Id).Return([]*apigateway.UpdateGatewayResponseOutput{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway gateway responses", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiGatewayResponses", *apis[0].Id).Return([]*apigateway.UpdateGatewayResponseOutput{ + {ResponseType: awssdk.String("UNAUTHORIZED")}, + {ResponseType: awssdk.String("ACCESS_DENIED")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "aggr-vryjzimtj1-UNAUTHORIZED") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayGatewayResponseResourceType) + + assert.Equal(t, got[1].ResourceId(), "aggr-vryjzimtj1-ACCESS_DENIED") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayGatewayResponseResourceType) + }, + }, + { + test: "cannot list rest apis", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayGatewayResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayGatewayResponseResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayGatewayResponseResourceType, resourceaws.AwsApiGatewayRestApiResourceType), + }, + { + test: "cannot list api gateway gateway responses", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiGatewayResponses", *apis[0].Id).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayGatewayResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayGatewayResponseResourceType, resourceaws.AwsApiGatewayGatewayResponseResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayGatewayResponseResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayGatewayResponseEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayMethodSettings(t *testing.T) { + dummyError := errors.New("this is an error") + apis := []*apigateway.RestApi{ + {Id: awssdk.String("vryjzimtj1")}, + } + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway method settings", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiStages", *apis[0].Id).Return([]*apigateway.Stage{ + {StageName: awssdk.String("foo"), MethodSettings: map[string]*apigateway.MethodSetting{}}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway method settings", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiStages", *apis[0].Id).Return([]*apigateway.Stage{ + {StageName: awssdk.String("foo"), MethodSettings: map[string]*apigateway.MethodSetting{ + "*/*": {}, + "foo/GET": {}, + "foo/DELETE": {}, + }}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 3) + + assert.Equal(t, got[0].ResourceId(), "vryjzimtj1-foo-*/*") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayMethodSettingsResourceType) + + assert.Equal(t, got[1].ResourceId(), "vryjzimtj1-foo-foo/DELETE") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayMethodSettingsResourceType) + + assert.Equal(t, got[2].ResourceId(), "vryjzimtj1-foo-foo/GET") + assert.Equal(t, got[2].ResourceType(), resourceaws.AwsApiGatewayMethodSettingsResourceType) + }, + }, + { + test: "cannot list rest apis", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayMethodSettingsResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodSettingsResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodSettingsResourceType, resourceaws.AwsApiGatewayRestApiResourceType), + }, + { + test: "cannot list api gateway settings", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiStages", *apis[0].Id).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayMethodSettingsResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodSettingsResourceType, resourceaws.AwsApiGatewayStageResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodSettingsResourceType, resourceaws.AwsApiGatewayStageResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayMethodSettingsEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := remote.NewSortableScanner(NewScanner(remoteLibrary, alerter, scanOptions, testFilter)) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayIntegration(t *testing.T) { + dummyError := errors.New("this is an error") + apis := []*apigateway.RestApi{ + {Id: awssdk.String("u7jce3lokk")}, + } + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway integrations", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ + {Id: awssdk.String("z9ag20"), Path: awssdk.String("/foo")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway integrations", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ + {Id: awssdk.String("z9ag20"), Path: awssdk.String("/foo"), ResourceMethods: map[string]*apigateway.Method{ + "GET": {}, + "POST": {}, + "DELETE": {}, + }}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 3) + + assert.Equal(t, got[0].ResourceId(), "agi-u7jce3lokk-z9ag20-DELETE") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayIntegrationResourceType) + + assert.Equal(t, got[1].ResourceId(), "agi-u7jce3lokk-z9ag20-GET") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayIntegrationResourceType) + + assert.Equal(t, got[2].ResourceId(), "agi-u7jce3lokk-z9ag20-POST") + assert.Equal(t, got[2].ResourceType(), resourceaws.AwsApiGatewayIntegrationResourceType) + }, + }, + { + test: "cannot list rest apis", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayIntegrationResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResourceType, resourceaws.AwsApiGatewayRestApiResourceType), + }, + { + test: "cannot list api gateway resources", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiResources", *apis[0].Id).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayIntegrationResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResourceType, resourceaws.AwsApiGatewayResourceResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResourceType, resourceaws.AwsApiGatewayResourceResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayIntegrationEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := remote.NewSortableScanner(NewScanner(remoteLibrary, alerter, scanOptions, testFilter)) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayIntegrationResponse(t *testing.T) { + dummyError := errors.New("this is an error") + apis := []*apigateway.RestApi{ + {Id: awssdk.String("u7jce3lokk")}, + } + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway integration responses", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ + {Id: awssdk.String("z9ag20"), Path: awssdk.String("/foo"), ResourceMethods: map[string]*apigateway.Method{ + "GET": {}, + }}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway integration responses", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ + {Id: awssdk.String("z9ag20"), Path: awssdk.String("/foo"), ResourceMethods: map[string]*apigateway.Method{ + "GET": { + MethodIntegration: &apigateway.Integration{ + IntegrationResponses: map[string]*apigateway.IntegrationResponse{ + "200": {}, + "302": {}, + }, + }, + }, + }}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "agir-u7jce3lokk-z9ag20-GET-200") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayIntegrationResponseResourceType) + + assert.Equal(t, got[1].ResourceId(), "agir-u7jce3lokk-z9ag20-GET-302") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayIntegrationResponseResourceType) + }, + }, + { + test: "cannot list rest apis", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayIntegrationResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResponseResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResponseResourceType, resourceaws.AwsApiGatewayRestApiResourceType), + }, + { + test: "cannot list api gateway resources", + mocks: func(repo *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRestApis").Return(apis, nil) + repo.On("ListAllRestApiResources", *apis[0].Id).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayIntegrationResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResponseResourceType, resourceaws.AwsApiGatewayResourceResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResponseResourceType, resourceaws.AwsApiGatewayResourceResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayIntegrationResponseEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := remote.NewSortableScanner(NewScanner(remoteLibrary, alerter, scanOptions, testFilter)) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_apigatewayv2_scanner_test.go b/enumeration/remote/aws_apigatewayv2_scanner_test.go new file mode 100644 index 00000000..564b34b5 --- /dev/null +++ b/enumeration/remote/aws_apigatewayv2_scanner_test.go @@ -0,0 +1,1231 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + aws2 "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/apigateway" + "github.com/aws/aws-sdk-go/service/apigatewayv2" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestApiGatewayV2Api(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayV2Repository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway v2 api", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "single api gateway v2 api", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{ + {ApiId: awssdk.String("f5vdrg12tk")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + + assert.Equal(t, got[0].ResourceId(), "f5vdrg12tk") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2ApiResourceType) + }, + }, + { + test: "cannot list api gateway v2 apis", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2ApiResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ApiResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayV2ApiResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayV2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayV2Repository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayV2ApiEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayV2Route(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayV2Repository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway v2 api", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "single api gateway v2 api with a single route", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{ + {ApiId: awssdk.String("a-gateway")}, + }, nil) + repository.On("ListAllApiRoutes", awssdk.String("a-gateway")). + Return([]*apigatewayv2.Route{{ + RouteId: awssdk.String("a-route"), + RouteKey: awssdk.String("POST /an-example"), + }}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + + assert.Equal(t, "a-route", got[0].ResourceId()) + assert.Equal(t, resourceaws.AwsApiGatewayV2RouteResourceType, got[0].ResourceType()) + expectedAttrs := &resource.Attributes{ + "api_id": "a-gateway", + "route_key": "POST /an-example", + } + assert.Equal(t, expectedAttrs, got[0].Attributes()) + }, + }, + { + test: "cannot list api gateway v2 apis", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2RouteResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ApiResourceType, resourceaws.AwsApiGatewayV2RouteResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2RouteResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayV2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayV2Repository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayV2RouteEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayV2Deployment(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayV2Repository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "single api gateway v2 api with a single deployment", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{ + {ApiId: awssdk.String("a-gateway")}, + }, nil) + repository.On("ListAllApiDeployments", awssdk.String("a-gateway")). + Return([]*apigatewayv2.Deployment{{ + DeploymentId: awssdk.String("a-deployment"), + }}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + + assert.Equal(t, "a-deployment", got[0].ResourceId()) + assert.Equal(t, resourceaws.AwsApiGatewayV2DeploymentResourceType, got[0].ResourceType()) + expectedAttrs := &resource.Attributes{} + assert.Equal(t, expectedAttrs, got[0].Attributes()) + }, + }, + { + test: "no API gateways", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "single API gateway with no deployments", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{ + {ApiId: awssdk.String("a-gateway")}, + }, nil) + repository.On("ListAllApiDeployments", awssdk.String("a-gateway")). + Return([]*apigatewayv2.Deployment{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "error listing API gateways", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2DeploymentResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ApiResourceType, resourceaws.AwsApiGatewayV2RouteResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2DeploymentResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), + }, + { + test: "error listing deployments of an API", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{ + {ApiId: awssdk.String("a-gateway")}, + }, nil) + repository.On("ListAllApiDeployments", awssdk.String("a-gateway")).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2DeploymentResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2DeploymentResourceType, resourceaws.AwsApiGatewayV2DeploymentResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2DeploymentResourceType, resourceaws.AwsApiGatewayV2DeploymentResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayV2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayV2Repository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayV2DeploymentEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayV2VpcLink(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayV2Repository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway v2 vpc links", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllVpcLinks").Return([]*apigatewayv2.VpcLink{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "single api gateway v2 vpc link", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllVpcLinks").Return([]*apigatewayv2.VpcLink{ + {VpcLinkId: awssdk.String("b8r351")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + + assert.Equal(t, got[0].ResourceId(), "b8r351") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2VpcLinkResourceType) + }, + }, + { + test: "cannot list api gateway v2 vpc links", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllVpcLinks").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2VpcLinkResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2VpcLinkResourceType, resourceaws.AwsApiGatewayV2VpcLinkResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayV2VpcLinkResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayV2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayV2Repository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayV2VpcLinkEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayV2Authorizer(t *testing.T) { + dummyError := errors.New("this is an error") + + apis := []*apigatewayv2.Api{ + {ApiId: awssdk.String("bmyl5c6huh")}, + {ApiId: awssdk.String("blghshbgte")}, + } + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayV2Repository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway v2 authorizers", + mocks: func(repo *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repo.On("ListAllApis").Return(apis, nil) + repo.On("ListAllApiAuthorizers", *apis[0].ApiId).Return([]*apigatewayv2.Authorizer{}, nil).Once() + repo.On("ListAllApiAuthorizers", *apis[1].ApiId).Return([]*apigatewayv2.Authorizer{}, nil).Once() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway v2 authorizers", + mocks: func(repo *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repo.On("ListAllApis").Return(apis, nil) + repo.On("ListAllApiAuthorizers", *apis[0].ApiId).Return([]*apigatewayv2.Authorizer{ + {AuthorizerId: awssdk.String("xaappu")}, + }, nil).Once() + repo.On("ListAllApiAuthorizers", *apis[1].ApiId).Return([]*apigatewayv2.Authorizer{ + {AuthorizerId: awssdk.String("bwhebj")}, + }, nil).Once() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "xaappu") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2AuthorizerResourceType) + + assert.Equal(t, got[1].ResourceId(), "bwhebj") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayV2AuthorizerResourceType) + }, + }, + { + test: "cannot list apis", + mocks: func(repo *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repo.On("ListAllApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2AuthorizerResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2AuthorizerResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2AuthorizerResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), + }, + { + test: "cannot list api gateway v2 authorizers", + mocks: func(repo *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repo.On("ListAllApis").Return(apis, nil) + repo.On("ListAllApiAuthorizers", *apis[0].ApiId).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2AuthorizerResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2AuthorizerResourceType, resourceaws.AwsApiGatewayV2AuthorizerResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayV2AuthorizerResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayV2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayV2Repository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayV2AuthorizerEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayV2Integration(t *testing.T) { + dummyError := errors.New("this is an error") + + apis := []*apigatewayv2.Api{ + {ApiId: awssdk.String("bmyl5c6huh")}, + {ApiId: awssdk.String("blghshbgte")}, + } + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayV2Repository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway v2 integrations", + mocks: func(repo *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repo.On("ListAllApis").Return(apis, nil) + repo.On("ListAllApiIntegrations", *apis[0].ApiId).Return([]*apigatewayv2.Integration{}, nil).Once() + repo.On("ListAllApiIntegrations", *apis[1].ApiId).Return([]*apigatewayv2.Integration{}, nil).Once() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway v2 integrations", + mocks: func(repo *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repo.On("ListAllApis").Return(apis, nil) + repo.On("ListAllApiIntegrations", *apis[0].ApiId).Return([]*apigatewayv2.Integration{ + { + IntegrationId: awssdk.String("xaappu"), + IntegrationType: awssdk.String("MOCK"), + }, + }, nil).Once() + repo.On("ListAllApiIntegrations", *apis[1].ApiId).Return([]*apigatewayv2.Integration{ + { + IntegrationId: awssdk.String("bwhebj"), + IntegrationType: awssdk.String("MOCK"), + }, + }, nil).Once() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "xaappu") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2IntegrationResourceType) + + assert.Equal(t, got[1].ResourceId(), "bwhebj") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayV2IntegrationResourceType) + }, + }, + { + test: "cannot list apis", + mocks: func(repo *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repo.On("ListAllApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2IntegrationResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), + }, + { + test: "cannot list api gateway v2 integrations", + mocks: func(repo *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repo.On("ListAllApis").Return(apis, nil) + repo.On("ListAllApiIntegrations", *apis[0].ApiId).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2IntegrationResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResourceType, resourceaws.AwsApiGatewayV2IntegrationResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayV2IntegrationResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayV2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayV2Repository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayV2IntegrationEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayV2Model(t *testing.T) { + dummyError := errors.New("this is an error") + + apis := []*apigatewayv2.Api{ + {ApiId: awssdk.String("bmyl5c6huh")}, + {ApiId: awssdk.String("blghshbgte")}, + } + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayV2Repository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway v2 models", + mocks: func(repo *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repo.On("ListAllApis").Return(apis, nil) + repo.On("ListAllApiModels", *apis[0].ApiId).Return([]*apigatewayv2.Model{}, nil).Once() + repo.On("ListAllApiModels", *apis[1].ApiId).Return([]*apigatewayv2.Model{}, nil).Once() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple api gateway v2 models", + mocks: func(repo *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repo.On("ListAllApis").Return(apis, nil) + repo.On("ListAllApiModels", *apis[0].ApiId).Return([]*apigatewayv2.Model{ + { + ModelId: awssdk.String("vdw6up"), + Name: awssdk.String("model1"), + }, + }, nil).Once() + repo.On("ListAllApiModels", *apis[1].ApiId).Return([]*apigatewayv2.Model{ + { + ModelId: awssdk.String("bwhebj"), + Name: awssdk.String("model2"), + }, + }, nil).Once() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "vdw6up") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2ModelResourceType) + assert.Equal(t, "model1", *got[0].Attributes().GetString("name")) + + assert.Equal(t, got[1].ResourceId(), "bwhebj") + assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayV2ModelResourceType) + assert.Equal(t, "model2", *got[1].Attributes().GetString("name")) + + }, + }, + { + test: "cannot list apis", + mocks: func(repo *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repo.On("ListAllApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2ModelResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ModelResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ModelResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), + }, + { + test: "cannot list api gateway v2 model", + mocks: func(repo *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repo.On("ListAllApis").Return(apis, nil) + repo.On("ListAllApiModels", *apis[0].ApiId).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2ModelResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ModelResourceType, resourceaws.AwsApiGatewayV2ModelResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayV2ModelResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayV2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayV2Repository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayV2ModelEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayV2Stage(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayV2Repository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway v2 api", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "single api gateway v2 api with a single stage", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{ + {ApiId: awssdk.String("a-gateway")}, + }, nil) + repository.On("ListAllApiStages", "a-gateway"). + Return([]*apigatewayv2.Stage{{ + StageName: awssdk.String("a-stage"), + }}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + + assert.Equal(t, "a-stage", got[0].ResourceId()) + assert.Equal(t, resourceaws.AwsApiGatewayV2StageResourceType, got[0].ResourceType()) + }, + }, + { + test: "cannot list api gateway v2 apis", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2StageResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ApiResourceType, resourceaws.AwsApiGatewayV2StageResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2StageResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayV2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayV2Repository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayV2StageEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayV2RouteResponse(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayV2Repository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway v2 route responses", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{ + {ApiId: awssdk.String("77ooqulkke")}, + }, nil) + repository.On("ListAllApiRoutes", awssdk.String("77ooqulkke")). + Return([]*apigatewayv2.Route{ + {RouteId: awssdk.String("liqc5u4")}, + }, nil) + repository.On("ListAllApiRouteResponses", "77ooqulkke", "liqc5u4"). + Return([]*apigatewayv2.RouteResponse{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "single api gateway v2 route with one route response", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{ + {ApiId: awssdk.String("77ooqulkke")}, + }, nil) + repository.On("ListAllApiRoutes", awssdk.String("77ooqulkke")). + Return([]*apigatewayv2.Route{ + {RouteId: awssdk.String("liqc5u4")}, + }, nil) + repository.On("ListAllApiRouteResponses", "77ooqulkke", "liqc5u4"). + Return([]*apigatewayv2.RouteResponse{ + {RouteResponseId: awssdk.String("nbw7vw")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + + assert.Equal(t, got[0].ResourceId(), "nbw7vw") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2RouteResponseResourceType) + }, + }, + { + test: "cannot list api gateway v2 apis", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2RouteResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ApiResourceType, resourceaws.AwsApiGatewayV2RouteResponseResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2RouteResponseResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), + }, + { + test: "cannot list api gateway v2 routes", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{ + {ApiId: awssdk.String("77ooqulkke")}, + }, nil) + repository.On("ListAllApiRoutes", awssdk.String("77ooqulkke")).Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2RouteResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2RouteResourceType, resourceaws.AwsApiGatewayV2RouteResponseResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2RouteResponseResourceType, resourceaws.AwsApiGatewayV2RouteResourceType), + }, + { + test: "cannot list api gateway v2 route responses", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{ + {ApiId: awssdk.String("77ooqulkke")}, + }, nil) + repository.On("ListAllApiRoutes", awssdk.String("77ooqulkke")). + Return([]*apigatewayv2.Route{ + {RouteId: awssdk.String("liqc5u4")}, + }, nil) + repository.On("ListAllApiRouteResponses", "77ooqulkke", "liqc5u4").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2RouteResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2RouteResponseResourceType, resourceaws.AwsApiGatewayV2RouteResponseResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2RouteResponseResourceType, resourceaws.AwsApiGatewayV2RouteResponseResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayV2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayV2Repository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayV2RouteResponseEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayV2Mapping(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *repository2.MockApiGatewayV2Repository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway v2 domains", + mocks: func(repositoryV1 *repository2.MockApiGatewayRepository, repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repositoryV1.On("ListAllDomainNames").Return([]*apigateway.DomainName{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "single api gateway v2 domain with a single mapping", + mocks: func(repositoryV1 *repository2.MockApiGatewayRepository, repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repositoryV1.On("ListAllDomainNames").Return([]*apigateway.DomainName{ + {DomainName: awssdk.String("example.com")}, + }, nil) + repository.On("ListAllApiMappings", "example.com"). + Return([]*apigatewayv2.ApiMapping{{ + Stage: awssdk.String("a-stage"), + ApiId: awssdk.String("foobar"), + ApiMappingId: awssdk.String("barfoo"), + }}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + + assert.Equal(t, "barfoo", got[0].ResourceId()) + assert.Equal(t, resourceaws.AwsApiGatewayV2MappingResourceType, got[0].ResourceType()) + }, + }, + { + test: "cannot list api gateway v2 domains", + mocks: func(repositoryV1 *repository2.MockApiGatewayRepository, repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repositoryV1.On("ListAllDomainNames").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2MappingResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayDomainNameResourceType, resourceaws.AwsApiGatewayV2MappingResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2MappingResourceType, resourceaws.AwsApiGatewayDomainNameResourceType), + }, + { + test: "cannot list api gateway v2 mappings", + mocks: func(repositoryV1 *repository2.MockApiGatewayRepository, repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repositoryV1.On("ListAllDomainNames").Return([]*apigateway.DomainName{ + {DomainName: awssdk.String("example.com")}, + }, nil) + repository.On("ListAllApiMappings", "example.com"). + Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2MappingResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2MappingResourceType, resourceaws.AwsApiGatewayV2MappingResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayV2MappingResourceType), + }, + { + test: "returning mapping with invalid attributes", + mocks: func(repositoryV1 *repository2.MockApiGatewayRepository, repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repositoryV1.On("ListAllDomainNames").Return([]*apigateway.DomainName{ + {DomainName: awssdk.String("example.com")}, + }, nil) + repository.On("ListAllApiMappings", "example.com"). + Return([]*apigatewayv2.ApiMapping{ + { + ApiMappingId: awssdk.String("barfoo"), + }, + { + Stage: awssdk.String("a-stage"), + ApiId: awssdk.String("foobar"), + ApiMappingId: awssdk.String("foobar"), + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, "barfoo", got[0].ResourceId()) + assert.Equal(t, resourceaws.AwsApiGatewayV2MappingResourceType, got[0].ResourceType()) + + assert.Equal(t, "foobar", got[1].ResourceId()) + assert.Equal(t, resourceaws.AwsApiGatewayV2MappingResourceType, got[1].ResourceType()) + }, + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepoV1 := &repository2.MockApiGatewayRepository{} + fakeRepo := &repository2.MockApiGatewayV2Repository{} + c.mocks(fakeRepoV1, fakeRepo, alerter) + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayV2MappingEnumerator(fakeRepo, fakeRepoV1, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + fakeRepoV1.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayV2DomainName(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway v2 domain names", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllDomainNames").Return([]*apigateway.DomainName{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "single api gateway v2 domain name", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllDomainNames").Return([]*apigateway.DomainName{ + {DomainName: awssdk.String("b8r351.example.com")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + + assert.Equal(t, got[0].ResourceId(), "b8r351.example.com") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2DomainNameResourceType) + }, + }, + { + test: "cannot list api gateway v2 domain names", + mocks: func(repository *repository2.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllDomainNames").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2DomainNameResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2DomainNameResourceType, resourceaws.AwsApiGatewayV2DomainNameResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayV2DomainNameResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayV2DomainNameEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestApiGatewayV2IntegrationResponse(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockApiGatewayV2Repository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no api gateway v2 integration responses", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{ + {ApiId: awssdk.String("yw28nwdf34")}, + }, nil) + repository.On("ListAllApiIntegrations", "yw28nwdf34"). + Return([]*apigatewayv2.Integration{ + {IntegrationId: awssdk.String("fmezvlh")}, + }, nil) + repository.On("ListAllApiIntegrationResponses", "yw28nwdf34", "fmezvlh"). + Return([]*apigatewayv2.IntegrationResponse{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "single api gateway v2 integration with one integration response", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{ + {ApiId: awssdk.String("yw28nwdf34")}, + }, nil) + repository.On("ListAllApiIntegrations", "yw28nwdf34"). + Return([]*apigatewayv2.Integration{ + {IntegrationId: awssdk.String("fmezvlh")}, + }, nil) + repository.On("ListAllApiIntegrationResponses", "yw28nwdf34", "fmezvlh"). + Return([]*apigatewayv2.IntegrationResponse{ + {IntegrationResponseId: awssdk.String("sf67ti7")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + + assert.Equal(t, got[0].ResourceId(), "sf67ti7") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2IntegrationResponseResourceType) + }, + }, + { + test: "cannot list api gateway v2 apis", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2IntegrationResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ApiResourceType, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), + }, + { + test: "cannot list api gateway v2 integrations", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{ + {ApiId: awssdk.String("yw28nwdf34")}, + }, nil) + repository.On("ListAllApiIntegrations", "yw28nwdf34").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2IntegrationResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResourceType, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType, resourceaws.AwsApiGatewayV2IntegrationResourceType), + }, + { + test: "cannot list api gateway v2 integration responses", + mocks: func(repository *repository2.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllApis").Return([]*apigatewayv2.Api{ + {ApiId: awssdk.String("yw28nwdf34")}, + }, nil) + repository.On("ListAllApiIntegrations", "yw28nwdf34"). + Return([]*apigatewayv2.Integration{ + {IntegrationId: awssdk.String("fmezvlh")}, + }, nil) + repository.On("ListAllApiIntegrationResponses", "yw28nwdf34", "fmezvlh").Return(nil, dummyError) + alerter.On("SendAlert", resourceaws.AwsApiGatewayV2IntegrationResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockApiGatewayV2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ApiGatewayV2Repository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewApiGatewayV2IntegrationResponseEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_applicationautoscaling_scanner_test.go b/enumeration/remote/aws_applicationautoscaling_scanner_test.go new file mode 100644 index 00000000..75ff488e --- /dev/null +++ b/enumeration/remote/aws_applicationautoscaling_scanner_test.go @@ -0,0 +1,328 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + aws2 "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/applicationautoscaling" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestAppAutoScalingTarget(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockAppAutoScalingRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "should return one target", + dirName: "aws_appautoscaling_target_single", + mocks: func(client *repository2.MockAppAutoScalingRepository, alerter *mocks.AlerterInterface) { + client.On("ServiceNamespaceValues").Return(applicationautoscaling.ServiceNamespace_Values()).Once() + + client.On("DescribeScalableTargets", "dynamodb").Return([]*applicationautoscaling.ScalableTarget{ + { + ResourceId: awssdk.String("table/GameScores"), + RoleARN: awssdk.String("arn:aws:iam::533948124879:role/aws-service-role/dynamodb.application-autoscaling.amazonaws.com/AWSServiceRoleForApplicationAutoScaling_DynamoDBTable"), + ScalableDimension: awssdk.String("dynamodb:table:ReadCapacityUnits"), + ServiceNamespace: awssdk.String("dynamodb"), + MaxCapacity: awssdk.Int64(100), + MinCapacity: awssdk.Int64(5), + }, + }, nil).Once() + + client.On("DescribeScalableTargets", mock.AnythingOfType("string")).Return([]*applicationautoscaling.ScalableTarget{}, nil).Times(len(applicationautoscaling.ServiceNamespace_Values()) - 1) + }, + wantErr: nil, + }, + { + test: "should return remote error", + dirName: "aws_appautoscaling_target_single", + mocks: func(client *repository2.MockAppAutoScalingRepository, alerter *mocks.AlerterInterface) { + client.On("ServiceNamespaceValues").Return(applicationautoscaling.ServiceNamespace_Values()).Once() + + client.On("DescribeScalableTargets", mock.AnythingOfType("string")).Return(nil, errors.New("remote error")).Once() + }, + wantErr: remoteerror.NewResourceListingError(errors.New("remote error"), resourceaws.AwsAppAutoscalingTargetResourceType), + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockAppAutoScalingRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.AppAutoScalingRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewAppAutoScalingRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewAppAutoscalingTargetEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsAppAutoscalingTargetResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsAppAutoscalingTargetResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + if err != nil { + assert.EqualError(tt, c.wantErr, err.Error()) + } else { + assert.Equal(tt, err, c.wantErr) + } + + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsAppAutoscalingTargetResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAppAutoScalingPolicy(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockAppAutoScalingRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "should return one policy", + dirName: "aws_appautoscaling_policy_single", + mocks: func(client *repository2.MockAppAutoScalingRepository, alerter *mocks.AlerterInterface) { + client.On("ServiceNamespaceValues").Return(applicationautoscaling.ServiceNamespace_Values()).Once() + + client.On("DescribeScalingPolicies", "dynamodb").Return([]*applicationautoscaling.ScalingPolicy{ + { + PolicyName: awssdk.String("DynamoDBReadCapacityUtilization:table/GameScores"), + ResourceId: awssdk.String("table/GameScores"), + ScalableDimension: awssdk.String("dynamodb:table:ReadCapacityUnits"), + ServiceNamespace: awssdk.String("dynamodb"), + }, + }, nil).Once() + + client.On("DescribeScalingPolicies", mock.AnythingOfType("string")).Return([]*applicationautoscaling.ScalingPolicy{}, nil).Times(len(applicationautoscaling.ServiceNamespace_Values()) - 1) + }, + wantErr: nil, + }, + { + test: "should return remote error", + dirName: "aws_appautoscaling_policy_single", + mocks: func(client *repository2.MockAppAutoScalingRepository, alerter *mocks.AlerterInterface) { + client.On("ServiceNamespaceValues").Return(applicationautoscaling.ServiceNamespace_Values()).Once() + + client.On("DescribeScalingPolicies", mock.AnythingOfType("string")).Return(nil, errors.New("remote error")).Once() + }, + wantErr: remoteerror.NewResourceListingError(errors.New("remote error"), resourceaws.AwsAppAutoscalingPolicyResourceType), + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockAppAutoScalingRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.AppAutoScalingRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewAppAutoScalingRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewAppAutoscalingPolicyEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsAppAutoscalingPolicyResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsAppAutoscalingPolicyResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + if err != nil { + assert.EqualError(tt, c.wantErr, err.Error()) + } else { + assert.Equal(tt, err, c.wantErr) + } + + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsAppAutoscalingPolicyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAppAutoScalingScheduledAction(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockAppAutoScalingRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "should return one scheduled action", + mocks: func(client *repository2.MockAppAutoScalingRepository, alerter *mocks.AlerterInterface) { + matchServiceNamespaceFunc := func(ns string) bool { + for _, n := range applicationautoscaling.ServiceNamespace_Values() { + if n == ns { + return true + } + } + return false + } + + client.On("ServiceNamespaceValues").Return(applicationautoscaling.ServiceNamespace_Values()).Once() + + client.On("DescribeScheduledActions", mock.MatchedBy(matchServiceNamespaceFunc)).Return([]*applicationautoscaling.ScheduledAction{ + { + ScheduledActionName: awssdk.String("action"), + ResourceId: awssdk.String("table/GameScores"), + ScalableDimension: awssdk.String("dynamodb:table:ReadCapacityUnits"), + ServiceNamespace: awssdk.String("dynamodb"), + }, + }, nil).Once() + + client.On("DescribeScheduledActions", mock.MatchedBy(matchServiceNamespaceFunc)).Return([]*applicationautoscaling.ScheduledAction{}, nil).Times(len(applicationautoscaling.ServiceNamespace_Values()) - 1) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + assert.Equal(t, "action-dynamodb-table/GameScores", got[0].ResourceId()) + assert.Equal(t, resourceaws.AwsAppAutoscalingScheduledActionResourceType, got[0].ResourceType()) + }, + wantErr: nil, + }, + { + test: "should return remote error", + mocks: func(client *repository2.MockAppAutoScalingRepository, alerter *mocks.AlerterInterface) { + client.On("ServiceNamespaceValues").Return(applicationautoscaling.ServiceNamespace_Values()).Once() + + client.On("DescribeScheduledActions", mock.AnythingOfType("string")).Return(nil, dummyError).Once() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + wantErr: remoteerror.NewResourceListingError(dummyError, resourceaws.AwsAppAutoscalingScheduledActionResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockAppAutoScalingRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.AppAutoScalingRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewAppAutoscalingScheduledActionEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_autoscaling_scanner_test.go b/enumeration/remote/aws_autoscaling_scanner_test.go new file mode 100644 index 00000000..acb41155 --- /dev/null +++ b/enumeration/remote/aws_autoscaling_scanner_test.go @@ -0,0 +1,110 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/autoscaling" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestAutoscaling_LaunchConfiguration(t *testing.T) { + tests := []struct { + test string + mocks func(*repository2.MockAutoScalingRepository, *mocks.AlerterInterface) + assertExpected func(*testing.T, []*resource.Resource) + wantErr error + }{ + { + test: "no launch configuration", + mocks: func(repository *repository2.MockAutoScalingRepository, alerter *mocks.AlerterInterface) { + repository.On("DescribeLaunchConfigurations").Return([]*autoscaling.LaunchConfiguration{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple launch configurations", + mocks: func(repository *repository2.MockAutoScalingRepository, alerter *mocks.AlerterInterface) { + repository.On("DescribeLaunchConfigurations").Return([]*autoscaling.LaunchConfiguration{ + {LaunchConfigurationName: awssdk.String("web_config_1")}, + {LaunchConfigurationName: awssdk.String("web_config_2")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, "web_config_1", got[0].ResourceId()) + assert.Equal(t, resourceaws.AwsLaunchConfigurationResourceType, got[0].ResourceType()) + + assert.Equal(t, "web_config_2", got[1].ResourceId()) + assert.Equal(t, resourceaws.AwsLaunchConfigurationResourceType, got[1].ResourceType()) + }, + }, + { + test: "cannot list launch configurations", + mocks: func(repository *repository2.MockAutoScalingRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("DescribeLaunchConfigurations").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsLaunchConfigurationResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsLaunchConfigurationResourceType, resourceaws.AwsLaunchConfigurationResourceType), alerts.EnumerationPhase)).Return() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockAutoScalingRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.AutoScalingRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws.NewLaunchConfigurationEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_cloudformation_scanner_test.go b/enumeration/remote/aws_cloudformation_scanner_test.go new file mode 100644 index 00000000..c4bfba87 --- /dev/null +++ b/enumeration/remote/aws_cloudformation_scanner_test.go @@ -0,0 +1,128 @@ +package remote + +import ( + "errors" + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/cloudformation" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestCloudformationStack(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockCloudformationRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no cloudformation stacks", + dirName: "aws_cloudformation_stack_empty", + mocks: func(repository *repository2.MockCloudformationRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllStacks").Return([]*cloudformation.Stack{}, nil) + }, + }, + { + test: "multiple cloudformation stacks", + dirName: "aws_cloudformation_stack_multiple", + mocks: func(repository *repository2.MockCloudformationRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllStacks").Return([]*cloudformation.Stack{ + {StackId: awssdk.String("arn:aws:cloudformation:us-east-1:047081014315:stack/bar-stack/c7a96e70-0f21-11ec-bd2a-0a2d95c2b2ab")}, + {StackId: awssdk.String("arn:aws:cloudformation:us-east-1:047081014315:stack/foo-stack/c7aa0ab0-0f21-11ec-ba25-129d8c0b3757")}, + }, nil) + }, + }, + { + test: "cannot list cloudformation stacks", + dirName: "aws_cloudformation_stack_list", + mocks: func(repository *repository2.MockCloudformationRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 400, "") + repository.On("ListAllStacks").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsCloudformationStackResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsCloudformationStackResourceType, resourceaws.AwsCloudformationStackResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockCloudformationRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.CloudformationRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewCloudformationRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws.NewCloudformationStackEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsCloudformationStackResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsCloudformationStackResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsCloudformationStackResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_cloudfront_scanner_test.go b/enumeration/remote/aws_cloudfront_scanner_test.go new file mode 100644 index 00000000..39d8ac54 --- /dev/null +++ b/enumeration/remote/aws_cloudfront_scanner_test.go @@ -0,0 +1,126 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/cloudfront" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestCloudfrontDistribution(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockCloudfrontRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no cloudfront distributions", + dirName: "aws_cloudfront_distribution_empty", + mocks: func(repository *repository2.MockCloudfrontRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllDistributions").Return([]*cloudfront.DistributionSummary{}, nil) + }, + }, + { + test: "single cloudfront distribution", + dirName: "aws_cloudfront_distribution_single", + mocks: func(repository *repository2.MockCloudfrontRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllDistributions").Return([]*cloudfront.DistributionSummary{ + {Id: awssdk.String("E1M9CNS0XSHI19")}, + }, nil) + }, + }, + { + test: "cannot list cloudfront distributions", + dirName: "aws_cloudfront_distribution_list", + mocks: func(repository *repository2.MockCloudfrontRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 400, "") + repository.On("ListAllDistributions").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsCloudfrontDistributionResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsCloudfrontDistributionResourceType, resourceaws.AwsCloudfrontDistributionResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockCloudfrontRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.CloudfrontRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewCloudfrontRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws.NewCloudfrontDistributionEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsCloudfrontDistributionResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsCloudfrontDistributionResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsCloudfrontDistributionResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_dynamodb_scanner_test.go b/enumeration/remote/aws_dynamodb_scanner_test.go new file mode 100644 index 00000000..7c21587c --- /dev/null +++ b/enumeration/remote/aws_dynamodb_scanner_test.go @@ -0,0 +1,128 @@ +package remote + +import ( + "errors" + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestDynamoDBTable(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockDynamoDBRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no DynamoDB Table", + dirName: "aws_dynamodb_table_empty", + mocks: func(client *repository2.MockDynamoDBRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllTables").Return([]*string{}, nil) + }, + wantErr: nil, + }, + { + test: "Multiple DynamoDB Table", + dirName: "aws_dynamodb_table_multiple", + mocks: func(client *repository2.MockDynamoDBRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllTables").Return([]*string{ + awssdk.String("GameScores"), + awssdk.String("example"), + }, nil) + }, + wantErr: nil, + }, + { + test: "cannot list DynamoDB Table", + dirName: "aws_dynamodb_table_list", + mocks: func(client *repository2.MockDynamoDBRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 400, "") + client.On("ListAllTables").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsDynamodbTableResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsDynamodbTableResourceType, resourceaws.AwsDynamodbTableResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockDynamoDBRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.DynamoDBRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewDynamoDBRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws.NewDynamoDBTableEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsDynamodbTableResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsDynamodbTableResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsDynamodbTableResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_ec2_scanner_test.go b/enumeration/remote/aws_ec2_scanner_test.go new file mode 100644 index 00000000..9ba6b9a4 --- /dev/null +++ b/enumeration/remote/aws_ec2_scanner_test.go @@ -0,0 +1,2907 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + aws2 "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestEC2EbsVolume(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no volumes", + dirName: "aws_ec2_ebs_volume_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllVolumes").Return([]*ec2.Volume{}, nil) + }, + }, + { + test: "multiple volumes", + dirName: "aws_ec2_ebs_volume_multiple", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllVolumes").Return([]*ec2.Volume{ + {VolumeId: awssdk.String("vol-081c7272a57a09db1")}, + {VolumeId: awssdk.String("vol-01ddc91d3d9d1318b")}, + }, nil) + }, + }, + { + test: "cannot list volumes", + dirName: "aws_ec2_ebs_volume_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllVolumes").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsEbsVolumeResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsEbsVolumeResourceType, resourceaws.AwsEbsVolumeResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2EbsVolumeEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsEbsVolumeResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsEbsVolumeResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsEbsVolumeResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2EbsSnapshot(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no snapshots", + dirName: "aws_ec2_ebs_snapshot_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllSnapshots").Return([]*ec2.Snapshot{}, nil) + }, + }, + { + test: "multiple snapshots", + dirName: "aws_ec2_ebs_snapshot_multiple", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllSnapshots").Return([]*ec2.Snapshot{ + {SnapshotId: awssdk.String("snap-0c509a2a880d95a39")}, + {SnapshotId: awssdk.String("snap-00672558cecd93a61")}, + }, nil) + }, + }, + { + test: "cannot list snapshots", + dirName: "aws_ec2_ebs_snapshot_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllSnapshots").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsEbsSnapshotResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsEbsSnapshotResourceType, resourceaws.AwsEbsSnapshotResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2EbsSnapshotEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsEbsSnapshotResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsEbsSnapshotResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsEbsSnapshotResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2Eip(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no eips", + dirName: "aws_ec2_eip_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllAddresses").Return([]*ec2.Address{ + {}, // Test Eip without AllocationId because it can happen (seen in sentry) + }, nil) + }, + }, + { + test: "multiple eips", + dirName: "aws_ec2_eip_multiple", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllAddresses").Return([]*ec2.Address{ + {AllocationId: awssdk.String("eipalloc-017d5267e4dda73f1")}, + {AllocationId: awssdk.String("eipalloc-0cf714dc097c992cc")}, + }, nil) + }, + }, + { + test: "cannot list eips", + dirName: "aws_ec2_eip_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllAddresses").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsEipResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsEipResourceType, resourceaws.AwsEipResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2EipEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsEipResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsEipResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsEipResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2Ami(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no amis", + dirName: "aws_ec2_ami_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllImages").Return([]*ec2.Image{}, nil) + }, + }, + { + test: "multiple amis", + dirName: "aws_ec2_ami_multiple", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllImages").Return([]*ec2.Image{ + {ImageId: awssdk.String("ami-03a578b46f4c3081b")}, + {ImageId: awssdk.String("ami-025962fd8b456731f")}, + }, nil) + }, + }, + { + test: "cannot list ami", + dirName: "aws_ec2_ami_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllImages").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsAmiResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsAmiResourceType, resourceaws.AwsAmiResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2AmiEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsAmiResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsAmiResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsAmiResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2KeyPair(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no key pairs", + dirName: "aws_ec2_key_pair_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllKeyPairs").Return([]*ec2.KeyPairInfo{}, nil) + }, + }, + { + test: "multiple key pairs", + dirName: "aws_ec2_key_pair_multiple", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllKeyPairs").Return([]*ec2.KeyPairInfo{ + {KeyName: awssdk.String("test")}, + {KeyName: awssdk.String("bar")}, + }, nil) + }, + }, + { + test: "cannot list key pairs", + dirName: "aws_ec2_key_pair_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllKeyPairs").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsKeyPairResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsKeyPairResourceType, resourceaws.AwsKeyPairResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2KeyPairEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsKeyPairResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsKeyPairResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsKeyPairResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2EipAssociation(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no eip associations", + dirName: "aws_ec2_eip_association_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllAddressesAssociation").Return([]*ec2.Address{}, nil) + }, + }, + { + test: "single eip association", + dirName: "aws_ec2_eip_association_single", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllAddressesAssociation").Return([]*ec2.Address{ + { + AssociationId: awssdk.String("eipassoc-0e9a7356e30f0c3d1"), + AllocationId: awssdk.String("eipalloc-017d5267e4dda73f1"), + }, + }, nil) + }, + }, + { + test: "cannot list eip associations", + dirName: "aws_ec2_eip_association_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllAddressesAssociation").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsEipAssociationResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsEipAssociationResourceType, resourceaws.AwsEipAssociationResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2EipAssociationEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsEipAssociationResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsEipAssociationResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsEipAssociationResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2Instance(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no instances", + dirName: "aws_ec2_instance_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllInstances").Return([]*ec2.Instance{}, nil) + }, + }, + { + test: "multiple instances", + dirName: "aws_ec2_instance_multiple", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllInstances").Return([]*ec2.Instance{ + {InstanceId: awssdk.String("i-0d3650a23f4e45dc0")}, + {InstanceId: awssdk.String("i-010376047a71419f1")}, + }, nil) + }, + }, + { + test: "terminated instances", + dirName: "aws_ec2_instance_terminated", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllInstances").Return([]*ec2.Instance{ + {InstanceId: awssdk.String("i-0e1543baf4f2cd990")}, + {InstanceId: awssdk.String("i-0a3a7ed51ae2b4fa0")}, // Nil + }, nil) + }, + }, + { + test: "cannot list instances", + dirName: "aws_ec2_instance_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllInstances").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsInstanceResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsInstanceResourceType, resourceaws.AwsInstanceResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2InstanceEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsInstanceResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsInstanceResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsInstanceResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2InternetGateway(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no internet gateways", + dirName: "aws_ec2_internet_gateway_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllInternetGateways").Return([]*ec2.InternetGateway{}, nil) + }, + }, + { + test: "multiple internet gateways", + dirName: "aws_ec2_internet_gateway_multiple", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllInternetGateways").Return([]*ec2.InternetGateway{ + {InternetGatewayId: awssdk.String("igw-0184eb41aadc62d1c")}, + {InternetGatewayId: awssdk.String("igw-047b487f5c60fca99")}, + }, nil) + }, + }, + { + test: "cannot list internet gateways", + dirName: "aws_ec2_internet_gateway_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllInternetGateways").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsInternetGatewayResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsInternetGatewayResourceType, resourceaws.AwsInternetGatewayResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2InternetGatewayEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsInternetGatewayResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsInternetGatewayResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsInternetGatewayResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestVPC(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no VPC", + dirName: "aws_vpc_empty", + mocks: func(client *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllVPCs").Once().Return([]*ec2.Vpc{}, []*ec2.Vpc{}, nil) + }, + wantErr: nil, + }, + { + test: "VPC results", + dirName: "aws_vpc", + mocks: func(client *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllVPCs").Once().Return([]*ec2.Vpc{ + { + VpcId: awssdk.String("vpc-0768e1fd0029e3fc3"), + }, + { + VpcId: awssdk.String("vpc-020b072316a95b97f"), + IsDefault: awssdk.Bool(false), + }, + { + VpcId: awssdk.String("vpc-02c50896b59598761"), + IsDefault: awssdk.Bool(false), + }, + }, []*ec2.Vpc{ + { + VpcId: awssdk.String("vpc-a8c5d4c1"), + IsDefault: awssdk.Bool(false), + }, + }, nil) + }, + wantErr: nil, + }, + { + test: "cannot list VPC", + dirName: "aws_vpc_empty", + mocks: func(client *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + client.On("ListAllVPCs").Once().Return(nil, nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsVpcResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsVpcResourceType, resourceaws.AwsVpcResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewVPCEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsVpcResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsVpcResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsVpcResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestDefaultVPC(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no VPC", + dirName: "aws_vpc_empty", + mocks: func(client *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllVPCs").Once().Return([]*ec2.Vpc{}, []*ec2.Vpc{}, nil) + }, + wantErr: nil, + }, + { + test: "default VPC results", + dirName: "aws_default_vpc", + mocks: func(client *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllVPCs").Once().Return([]*ec2.Vpc{ + { + VpcId: awssdk.String("vpc-0768e1fd0029e3fc3"), + IsDefault: awssdk.Bool(false), + }, + { + VpcId: awssdk.String("vpc-020b072316a95b97f"), + IsDefault: awssdk.Bool(false), + }, + }, []*ec2.Vpc{ + { + VpcId: awssdk.String("vpc-a8c5d4c1"), + IsDefault: awssdk.Bool(true), + }, + }, nil) + }, + wantErr: nil, + }, + { + test: "cannot list VPC", + dirName: "aws_vpc_empty", + mocks: func(client *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + client.On("ListAllVPCs").Once().Return(nil, nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsDefaultVpcResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsDefaultVpcResourceType, resourceaws.AwsDefaultVpcResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewDefaultVPCEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsDefaultVpcResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsDefaultVpcResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsDefaultVpcResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2RouteTableAssociation(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no route table associations (test for nil values)", + dirName: "aws_ec2_route_table_association_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{ + { + RouteTableId: awssdk.String("assoc_with_nil"), + Associations: []*ec2.RouteTableAssociation{ + { + AssociationState: nil, + GatewayId: nil, + Main: nil, + RouteTableAssociationId: nil, + RouteTableId: nil, + SubnetId: nil, + }, + }, + }, + {RouteTableId: awssdk.String("nil_assoc")}, + }, nil) + }, + }, + { + test: "multiple route table associations (mixed subnet and gateway associations)", + dirName: "aws_ec2_route_table_association_multiple", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{ + { + RouteTableId: awssdk.String("rtb-05aa6c5673311a17b"), // route + Associations: []*ec2.RouteTableAssociation{ + { // Should be ignored + AssociationState: &ec2.RouteTableAssociationState{ + State: awssdk.String("disassociated"), + }, + GatewayId: awssdk.String("dummy-id"), + }, + { // Should be ignored + SubnetId: nil, + GatewayId: nil, + }, + { // assoc_route_subnet1 + AssociationState: &ec2.RouteTableAssociationState{ + State: awssdk.String("associated"), + }, + Main: awssdk.Bool(false), + RouteTableAssociationId: awssdk.String("rtbassoc-0809598f92dbec03b"), + RouteTableId: awssdk.String("rtb-05aa6c5673311a17b"), + SubnetId: awssdk.String("subnet-05185af647b2eeda3"), + }, + { // assoc_route_subnet + AssociationState: &ec2.RouteTableAssociationState{ + State: awssdk.String("associated"), + }, + Main: awssdk.Bool(false), + RouteTableAssociationId: awssdk.String("rtbassoc-01957791b2cfe6ea4"), + RouteTableId: awssdk.String("rtb-05aa6c5673311a17b"), + SubnetId: awssdk.String("subnet-0e93dbfa2e5dd8282"), + }, + { // assoc_route_subnet2 + AssociationState: &ec2.RouteTableAssociationState{ + State: awssdk.String("associated"), + }, + GatewayId: nil, + Main: awssdk.Bool(false), + RouteTableAssociationId: awssdk.String("rtbassoc-0b4f97ea57490e213"), + RouteTableId: awssdk.String("rtb-05aa6c5673311a17b"), + SubnetId: awssdk.String("subnet-0fd966efd884d0362"), + }, + }, + }, + { + RouteTableId: awssdk.String("rtb-09df7cc9d16de9f8f"), // route2 + Associations: []*ec2.RouteTableAssociation{ + { // assoc_route2_gateway + AssociationState: &ec2.RouteTableAssociationState{ + State: awssdk.String("associated"), + }, + RouteTableAssociationId: awssdk.String("rtbassoc-0a79ccacfceb4944b"), + RouteTableId: awssdk.String("rtb-09df7cc9d16de9f8f"), + GatewayId: awssdk.String("igw-0238f6e09185ac954"), + }, + }, + }, + }, nil) + }, + }, + { + test: "cannot list route table associations", + dirName: "aws_ec2_route_table_association_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllRouteTables").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsRouteTableAssociationResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRouteTableAssociationResourceType, resourceaws.AwsRouteTableResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2RouteTableAssociationEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsRouteTableAssociationResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsRouteTableAssociationResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsRouteTableAssociationResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2Subnet(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no subnets", + dirName: "aws_ec2_subnet_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllSubnets").Return([]*ec2.Subnet{}, []*ec2.Subnet{}, nil) + }, + }, + { + test: "multiple subnets", + dirName: "aws_ec2_subnet_multiple", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllSubnets").Return([]*ec2.Subnet{ + { + SubnetId: awssdk.String("subnet-05810d3f933925f6d"), // subnet1 + DefaultForAz: awssdk.Bool(false), + }, + { + SubnetId: awssdk.String("subnet-0b13f1e0eacf67424"), // subnet2 + DefaultForAz: awssdk.Bool(false), + }, + { + SubnetId: awssdk.String("subnet-0c9b78001fe186e22"), // subnet3 + DefaultForAz: awssdk.Bool(false), + }, + }, []*ec2.Subnet{ + { + SubnetId: awssdk.String("subnet-44fe0c65"), // us-east-1a + DefaultForAz: awssdk.Bool(true), + }, + { + SubnetId: awssdk.String("subnet-65e16628"), // us-east-1b + DefaultForAz: awssdk.Bool(true), + }, + { + SubnetId: awssdk.String("subnet-afa656f0"), // us-east-1c + DefaultForAz: awssdk.Bool(true), + }, + }, nil) + }, + }, + { + test: "cannot list subnets", + dirName: "aws_ec2_subnet_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllSubnets").Return(nil, nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsSubnetResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSubnetResourceType, resourceaws.AwsSubnetResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2SubnetEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsSubnetResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsSubnetResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsSubnetResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2DefaultSubnet(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no default subnets", + dirName: "aws_ec2_default_subnet_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllSubnets").Return([]*ec2.Subnet{}, []*ec2.Subnet{}, nil) + }, + }, + { + test: "multiple default subnets", + dirName: "aws_ec2_default_subnet_multiple", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllSubnets").Return([]*ec2.Subnet{ + { + SubnetId: awssdk.String("subnet-05810d3f933925f6d"), // subnet1 + DefaultForAz: awssdk.Bool(false), + }, + { + SubnetId: awssdk.String("subnet-0b13f1e0eacf67424"), // subnet2 + DefaultForAz: awssdk.Bool(false), + }, + { + SubnetId: awssdk.String("subnet-0c9b78001fe186e22"), // subnet3 + DefaultForAz: awssdk.Bool(false), + }, + }, []*ec2.Subnet{ + { + SubnetId: awssdk.String("subnet-44fe0c65"), // us-east-1a + DefaultForAz: awssdk.Bool(true), + }, + { + SubnetId: awssdk.String("subnet-65e16628"), // us-east-1b + DefaultForAz: awssdk.Bool(true), + }, + { + SubnetId: awssdk.String("subnet-afa656f0"), // us-east-1c + DefaultForAz: awssdk.Bool(true), + }, + }, nil) + }, + }, + { + test: "cannot list default subnets", + dirName: "aws_ec2_default_subnet_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllSubnets").Return(nil, nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsDefaultSubnetResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsDefaultSubnetResourceType, resourceaws.AwsDefaultSubnetResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2DefaultSubnetEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsDefaultSubnetResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsDefaultSubnetResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsDefaultSubnetResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2RouteTable(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no route tables", + dirName: "aws_ec2_route_table_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{}, nil) + }, + }, + { + test: "multiple route tables", + dirName: "aws_ec2_route_table_multiple", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{ + {RouteTableId: awssdk.String("rtb-08b7b71af15e183ce")}, // table1 + {RouteTableId: awssdk.String("rtb-0002ac731f6fdea55")}, // table2 + {RouteTableId: awssdk.String("rtb-0c55d55593f33fbac")}, // table3 + { + RouteTableId: awssdk.String("rtb-0eabf071c709c0976"), // default_table + VpcId: awssdk.String("vpc-0b4a6b3536da20ecd"), + Associations: []*ec2.RouteTableAssociation{ + { + Main: awssdk.Bool(true), + }, + }, + }, + }, nil) + }, + }, + { + test: "cannot list route tables", + dirName: "aws_ec2_route_table_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllRouteTables").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsRouteTableResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRouteTableResourceType, resourceaws.AwsRouteTableResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2RouteTableEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsRouteTableResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsRouteTableResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsRouteTableResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2DefaultRouteTable(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no default route tables", + dirName: "aws_ec2_default_route_table_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{}, nil) + }, + }, + { + test: "multiple default route tables", + dirName: "aws_ec2_default_route_table_single", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{ + {RouteTableId: awssdk.String("rtb-08b7b71af15e183ce")}, // table1 + {RouteTableId: awssdk.String("rtb-0002ac731f6fdea55")}, // table2 + {RouteTableId: awssdk.String("rtb-0c55d55593f33fbac")}, // table3 + { + RouteTableId: awssdk.String("rtb-0eabf071c709c0976"), // default_table + VpcId: awssdk.String("vpc-0b4a6b3536da20ecd"), + Associations: []*ec2.RouteTableAssociation{ + { + Main: awssdk.Bool(true), + }, + }, + }, + }, nil) + }, + }, + { + test: "cannot list default route tables", + dirName: "aws_ec2_default_route_table_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllRouteTables").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsDefaultRouteTableResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsDefaultRouteTableResourceType, resourceaws.AwsDefaultRouteTableResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2DefaultRouteTableEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsDefaultRouteTableResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsDefaultRouteTableResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsDefaultRouteTableResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestVpcSecurityGroup(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no security groups", + dirName: "aws_vpc_security_group_empty", + mocks: func(client *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllSecurityGroups").Once().Return([]*ec2.SecurityGroup{}, []*ec2.SecurityGroup{}, nil) + }, + wantErr: nil, + }, + { + test: "with security groups", + dirName: "aws_vpc_security_group_multiple", + mocks: func(client *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllSecurityGroups").Once().Return([]*ec2.SecurityGroup{ + { + GroupId: awssdk.String("sg-0254c038e32f25530"), + GroupName: awssdk.String("foo"), + }, + }, []*ec2.SecurityGroup{ + { + GroupId: awssdk.String("sg-9e0204ff"), + GroupName: awssdk.String("default"), + }, + }, nil) + }, + wantErr: nil, + }, + { + test: "cannot list security groups", + dirName: "aws_vpc_security_group_empty", + mocks: func(client *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + client.On("ListAllSecurityGroups").Return(nil, nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsSecurityGroupResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSecurityGroupResourceType, resourceaws.AwsSecurityGroupResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewVPCSecurityGroupEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsSecurityGroupResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsSecurityGroupResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsSecurityGroupResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestVpcDefaultSecurityGroup(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no security groups", + dirName: "aws_vpc_default_security_group_empty", + mocks: func(client *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllSecurityGroups").Once().Return([]*ec2.SecurityGroup{}, []*ec2.SecurityGroup{}, nil) + }, + wantErr: nil, + }, + { + test: "with security groups", + dirName: "aws_vpc_default_security_group_multiple", + mocks: func(client *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllSecurityGroups").Once().Return([]*ec2.SecurityGroup{ + { + GroupId: awssdk.String("sg-0254c038e32f25530"), + GroupName: awssdk.String("foo"), + }, + }, []*ec2.SecurityGroup{ + { + GroupId: awssdk.String("sg-9e0204ff"), + GroupName: awssdk.String("default"), + }, + }, nil) + }, + wantErr: nil, + }, + { + test: "cannot list security groups", + dirName: "aws_vpc_default_security_group_empty", + mocks: func(client *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + client.On("ListAllSecurityGroups").Return(nil, nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsDefaultSecurityGroupResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsDefaultSecurityGroupResourceType, resourceaws.AwsDefaultSecurityGroupResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewVPCDefaultSecurityGroupEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsDefaultSecurityGroupResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsDefaultSecurityGroupResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsDefaultSecurityGroupResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2NatGateway(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no nat gateways", + dirName: "aws_ec2_nat_gateway_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllNatGateways").Return([]*ec2.NatGateway{}, nil) + }, + }, + { + test: "single nat gateway", + dirName: "aws_ec2_nat_gateway_single", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllNatGateways").Return([]*ec2.NatGateway{ + {NatGatewayId: awssdk.String("nat-0a5408508b19ef490")}, + }, nil) + }, + }, + { + test: "cannot list nat gateways", + dirName: "aws_ec2_nat_gateway_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllNatGateways").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsNatGatewayResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsNatGatewayResourceType, resourceaws.AwsNatGatewayResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2NatGatewayEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsNatGatewayResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsNatGatewayResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsNatGatewayResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2NetworkACL(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no network ACL", + dirName: "aws_ec2_network_acl_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllNetworkACLs").Return([]*ec2.NetworkAcl{}, nil) + }, + }, + { + test: "network acl", + dirName: "aws_ec2_network_acl", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllNetworkACLs").Return([]*ec2.NetworkAcl{ + { + NetworkAclId: awssdk.String("acl-043880b4682d2366b"), + IsDefault: awssdk.Bool(false), + }, + { + NetworkAclId: awssdk.String("acl-07a565dbe518c0713"), + IsDefault: awssdk.Bool(false), + }, + { + NetworkAclId: awssdk.String("acl-e88ee595"), + IsDefault: awssdk.Bool(true), + }, + }, nil) + }, + }, + { + test: "cannot list network acl", + dirName: "aws_ec2_network_acl_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllNetworkACLs").Return(nil, awsError) + + alerter.On("SendAlert", + resourceaws.AwsNetworkACLResourceType, + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteAWSTerraform, + remoteerr.NewResourceListingErrorWithType( + awsError, + resourceaws.AwsNetworkACLResourceType, + resourceaws.AwsNetworkACLResourceType, + ), + alerts.EnumerationPhase, + ), + ).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2NetworkACLEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsNetworkACLResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsNetworkACLResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsNetworkACLResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2NetworkACLRule(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no network ACL", + dirName: "aws_ec2_network_acl_rule_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllNetworkACLs").Return([]*ec2.NetworkAcl{}, nil) + }, + }, + { + test: "network acl rules", + dirName: "aws_ec2_network_acl_rule", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllNetworkACLs").Return([]*ec2.NetworkAcl{ + { + NetworkAclId: awssdk.String("acl-0ad6d657494d17ee2"), // test + IsDefault: awssdk.Bool(false), + Entries: []*ec2.NetworkAclEntry{ + { + Egress: awssdk.Bool(false), + RuleNumber: awssdk.Int64(100), + Protocol: awssdk.String("6"), // tcp + RuleAction: awssdk.String("deny"), + CidrBlock: awssdk.String("0.0.0.0/0"), + }, + { + Egress: awssdk.Bool(false), + RuleNumber: awssdk.Int64(200), + Protocol: awssdk.String("6"), // tcp + RuleAction: awssdk.String("allow"), + Ipv6CidrBlock: awssdk.String("::/0"), + }, + { + Egress: awssdk.Bool(true), + RuleNumber: awssdk.Int64(100), + Protocol: awssdk.String("17"), // udp + RuleAction: awssdk.String("allow"), + CidrBlock: awssdk.String("172.16.1.0/0"), + }, + }, + }, + { + NetworkAclId: awssdk.String("acl-0de54ef59074b622e"), // test2 + IsDefault: awssdk.Bool(false), + Entries: []*ec2.NetworkAclEntry{ + { + Egress: awssdk.Bool(false), + RuleNumber: awssdk.Int64(100), + Protocol: awssdk.String("17"), // udp + RuleAction: awssdk.String("deny"), + CidrBlock: awssdk.String("0.0.0.0/0"), + }, + { + Egress: awssdk.Bool(true), + RuleNumber: awssdk.Int64(100), + Protocol: awssdk.String("17"), // udp + RuleAction: awssdk.String("allow"), + CidrBlock: awssdk.String("172.16.1.0/0"), + }, + }, + }, + }, nil) + }, + }, + { + test: "cannot list network acl", + dirName: "aws_ec2_network_acl_rule_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllNetworkACLs").Return(nil, awsError) + + alerter.On("SendAlert", + resourceaws.AwsNetworkACLRuleResourceType, + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteAWSTerraform, + remoteerr.NewResourceListingErrorWithType( + awsError, + resourceaws.AwsNetworkACLRuleResourceType, + resourceaws.AwsNetworkACLResourceType, + ), + alerts.EnumerationPhase, + ), + ).Return() + }, + wantErr: nil, + }, + } + + version := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", version) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := version + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + } + + remoteLibrary.AddEnumerator(aws2.NewEC2NetworkACLRuleEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsNetworkACLRuleResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsNetworkACLRuleResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsNetworkACLRuleResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2DefaultNetworkACL(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no network ACL", + dirName: "aws_ec2_default_network_acl_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllNetworkACLs").Return([]*ec2.NetworkAcl{}, nil) + }, + }, + { + test: "default network acl", + dirName: "aws_ec2_default_network_acl", + mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllNetworkACLs").Return([]*ec2.NetworkAcl{ + { + NetworkAclId: awssdk.String("acl-043880b4682d2366b"), + IsDefault: awssdk.Bool(false), + }, + { + NetworkAclId: awssdk.String("acl-07a565dbe518c0713"), + IsDefault: awssdk.Bool(false), + }, + { + NetworkAclId: awssdk.String("acl-e88ee595"), + IsDefault: awssdk.Bool(true), + }, + }, nil) + }, + }, + { + test: "cannot list default network acl", + dirName: "aws_ec2_default_network_acl_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllNetworkACLs").Return(nil, awsError) + + alerter.On("SendAlert", + resourceaws.AwsDefaultNetworkACLResourceType, + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteAWSTerraform, + remoteerr.NewResourceListingErrorWithType( + awsError, + resourceaws.AwsDefaultNetworkACLResourceType, + resourceaws.AwsDefaultNetworkACLResourceType, + ), + alerts.EnumerationPhase, + ), + ).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2DefaultNetworkACLEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsDefaultNetworkACLResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsDefaultNetworkACLResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsDefaultNetworkACLResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2Route(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + // route table with no routes case is not possible + // as a default route will always be present in each route table + test: "no routes", + dirName: "aws_ec2_route_empty", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{}, nil) + }, + }, + { + test: "multiple routes (mixed default_route_table and route_table)", + dirName: "aws_ec2_route_multiple", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{ + { + RouteTableId: awssdk.String("rtb-096bdfb69309c54c3"), // table1 + Routes: []*ec2.Route{ + { + DestinationCidrBlock: awssdk.String("10.0.0.0/16"), + Origin: awssdk.String("CreateRouteTable"), // default route + }, + { + DestinationCidrBlock: awssdk.String("1.1.1.1/32"), + GatewayId: awssdk.String("igw-030e74f73bd67f21b"), + Origin: awssdk.String("CreateRoute"), + }, + { + DestinationIpv6CidrBlock: awssdk.String("::/0"), + GatewayId: awssdk.String("igw-030e74f73bd67f21b"), + Origin: awssdk.String("CreateRoute"), + }, + }, + }, + { + RouteTableId: awssdk.String("rtb-0169b0937fd963ddc"), // table2 + Routes: []*ec2.Route{ + { + DestinationCidrBlock: awssdk.String("10.0.0.0/16"), + Origin: awssdk.String("CreateRouteTable"), // default route + }, + { + DestinationCidrBlock: awssdk.String("0.0.0.0/0"), + GatewayId: awssdk.String("igw-030e74f73bd67f21b"), + Origin: awssdk.String("CreateRoute"), + }, + { + DestinationIpv6CidrBlock: awssdk.String("::/0"), + GatewayId: awssdk.String("igw-030e74f73bd67f21b"), + Origin: awssdk.String("CreateRoute"), + }, + }, + }, + { + RouteTableId: awssdk.String("rtb-02780c485f0be93c5"), // default_table + VpcId: awssdk.String("vpc-09fe5abc2309ba49d"), + Associations: []*ec2.RouteTableAssociation{ + { + Main: awssdk.Bool(true), + }, + }, + Routes: []*ec2.Route{ + { + DestinationCidrBlock: awssdk.String("10.0.0.0/16"), + Origin: awssdk.String("CreateRouteTable"), // default route + }, + { + DestinationCidrBlock: awssdk.String("10.1.1.0/24"), + GatewayId: awssdk.String("igw-030e74f73bd67f21b"), + Origin: awssdk.String("CreateRoute"), + }, + { + DestinationCidrBlock: awssdk.String("10.1.2.0/24"), + GatewayId: awssdk.String("igw-030e74f73bd67f21b"), + Origin: awssdk.String("CreateRoute"), + }, + }, + }, + { + RouteTableId: awssdk.String(""), // table3 + Routes: []*ec2.Route{ + { + DestinationCidrBlock: awssdk.String("10.0.0.0/16"), + Origin: awssdk.String("CreateRouteTable"), // default route + }, + }, + }, + }, nil) + }, + }, + { + test: "cannot list routes", + dirName: "aws_ec2_route_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllRouteTables").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsRouteResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRouteResourceType, resourceaws.AwsRouteTableResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2RouteEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsRouteResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsRouteResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsRouteResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestVpcSecurityGroupRule(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no security group rules", + dirName: "aws_vpc_security_group_rule_empty", + mocks: func(client *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllSecurityGroups").Once().Return([]*ec2.SecurityGroup{ + { + GroupId: awssdk.String("sg-0254c038e32f25530"), + IpPermissions: []*ec2.IpPermission{}, + IpPermissionsEgress: []*ec2.IpPermission{}, + }, + }, nil, nil) + }, + wantErr: nil, + }, + { + test: "with security group rules", + dirName: "aws_vpc_security_group_rule_multiple", + mocks: func(client *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllSecurityGroups").Once().Return([]*ec2.SecurityGroup{ + { + GroupId: awssdk.String("sg-0254c038e32f25530"), + IpPermissions: []*ec2.IpPermission{ + { + FromPort: awssdk.Int64(0), + ToPort: awssdk.Int64(65535), + IpProtocol: awssdk.String("tcp"), + UserIdGroupPairs: []*ec2.UserIdGroupPair{ + { + GroupId: awssdk.String("sg-0254c038e32f25530"), + }, + { + GroupId: awssdk.String("sg-9e0204ff"), + }, + }, + }, + { + IpProtocol: awssdk.String("-1"), + IpRanges: []*ec2.IpRange{ + { + CidrIp: awssdk.String("1.2.0.0/16"), + }, + { + CidrIp: awssdk.String("5.6.7.0/24"), + }, + }, + Ipv6Ranges: []*ec2.Ipv6Range{ + { + CidrIpv6: awssdk.String("::/0"), + }, + }, + }, + }, + IpPermissionsEgress: []*ec2.IpPermission{ + { + IpProtocol: awssdk.String("-1"), + IpRanges: []*ec2.IpRange{ + { + CidrIp: awssdk.String("0.0.0.0/0"), + }, + }, + Ipv6Ranges: []*ec2.Ipv6Range{ + { + CidrIpv6: awssdk.String("::/0"), + }, + }, + }, + }, + }, + { + GroupId: awssdk.String("sg-0cc8b3c3c2851705a"), + IpPermissions: []*ec2.IpPermission{ + { + FromPort: awssdk.Int64(443), + ToPort: awssdk.Int64(443), + IpProtocol: awssdk.String("tcp"), + IpRanges: []*ec2.IpRange{ + { + CidrIp: awssdk.String("0.0.0.0/0"), + }, + }, + }, + }, + IpPermissionsEgress: []*ec2.IpPermission{ + { + IpProtocol: awssdk.String("-1"), + IpRanges: []*ec2.IpRange{ + { + CidrIp: awssdk.String("0.0.0.0/0"), + }, + }, + Ipv6Ranges: []*ec2.Ipv6Range{ + { + CidrIpv6: awssdk.String("::/0"), + }, + }, + }, + { + IpProtocol: awssdk.String("5"), + IpRanges: []*ec2.IpRange{ + { + CidrIp: awssdk.String("0.0.0.0/0"), + }, + }, + }, + }, + }, + }, nil, nil) + }, + wantErr: nil, + }, + { + test: "cannot list security group rules", + dirName: "aws_vpc_security_group_rule_empty", + mocks: func(client *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + client.On("ListAllSecurityGroups").Once().Return(nil, nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsSecurityGroupRuleResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSecurityGroupRuleResourceType, resourceaws.AwsSecurityGroupResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewVPCSecurityGroupRuleEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsSecurityGroupRuleResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsSecurityGroupRuleResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsSecurityGroupRuleResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestEC2LaunchTemplate(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no launch template", + dirName: "aws_launch_template", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("DescribeLaunchTemplates").Return([]*ec2.LaunchTemplate{}, nil) + }, + }, + { + test: "multiple launch templates", + dirName: "aws_launch_template_multiple", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + launchTemplates := []*ec2.LaunchTemplate{ + {LaunchTemplateId: awssdk.String("lt-0ed993d09ce6afc67"), LatestVersionNumber: awssdk.Int64(1)}, + {LaunchTemplateId: awssdk.String("lt-00b2d18c6cee7fe23"), LatestVersionNumber: awssdk.Int64(1)}, + } + + repository.On("DescribeLaunchTemplates").Return(launchTemplates, nil) + }, + }, + { + test: "cannot list launch templates", + dirName: "aws_launch_template", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("DescribeLaunchTemplates").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsLaunchTemplateResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsLaunchTemplateResourceType, resourceaws.AwsLaunchTemplateResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewLaunchTemplateEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsLaunchTemplateResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsLaunchTemplateResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsLaunchTemplateResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestEC2EbsEncryptionByDefault(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockEC2Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no encryption by default resource", + dirName: "aws_ebs_encryption_by_default_list", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + repository.On("IsEbsEncryptionEnabledByDefault").Return(false, nil) + }, + }, + { + test: "cannot list encryption by default resources", + dirName: "aws_ebs_encryption_by_default_error", + mocks: func(repository *repository2.MockEC2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("IsEbsEncryptionEnabledByDefault").Return(false, awsError) + + alerter.On("SendAlert", resourceaws.AwsEbsEncryptionByDefaultResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsEbsEncryptionByDefaultResourceType, resourceaws.AwsEbsEncryptionByDefaultResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockEC2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.EC2Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewEC2Repository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewEC2EbsEncryptionByDefaultEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsEbsEncryptionByDefaultResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsEbsEncryptionByDefaultResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsEbsEncryptionByDefaultResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_ecr_scanner_test.go b/enumeration/remote/aws_ecr_scanner_test.go new file mode 100644 index 00000000..3f4b7e46 --- /dev/null +++ b/enumeration/remote/aws_ecr_scanner_test.go @@ -0,0 +1,199 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + aws2 "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ecr" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestECRRepository(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockECRRepository, *mocks.AlerterInterface) + err error + }{ + { + test: "no repository", + dirName: "aws_ecr_repository_empty", + mocks: func(client *repository2.MockECRRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllRepositories").Return([]*ecr.Repository{}, nil) + }, + err: nil, + }, + { + test: "multiple repositories", + dirName: "aws_ecr_repository_multiple", + mocks: func(client *repository2.MockECRRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllRepositories").Return([]*ecr.Repository{ + {RepositoryName: awssdk.String("test_ecr")}, + {RepositoryName: awssdk.String("bar")}, + }, nil) + }, + err: nil, + }, + { + test: "cannot list repository", + dirName: "aws_ecr_repository_empty", + mocks: func(client *repository2.MockECRRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + client.On("ListAllRepositories").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsEcrRepositoryResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsEcrRepositoryResourceType, resourceaws.AwsEcrRepositoryResourceType), alerts.EnumerationPhase)).Return() + }, + err: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockECRRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ECRRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewECRRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewECRRepositoryEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsEcrRepositoryResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsEcrRepositoryResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsEcrRepositoryResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestECRRepositoryPolicy(t *testing.T) { + tests := []struct { + test string + mocks func(*repository2.MockECRRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + err error + }{ + { + test: "single repository policy", + mocks: func(client *repository2.MockECRRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllRepositories").Return([]*ecr.Repository{ + {RepositoryName: awssdk.String("test_ecr_repo_policy")}, + {RepositoryName: awssdk.String("test_ecr_repo_without_policy")}, + }, nil) + client.On("GetRepositoryPolicy", &ecr.Repository{ + RepositoryName: awssdk.String("test_ecr_repo_policy"), + }).Return(&ecr.GetRepositoryPolicyOutput{ + RegistryId: awssdk.String("1"), + RepositoryName: awssdk.String("test_ecr_repo_policy"), + }, nil) + client.On("GetRepositoryPolicy", &ecr.Repository{ + RepositoryName: awssdk.String("test_ecr_repo_without_policy"), + }).Return(nil, &ecr.RepositoryPolicyNotFoundException{}) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + + assert.Equal(t, got[0].ResourceId(), "test_ecr_repo_policy") + }, + err: nil, + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockECRRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ECRRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewECRRepositoryPolicyEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_elasticache_scanner_test.go b/enumeration/remote/aws_elasticache_scanner_test.go new file mode 100644 index 00000000..299a6d2f --- /dev/null +++ b/enumeration/remote/aws_elasticache_scanner_test.go @@ -0,0 +1,111 @@ +package remote + +import ( + "errors" + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/elasticache" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestElastiCacheCluster(t *testing.T) { + dummyError := errors.New("dummy error") + + tests := []struct { + test string + mocks func(*repository2.MockElastiCacheRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no elasticache clusters", + mocks: func(repository *repository2.MockElastiCacheRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllCacheClusters").Return([]*elasticache.CacheCluster{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "should list elasticache clusters", + mocks: func(repository *repository2.MockElastiCacheRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllCacheClusters").Return([]*elasticache.CacheCluster{ + {CacheClusterId: awssdk.String("cluster-foo")}, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + assert.Equal(t, got[0].ResourceId(), "cluster-foo") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsElastiCacheClusterResourceType) + }, + }, + { + test: "cannot list elasticache clusters (403)", + mocks: func(repository *repository2.MockElastiCacheRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllCacheClusters").Return(nil, awsError) + alerter.On("SendAlert", resourceaws.AwsElastiCacheClusterResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsElastiCacheClusterResourceType, resourceaws.AwsElastiCacheClusterResourceType), alerts.EnumerationPhase)).Return() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "cannot list elasticache clusters (dummy error)", + mocks: func(repository *repository2.MockElastiCacheRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllCacheClusters").Return(nil, dummyError) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + wantErr: remoteerr.NewResourceScanningError(dummyError, resourceaws.AwsElastiCacheClusterResourceType, ""), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockElastiCacheRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ElastiCacheRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws.NewElastiCacheClusterEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_elb_scanner_test.go b/enumeration/remote/aws_elb_scanner_test.go new file mode 100644 index 00000000..f4ae592e --- /dev/null +++ b/enumeration/remote/aws_elb_scanner_test.go @@ -0,0 +1,118 @@ +package remote + +import ( + "errors" + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/elb" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestELB_LoadBalancer(t *testing.T) { + dummyError := errors.New("dummy error") + + tests := []struct { + test string + mocks func(*repository2.MockELBRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no load balancer", + mocks: func(repository *repository2.MockELBRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllLoadBalancers").Return([]*elb.LoadBalancerDescription{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "should list load balancers", + mocks: func(repository *repository2.MockELBRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllLoadBalancers").Return([]*elb.LoadBalancerDescription{ + { + LoadBalancerName: awssdk.String("acc-test-lb-tf"), + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + assert.Equal(t, "acc-test-lb-tf", got[0].ResourceId()) + assert.Equal(t, resourceaws.AwsClassicLoadBalancerResourceType, got[0].ResourceType()) + }, + }, + { + test: "cannot list load balancers", + mocks: func(repository *repository2.MockELBRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllLoadBalancers").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsClassicLoadBalancerResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsClassicLoadBalancerResourceType, resourceaws.AwsClassicLoadBalancerResourceType), alerts.EnumerationPhase)).Return() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "cannot list load balancers (dummy error)", + mocks: func(repository *repository2.MockELBRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllLoadBalancers").Return(nil, dummyError) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + wantErr: remoteerr.NewResourceScanningError(dummyError, resourceaws.AwsClassicLoadBalancerResourceType, ""), + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockELBRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ELBRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws.NewClassicLoadBalancerEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_elbv2_scanner_test.go b/enumeration/remote/aws_elbv2_scanner_test.go new file mode 100644 index 00000000..a59d53f7 --- /dev/null +++ b/enumeration/remote/aws_elbv2_scanner_test.go @@ -0,0 +1,247 @@ +package remote + +import ( + "errors" + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + aws2 "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/elbv2" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestELBV2_LoadBalancer(t *testing.T) { + dummyError := errors.New("dummy error") + + tests := []struct { + test string + mocks func(*repository2.MockELBV2Repository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no load balancer", + mocks: func(repository *repository2.MockELBV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllLoadBalancers").Return([]*elbv2.LoadBalancer{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "should list load balancers", + mocks: func(repository *repository2.MockELBV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllLoadBalancers").Return([]*elbv2.LoadBalancer{ + { + LoadBalancerArn: awssdk.String("arn:aws:elasticloadbalancing:us-east-1:533948124879:loadbalancer/app/acc-test-lb-tf/9114c60e08560420"), + LoadBalancerName: awssdk.String("acc-test-lb-tf"), + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + assert.Equal(t, "arn:aws:elasticloadbalancing:us-east-1:533948124879:loadbalancer/app/acc-test-lb-tf/9114c60e08560420", got[0].ResourceId()) + assert.Equal(t, resourceaws.AwsLoadBalancerResourceType, got[0].ResourceType()) + }, + }, + { + test: "cannot list load balancers (403)", + mocks: func(repository *repository2.MockELBV2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllLoadBalancers").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsLoadBalancerResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsLoadBalancerResourceType, resourceaws.AwsLoadBalancerResourceType), alerts.EnumerationPhase)).Return() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "cannot list load balancers (dummy error)", + mocks: func(repository *repository2.MockELBV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllLoadBalancers").Return(nil, dummyError) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + wantErr: remoteerr.NewResourceScanningError(dummyError, resourceaws.AwsLoadBalancerResourceType, ""), + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockELBV2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ELBV2Repository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewLoadBalancerEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestELBV2_LoadBalancerListener(t *testing.T) { + dummyError := errors.New("dummy error") + + tests := []struct { + test string + mocks func(*repository2.MockELBV2Repository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no load balancer listener", + mocks: func(repository *repository2.MockELBV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllLoadBalancers").Return([]*elbv2.LoadBalancer{ + { + LoadBalancerArn: awssdk.String("test-lb"), + }, + }, nil) + repository.On("ListAllLoadBalancerListeners", "test-lb").Return([]*elbv2.Listener{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "should list load balancer listener", + mocks: func(repository *repository2.MockELBV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllLoadBalancers").Return([]*elbv2.LoadBalancer{ + { + LoadBalancerArn: awssdk.String("test-lb"), + }, + }, nil) + + repository.On("ListAllLoadBalancerListeners", "test-lb").Return([]*elbv2.Listener{ + { + ListenerArn: awssdk.String("test-lb-listener-1"), + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + assert.Equal(t, "test-lb-listener-1", got[0].ResourceId()) + assert.Equal(t, resourceaws.AwsLoadBalancerListenerResourceType, got[0].ResourceType()) + }, + }, + { + test: "cannot list load balancer listeners (403)", + mocks: func(repository *repository2.MockELBV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllLoadBalancers").Return([]*elbv2.LoadBalancer{ + { + LoadBalancerArn: awssdk.String("test-lb"), + }, + }, nil) + + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllLoadBalancerListeners", "test-lb").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsLoadBalancerListenerResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingError(awsError, resourceaws.AwsLoadBalancerListenerResourceType), alerts.EnumerationPhase)).Return() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "cannot list load balancers (403)", + mocks: func(repository *repository2.MockELBV2Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllLoadBalancers").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsLoadBalancerListenerResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsLoadBalancerListenerResourceType, resourceaws.AwsLoadBalancerResourceType), alerts.EnumerationPhase)).Return() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "cannot list load balancer listeners (dummy error)", + mocks: func(repository *repository2.MockELBV2Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllLoadBalancers").Return([]*elbv2.LoadBalancer{ + { + LoadBalancerArn: awssdk.String("test-lb"), + }, + }, nil) + + repository.On("ListAllLoadBalancerListeners", "test-lb").Return(nil, dummyError) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + wantErr: remoteerr.NewResourceScanningError(dummyError, resourceaws.AwsLoadBalancerListenerResourceType, ""), + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockELBV2Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ELBV2Repository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewLoadBalancerListenerEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_iam_scanner_test.go b/enumeration/remote/aws_iam_scanner_test.go new file mode 100644 index 00000000..78adf775 --- /dev/null +++ b/enumeration/remote/aws_iam_scanner_test.go @@ -0,0 +1,1324 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + aws2 "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/iam" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestIamUser(t *testing.T) { + + cases := []struct { + test string + dirName string + mocks func(*repository2.MockIAMRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no iam user", + dirName: "aws_iam_user_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllUsers").Return([]*iam.User{}, nil) + }, + wantErr: nil, + }, + { + test: "iam multiples users", + dirName: "aws_iam_user_multiple", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllUsers").Return([]*iam.User{ + { + UserName: aws.String("test-driftctl-0"), + }, + { + UserName: aws.String("test-driftctl-1"), + }, + { + UserName: aws.String("test-driftctl-2"), + }, + }, nil) + }, + wantErr: nil, + }, + { + test: "cannot list iam user", + dirName: "aws_iam_user_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repo.On("ListAllUsers").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsIamUserResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamUserResourceType, resourceaws.AwsIamUserResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockIAMRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.IAMRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewIAMRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewIamUserEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamUserResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsIamUserResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsIamUserResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestIamUserPolicy(t *testing.T) { + + cases := []struct { + test string + dirName string + mocks func(*repository2.MockIAMRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no iam user policy", + dirName: "aws_iam_user_policy_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + users := []*iam.User{ + { + UserName: aws.String("loadbalancer"), + }, + } + repo.On("ListAllUsers").Return(users, nil) + repo.On("ListAllUserPolicies", users).Return([]string{}, nil) + }, + wantErr: nil, + }, + { + test: "iam multiples users multiple policies", + dirName: "aws_iam_user_policy_multiple", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + users := []*iam.User{ + { + UserName: aws.String("loadbalancer"), + }, + { + UserName: aws.String("loadbalancer2"), + }, + { + UserName: aws.String("loadbalancer3"), + }, + } + repo.On("ListAllUsers").Return(users, nil) + repo.On("ListAllUserPolicies", users).Once().Return([]string{ + *aws.String("loadbalancer:test"), + *aws.String("loadbalancer:test2"), + *aws.String("loadbalancer:test3"), + *aws.String("loadbalancer:test4"), + *aws.String("loadbalancer2:test2"), + *aws.String("loadbalancer2:test22"), + *aws.String("loadbalancer2:test23"), + *aws.String("loadbalancer2:test24"), + *aws.String("loadbalancer3:test3"), + *aws.String("loadbalancer3:test32"), + *aws.String("loadbalancer3:test33"), + *aws.String("loadbalancer3:test34"), + }, nil) + }, + wantErr: nil, + }, + { + test: "cannot list user", + dirName: "aws_iam_user_policy_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repo.On("ListAllUsers").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsIamUserPolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamUserPolicyResourceType, resourceaws.AwsIamUserResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + { + test: "cannot list user policy", + dirName: "aws_iam_user_policy_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllUsers").Once().Return([]*iam.User{}, nil) + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repo.On("ListAllUserPolicies", mock.Anything).Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsIamUserPolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamUserPolicyResourceType, resourceaws.AwsIamUserPolicyResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockIAMRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.IAMRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewIAMRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewIamUserPolicyEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamUserPolicyResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsIamUserPolicyResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsIamUserPolicyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestIamPolicy(t *testing.T) { + + cases := []struct { + test string + dirName string + mocks func(*repository2.MockIAMRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no iam custom policies", + dirName: "aws_iam_policy_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllPolicies").Once().Return([]*iam.Policy{}, nil) + }, + wantErr: nil, + }, + { + test: "iam multiples custom policies", + dirName: "aws_iam_policy_multiple", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllPolicies").Once().Return([]*iam.Policy{ + { + Arn: aws.String("arn:aws:iam::929327065333:policy/policy-0"), + }, + { + Arn: aws.String("arn:aws:iam::929327065333:policy/policy-1"), + }, + { + Arn: aws.String("arn:aws:iam::929327065333:policy/policy-2"), + }, + }, nil) + }, + wantErr: nil, + }, + { + test: "cannot list iam custom policies", + dirName: "aws_iam_policy_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repo.On("ListAllPolicies").Once().Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsIamPolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamPolicyResourceType, resourceaws.AwsIamPolicyResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockIAMRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.IAMRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewIAMRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewIamPolicyEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamPolicyResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsIamPolicyResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsIamPolicyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestIamRole(t *testing.T) { + + cases := []struct { + test string + dirName string + mocks func(*repository2.MockIAMRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no iam roles", + dirName: "aws_iam_role_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRoles").Return([]*iam.Role{}, nil) + }, + wantErr: nil, + }, + { + test: "iam multiples roles", + dirName: "aws_iam_role_multiple", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRoles").Return([]*iam.Role{ + { + RoleName: aws.String("test_role_0"), + Path: aws.String("/"), + }, + { + RoleName: aws.String("test_role_1"), + Path: aws.String("/"), + }, + { + RoleName: aws.String("test_role_2"), + Path: aws.String("/"), + }, + }, nil) + }, + wantErr: nil, + }, + { + test: "iam roles ignore services roles", + dirName: "aws_iam_role_ignore_services_roles", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRoles").Return([]*iam.Role{ + { + RoleName: aws.String("AWSServiceRoleForOrganizations"), + Path: aws.String("/aws-service-role/organizations.amazonaws.com/"), + }, + { + RoleName: aws.String("AWSServiceRoleForSupport"), + Path: aws.String("/aws-service-role/support.amazonaws.com/"), + }, + { + RoleName: aws.String("AWSServiceRoleForTrustedAdvisor"), + Path: aws.String("/aws-service-role/trustedadvisor.amazonaws.com/"), + }, + }, nil) + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockIAMRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.IAMRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewIAMRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewIamRoleEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamRoleResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsIamRoleResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsIamRoleResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestIamRolePolicyAttachment(t *testing.T) { + + cases := []struct { + test string + dirName string + mocks func(*repository2.MockIAMRepository, *mocks.AlerterInterface) + err error + }{ + { + test: "no iam role policy", + dirName: "aws_aws_iam_role_policy_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + roles := []*iam.Role{ + { + RoleName: aws.String("test-role"), + }, + } + repo.On("ListAllRoles").Return(roles, nil) + repo.On("ListAllRolePolicyAttachments", roles).Return([]*repository2.AttachedRolePolicy{}, nil) + }, + err: nil, + }, + { + test: "iam multiples roles multiple policies", + dirName: "aws_iam_role_policy_attachment_multiple", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + roles := []*iam.Role{ + { + RoleName: aws.String("test-role"), + }, + { + RoleName: aws.String("test-role2"), + }, + } + repo.On("ListAllRoles").Return(roles, nil) + repo.On("ListAllRolePolicyAttachments", roles).Return([]*repository2.AttachedRolePolicy{ + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::929327065333:policy/test-policy"), + PolicyName: aws.String("test-policy"), + }, + RoleName: *aws.String("test-role"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::929327065333:policy/test-policy2"), + PolicyName: aws.String("test-policy2"), + }, + RoleName: *aws.String("test-role"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::929327065333:policy/test-policy3"), + PolicyName: aws.String("test-policy3"), + }, + RoleName: *aws.String("test-role"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::929327065333:policy/test-policy"), + PolicyName: aws.String("test-policy"), + }, + RoleName: *aws.String("test-role2"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::929327065333:policy/test-policy2"), + PolicyName: aws.String("test-policy2"), + }, + RoleName: *aws.String("test-role2"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::929327065333:policy/test-policy3"), + PolicyName: aws.String("test-policy3"), + }, + RoleName: *aws.String("test-role2"), + }, + }, nil) + }, + err: nil, + }, + { + test: "iam multiples roles for ignored roles", + dirName: "aws_iam_role_policy_attachment_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + roles := []*iam.Role{ + { + RoleName: aws.String("AWSServiceRoleForSupport"), + }, + { + RoleName: aws.String("AWSServiceRoleForOrganizations"), + }, + { + RoleName: aws.String("AWSServiceRoleForTrustedAdvisor"), + }, + } + repo.On("ListAllRoles").Return(roles, nil) + }, + }, + { + test: "Cannot list roles", + dirName: "aws_iam_role_policy_attachment_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repo.On("ListAllRoles").Once().Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsIamRolePolicyAttachmentResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamRolePolicyAttachmentResourceType, resourceaws.AwsIamRoleResourceType), alerts.EnumerationPhase)).Return() + }, + }, + { + test: "Cannot list roles policy attachment", + dirName: "aws_iam_role_policy_attachment_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRoles").Once().Return([]*iam.Role{{RoleName: aws.String("test")}}, nil) + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repo.On("ListAllRolePolicyAttachments", mock.Anything).Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsIamRolePolicyAttachmentResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamRolePolicyAttachmentResourceType, resourceaws.AwsIamRolePolicyAttachmentResourceType), alerts.EnumerationPhase)).Return() + }, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockIAMRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.IAMRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewIAMRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewIamRolePolicyAttachmentEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamRolePolicyAttachmentResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsIamRolePolicyAttachmentResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.err, err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsIamRolePolicyAttachmentResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestIamAccessKey(t *testing.T) { + + cases := []struct { + test string + dirName string + mocks func(*repository2.MockIAMRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no iam access_key", + dirName: "aws_iam_access_key_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + users := []*iam.User{ + { + UserName: aws.String("test-driftctl"), + }, + } + repo.On("ListAllUsers").Return(users, nil) + repo.On("ListAllAccessKeys", users).Return([]*iam.AccessKeyMetadata{}, nil) + }, + wantErr: nil, + }, + { + test: "iam multiples keys for multiples users", + dirName: "aws_iam_access_key_multiple", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + users := []*iam.User{ + { + UserName: aws.String("test-driftctl"), + }, + } + repo.On("ListAllUsers").Return(users, nil) + repo.On("ListAllAccessKeys", users).Return([]*iam.AccessKeyMetadata{ + { + AccessKeyId: aws.String("AKIA5QYBVVD223VWU32A"), + UserName: aws.String("test-driftctl"), + }, + { + AccessKeyId: aws.String("AKIA5QYBVVD2QYI36UZP"), + UserName: aws.String("test-driftctl"), + }, + { + AccessKeyId: aws.String("AKIA5QYBVVD26EJME25D"), + UserName: aws.String("test-driftctl2"), + }, + { + AccessKeyId: aws.String("AKIA5QYBVVD2SWDFVVMG"), + UserName: aws.String("test-driftctl2"), + }, + }, nil) + }, + wantErr: nil, + }, + { + test: "Cannot list iam user", + dirName: "aws_iam_access_key_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repo.On("ListAllUsers").Once().Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsIamAccessKeyResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamAccessKeyResourceType, resourceaws.AwsIamUserResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + { + test: "Cannot list iam access_key", + dirName: "aws_iam_access_key_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllUsers").Once().Return([]*iam.User{}, nil) + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repo.On("ListAllAccessKeys", mock.Anything).Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsIamAccessKeyResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamAccessKeyResourceType, resourceaws.AwsIamAccessKeyResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockIAMRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.IAMRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewIAMRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewIamAccessKeyEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamAccessKeyResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsIamAccessKeyResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsIamAccessKeyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestIamUserPolicyAttachment(t *testing.T) { + + cases := []struct { + test string + dirName string + mocks func(*repository2.MockIAMRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no iam user policy", + dirName: "aws_iam_user_policy_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + users := []*iam.User{ + { + UserName: aws.String("loadbalancer"), + }, + } + repo.On("ListAllUsers").Return(users, nil) + repo.On("ListAllUserPolicyAttachments", users).Return([]*repository2.AttachedUserPolicy{}, nil) + }, + wantErr: nil, + }, + { + test: "iam multiples users multiple policies", + dirName: "aws_iam_user_policy_attachment_multiple", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + users := []*iam.User{ + { + UserName: aws.String("loadbalancer"), + }, + { + UserName: aws.String("loadbalancer2"), + }, + { + UserName: aws.String("loadbalancer3"), + }, + } + repo.On("ListAllUsers").Return(users, nil) + repo.On("ListAllUserPolicyAttachments", users).Return([]*repository2.AttachedUserPolicy{ + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test"), + PolicyName: aws.String("test"), + }, + UserName: *aws.String("loadbalancer"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test2"), + PolicyName: aws.String("test2"), + }, + UserName: *aws.String("loadbalancer"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test3"), + PolicyName: aws.String("test3"), + }, + UserName: *aws.String("loadbalancer"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test4"), + PolicyName: aws.String("test4"), + }, + UserName: *aws.String("loadbalancer"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test"), + PolicyName: aws.String("test"), + }, + UserName: *aws.String("loadbalancer2"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test2"), + PolicyName: aws.String("test2"), + }, + UserName: *aws.String("loadbalancer2"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test3"), + PolicyName: aws.String("test3"), + }, + UserName: *aws.String("loadbalancer2"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test4"), + PolicyName: aws.String("test4"), + }, + UserName: *aws.String("loadbalancer2"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test"), + PolicyName: aws.String("test"), + }, + UserName: *aws.String("loadbalancer3"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test2"), + PolicyName: aws.String("test2"), + }, + UserName: *aws.String("loadbalancer3"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test3"), + PolicyName: aws.String("test3"), + }, + UserName: *aws.String("loadbalancer3"), + }, + { + AttachedPolicy: iam.AttachedPolicy{ + PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test4"), + PolicyName: aws.String("test4"), + }, + UserName: *aws.String("loadbalancer3"), + }, + }, nil) + + }, + wantErr: nil, + }, + { + test: "cannot list user", + dirName: "aws_iam_user_policy_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repo.On("ListAllUsers").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsIamUserPolicyAttachmentResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamUserPolicyAttachmentResourceType, resourceaws.AwsIamUserResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + { + test: "cannot list user policies attachment", + dirName: "aws_iam_user_policy_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllUsers").Once().Return([]*iam.User{}, nil) + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repo.On("ListAllUserPolicyAttachments", mock.Anything).Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsIamUserPolicyAttachmentResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamUserPolicyAttachmentResourceType, resourceaws.AwsIamUserPolicyAttachmentResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockIAMRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.IAMRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewIAMRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewIamUserPolicyAttachmentEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamUserPolicyAttachmentResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsIamUserPolicyAttachmentResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsIamUserPolicyAttachmentResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestIamRolePolicy(t *testing.T) { + + cases := []struct { + test string + dirName string + mocks func(*repository2.MockIAMRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no iam role policy", + dirName: "aws_iam_role_policy_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + roles := []*iam.Role{ + { + RoleName: aws.String("test_role"), + }, + } + repo.On("ListAllRoles").Return(roles, nil) + repo.On("ListAllRolePolicies", roles).Return([]repository2.RolePolicy{}, nil) + }, + wantErr: nil, + }, + { + test: "multiples roles with inline policies", + dirName: "aws_iam_role_policy_multiple", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + roles := []*iam.Role{ + { + RoleName: aws.String("test_role_0"), + }, + { + RoleName: aws.String("test_role_1"), + }, + } + repo.On("ListAllRoles").Return(roles, nil) + repo.On("ListAllRolePolicies", roles).Return([]repository2.RolePolicy{ + {Policy: "policy-role0-0", RoleName: "test_role_0"}, + {Policy: "policy-role0-1", RoleName: "test_role_0"}, + {Policy: "policy-role0-2", RoleName: "test_role_0"}, + {Policy: "policy-role1-0", RoleName: "test_role_1"}, + {Policy: "policy-role1-1", RoleName: "test_role_1"}, + {Policy: "policy-role1-2", RoleName: "test_role_1"}, + }, nil).Once() + }, + wantErr: nil, + }, + { + test: "Cannot list roles", + dirName: "aws_iam_role_policy_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repo.On("ListAllRoles").Once().Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsIamRolePolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamRolePolicyResourceType, resourceaws.AwsIamRoleResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + { + test: "cannot list role policy", + dirName: "aws_iam_role_policy_empty", + mocks: func(repo *repository2.MockIAMRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllRoles").Once().Return([]*iam.Role{}, nil) + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repo.On("ListAllRolePolicies", mock.Anything).Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsIamRolePolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamRolePolicyResourceType, resourceaws.AwsIamRolePolicyResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockIAMRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.IAMRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewIAMRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewIamRolePolicyEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamRolePolicyResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsIamRolePolicyResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsIamRolePolicyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestIamGroupPolicy(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockIAMRepository) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "multiple groups, with multiples policies", + mocks: func(repository *repository2.MockIAMRepository) { + repository.On("ListAllGroups").Return(nil, nil) + repository.On("ListAllGroupPolicies", []*iam.Group(nil)). + Return([]string{"group1:policy1", "group2:policy2"}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + assert.Equal(t, resourceaws.AwsIamGroupPolicyResourceType, got[0].ResourceType()) + assert.Equal(t, "group1:policy1", got[0].ResourceId()) + assert.Equal(t, resourceaws.AwsIamGroupPolicyResourceType, got[1].ResourceType()) + assert.Equal(t, "group2:policy2", got[1].ResourceId()) + }, + }, + { + test: "cannot list groups", + mocks: func(repository *repository2.MockIAMRepository) { + repository.On("ListAllGroups").Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsIamGroupPolicyResourceType, resourceaws.AwsIamGroupResourceType), + }, + { + test: "cannot list policies", + mocks: func(repository *repository2.MockIAMRepository) { + repository.On("ListAllGroups").Return(nil, nil) + repository.On("ListAllGroupPolicies", []*iam.Group(nil)).Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsIamGroupPolicyResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockIAMRepository{} + c.mocks(fakeRepo) + + var repo repository2.IAMRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewIamGroupPolicyEnumerator( + repo, factory, + )) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestIamGroup(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockIAMRepository) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "multiple groups, with multiples groups", + mocks: func(repository *repository2.MockIAMRepository) { + repository.On("ListAllGroups").Return([]*iam.Group{ + { + GroupName: aws.String("group1"), + }, + { + GroupName: aws.String("group2"), + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + assert.Equal(t, resourceaws.AwsIamGroupResourceType, got[0].ResourceType()) + assert.Equal(t, "group1", got[0].ResourceId()) + assert.Equal(t, resourceaws.AwsIamGroupResourceType, got[1].ResourceType()) + assert.Equal(t, "group2", got[1].ResourceId()) + }, + }, + { + test: "cannot list groups", + mocks: func(repository *repository2.MockIAMRepository) { + repository.On("ListAllGroups").Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsIamGroupResourceType), + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockIAMRepository{} + c.mocks(fakeRepo) + + var repo repository2.IAMRepository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewIamGroupEnumerator( + repo, factory, + )) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_kms_scanner_test.go b/enumeration/remote/aws_kms_scanner_test.go new file mode 100644 index 00000000..166f62fa --- /dev/null +++ b/enumeration/remote/aws_kms_scanner_test.go @@ -0,0 +1,226 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + aws2 "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kms" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestKMSKey(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockKMSRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no keys", + dirName: "aws_kms_key_empty", + mocks: func(repository *repository2.MockKMSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllKeys").Return([]*kms.KeyListEntry{}, nil) + }, + }, + { + test: "multiple keys", + dirName: "aws_kms_key_multiple", + mocks: func(repository *repository2.MockKMSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllKeys").Return([]*kms.KeyListEntry{ + {KeyId: awssdk.String("8ee21d91-c000-428c-8032-235aac55da36")}, + {KeyId: awssdk.String("5d765f32-bfdc-4610-b6ab-f82db5d0601b")}, + {KeyId: awssdk.String("89d2c023-ea53-40a5-b20a-d84905c622d7")}, + }, nil) + }, + }, + { + test: "cannot list keys", + dirName: "aws_kms_key_list", + mocks: func(repository *repository2.MockKMSRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllKeys").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsKmsKeyResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsKmsKeyResourceType, resourceaws.AwsKmsKeyResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockKMSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.KMSRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewKMSRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewKMSKeyEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsKmsKeyResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsKmsKeyResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsKmsKeyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestKMSAlias(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockKMSRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no aliases", + dirName: "aws_kms_alias_empty", + mocks: func(repository *repository2.MockKMSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllAliases").Return([]*kms.AliasListEntry{}, nil) + }, + }, + { + test: "multiple aliases", + dirName: "aws_kms_alias_multiple", + mocks: func(repository *repository2.MockKMSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllAliases").Return([]*kms.AliasListEntry{ + {AliasName: awssdk.String("alias/foo")}, + {AliasName: awssdk.String("alias/bar")}, + {AliasName: awssdk.String("alias/baz20210225124429210500000001")}, + }, nil) + }, + }, + { + test: "cannot list aliases", + dirName: "aws_kms_alias_list", + mocks: func(repository *repository2.MockKMSRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllAliases").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsKmsAliasResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsKmsAliasResourceType, resourceaws.AwsKmsAliasResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockKMSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.KMSRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewKMSRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewKMSAliasEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsKmsAliasResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsKmsAliasResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsKmsAliasResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_lambda_scanner_test.go b/enumeration/remote/aws_lambda_scanner_test.go new file mode 100644 index 00000000..5025d796 --- /dev/null +++ b/enumeration/remote/aws_lambda_scanner_test.go @@ -0,0 +1,264 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + aws2 "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/lambda" + "github.com/pkg/errors" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/mock" + + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + "github.com/stretchr/testify/assert" +) + +func TestScanLambdaFunction(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockLambdaRepository, *mocks.AlerterInterface) + err error + }{ + { + test: "no lambda functions", + dirName: "aws_lambda_function_empty", + mocks: func(repo *repository2.MockLambdaRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllLambdaFunctions").Return([]*lambda.FunctionConfiguration{}, nil) + }, + err: nil, + }, + { + test: "with lambda functions", + dirName: "aws_lambda_function_multiple", + mocks: func(repo *repository2.MockLambdaRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllLambdaFunctions").Return([]*lambda.FunctionConfiguration{ + { + FunctionName: awssdk.String("foo"), + }, + { + FunctionName: awssdk.String("bar"), + }, + }, nil) + }, + err: nil, + }, + { + test: "One lambda with signing", + dirName: "aws_lambda_function_signed", + mocks: func(repo *repository2.MockLambdaRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllLambdaFunctions").Return([]*lambda.FunctionConfiguration{ + { + FunctionName: awssdk.String("foo"), + }, + }, nil) + }, + err: nil, + }, + { + test: "cannot list lambda functions", + dirName: "aws_lambda_function_empty", + mocks: func(repo *repository2.MockLambdaRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repo.On("ListAllLambdaFunctions").Return([]*lambda.FunctionConfiguration{}, awsError) + + alerter.On("SendAlert", resourceaws.AwsLambdaFunctionResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsLambdaFunctionResourceType, resourceaws.AwsLambdaFunctionResourceType), alerts.EnumerationPhase)).Return() + }, + err: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + session := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockLambdaRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.LambdaRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewLambdaRepository(session, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewLambdaFunctionEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsLambdaFunctionResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsLambdaFunctionResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.err, err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsLambdaFunctionResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestScanLambdaEventSourceMapping(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockLambdaRepository, *mocks.AlerterInterface) + err error + }{ + { + test: "no EventSourceMapping", + dirName: "aws_lambda_source_mapping_empty", + mocks: func(repo *repository2.MockLambdaRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllLambdaEventSourceMappings").Return([]*lambda.EventSourceMappingConfiguration{}, nil) + }, + err: nil, + }, + { + test: "with 2 sqs EventSourceMapping", + dirName: "aws_lambda_source_mapping_sqs_multiple", + mocks: func(repo *repository2.MockLambdaRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllLambdaEventSourceMappings").Return([]*lambda.EventSourceMappingConfiguration{ + { + UUID: awssdk.String("13ff66f8-37eb-4ad6-a0a8-594fea72df4f"), + }, + { + UUID: awssdk.String("4ad7e2b3-79e9-4713-9d9d-5af2c01d9058"), + }, + }, nil) + }, + err: nil, + }, + { + test: "with dynamo EventSourceMapping", + dirName: "aws_lambda_source_mapping_dynamo_multiple", + mocks: func(repo *repository2.MockLambdaRepository, alerter *mocks.AlerterInterface) { + repo.On("ListAllLambdaEventSourceMappings").Return([]*lambda.EventSourceMappingConfiguration{ + { + UUID: awssdk.String("1aa9c4a0-060b-41c1-a9ae-dc304ebcdb00"), + }, + }, nil) + }, + err: nil, + }, + { + test: "cannot list lambda functions", + dirName: "aws_lambda_function_empty", + mocks: func(repo *repository2.MockLambdaRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repo.On("ListAllLambdaEventSourceMappings").Return([]*lambda.EventSourceMappingConfiguration{}, awsError) + + alerter.On("SendAlert", resourceaws.AwsLambdaEventSourceMappingResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsLambdaEventSourceMappingResourceType, resourceaws.AwsLambdaEventSourceMappingResourceType), alerts.EnumerationPhase)).Return() + }, + err: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + session := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockLambdaRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.LambdaRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewLambdaRepository(session, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewLambdaEventSourceMappingEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsLambdaEventSourceMappingResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsLambdaEventSourceMappingResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.err, err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsLambdaEventSourceMappingResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_rds_scanner_test.go b/enumeration/remote/aws_rds_scanner_test.go new file mode 100644 index 00000000..594c5a9c --- /dev/null +++ b/enumeration/remote/aws_rds_scanner_test.go @@ -0,0 +1,335 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + aws2 "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/rds" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestRDSDBInstance(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockRDSRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no db instances", + dirName: "aws_rds_db_instance_empty", + mocks: func(repository *repository2.MockRDSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllDBInstances").Return([]*rds.DBInstance{}, nil) + }, + }, + { + test: "single db instance", + dirName: "aws_rds_db_instance_single", + mocks: func(repository *repository2.MockRDSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllDBInstances").Return([]*rds.DBInstance{ + {DBInstanceIdentifier: awssdk.String("terraform-20201015115018309600000001")}, + }, nil) + }, + }, + { + test: "multiple mixed db instances", + dirName: "aws_rds_db_instance_multiple", + mocks: func(repository *repository2.MockRDSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllDBInstances").Return([]*rds.DBInstance{ + {DBInstanceIdentifier: awssdk.String("terraform-20201015115018309600000001")}, + {DBInstanceIdentifier: awssdk.String("database-1")}, + }, nil) + }, + }, + { + test: "cannot list db instances", + dirName: "aws_rds_db_instance_list", + mocks: func(repository *repository2.MockRDSRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllDBInstances").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsDbInstanceResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsDbInstanceResourceType, resourceaws.AwsDbInstanceResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockRDSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.RDSRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewRDSRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewRDSDBInstanceEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsDbInstanceResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsDbInstanceResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsDbInstanceResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestRDSDBSubnetGroup(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockRDSRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no db subnet groups", + dirName: "aws_rds_db_subnet_group_empty", + mocks: func(repository *repository2.MockRDSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllDBSubnetGroups").Return([]*rds.DBSubnetGroup{}, nil) + }, + }, + { + test: "multiple db subnet groups", + dirName: "aws_rds_db_subnet_group_multiple", + mocks: func(repository *repository2.MockRDSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllDBSubnetGroups").Return([]*rds.DBSubnetGroup{ + {DBSubnetGroupName: awssdk.String("foo")}, + {DBSubnetGroupName: awssdk.String("bar")}, + }, nil) + }, + }, + { + test: "cannot list db subnet groups", + dirName: "aws_rds_db_subnet_group_list", + mocks: func(repository *repository2.MockRDSRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllDBSubnetGroups").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsDbSubnetGroupResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsDbSubnetGroupResourceType, resourceaws.AwsDbSubnetGroupResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockRDSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.RDSRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewRDSRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewRDSDBSubnetGroupEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsDbSubnetGroupResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsDbSubnetGroupResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsDbSubnetGroupResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestRDSCluster(t *testing.T) { + tests := []struct { + test string + dirName string + mocks func(*repository2.MockRDSRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no cluster", + dirName: "aws_rds_cluster_empty", + mocks: func(repository *repository2.MockRDSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllDBClusters").Return([]*rds.DBCluster{}, nil) + }, + }, + { + test: "should return one result", + dirName: "aws_rds_clusters_results", + mocks: func(repository *repository2.MockRDSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllDBClusters").Return([]*rds.DBCluster{ + { + DBClusterIdentifier: awssdk.String("aurora-cluster-demo"), + DatabaseName: awssdk.String("mydb"), + }, + { + DBClusterIdentifier: awssdk.String("aurora-cluster-demo-2"), + }, + }, nil) + }, + }, + { + test: "cannot list clusters", + dirName: "aws_rds_cluster_denied", + mocks: func(repository *repository2.MockRDSRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 400, "") + repository.On("ListAllDBClusters").Return(nil, awsError).Once() + + alerter.On("SendAlert", resourceaws.AwsRDSClusterResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRDSClusterResourceType, resourceaws.AwsRDSClusterResourceType), alerts.EnumerationPhase)).Return().Once() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockRDSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.RDSRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewRDSRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewRDSClusterEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsRDSClusterResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsRDSClusterResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsRDSClusterResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_route53_scanner_test.go b/enumeration/remote/aws_route53_scanner_test.go new file mode 100644 index 00000000..5ccedb44 --- /dev/null +++ b/enumeration/remote/aws_route53_scanner_test.go @@ -0,0 +1,476 @@ +package remote + +import ( + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + aws2 "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + "testing" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/route53" + "github.com/pkg/errors" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/mock" + + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + "github.com/stretchr/testify/assert" +) + +func TestRoute53_HealthCheck(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockRoute53Repository, *mocks.AlerterInterface) + err error + }{ + { + test: "no health check", + dirName: "aws_route53_health_check_empty", + mocks: func(client *repository2.MockRoute53Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllHealthChecks").Return([]*route53.HealthCheck{}, nil) + }, + err: nil, + }, + { + test: "Multiple health check", + dirName: "aws_route53_health_check_multiple", + mocks: func(client *repository2.MockRoute53Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllHealthChecks").Return([]*route53.HealthCheck{ + {Id: awssdk.String("7001a9df-ded4-4802-9909-668eb80b972b")}, + {Id: awssdk.String("84fc318a-2e0d-41d6-b638-280e2f0f4e26")}, + }, nil) + }, + err: nil, + }, + { + test: "cannot list health check", + dirName: "aws_route53_health_check_empty", + mocks: func(client *repository2.MockRoute53Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + client.On("ListAllHealthChecks").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsRoute53HealthCheckResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRoute53HealthCheckResourceType, resourceaws.AwsRoute53HealthCheckResourceType), alerts.EnumerationPhase)).Return() + }, + err: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + session := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockRoute53Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.Route53Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewRoute53Repository(session, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewRoute53HealthCheckEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsRoute53HealthCheckResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsRoute53HealthCheckResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.err, err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsRoute53HealthCheckResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestRoute53_Zone(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockRoute53Repository, *mocks.AlerterInterface) + err error + }{ + { + test: "no zones", + dirName: "aws_route53_zone_empty", + mocks: func(client *repository2.MockRoute53Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllZones").Return( + []*route53.HostedZone{}, + nil, + ) + }, + err: nil, + }, + { + test: "single zone", + dirName: "aws_route53_zone_single", + mocks: func(client *repository2.MockRoute53Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllZones").Return( + []*route53.HostedZone{ + { + Id: awssdk.String("Z08068311RGDXPHF8KE62"), + Name: awssdk.String("foo.bar"), + }, + }, + nil, + ) + }, + err: nil, + }, + { + test: "multiples zone (test pagination)", + dirName: "aws_route53_zone_multiples", + mocks: func(client *repository2.MockRoute53Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllZones").Return( + []*route53.HostedZone{ + { + Id: awssdk.String("Z01809283VH9BBALZHO7B"), + Name: awssdk.String("foo-0.com"), + }, + { + Id: awssdk.String("Z01804312AV8PHE3C43AD"), + Name: awssdk.String("foo-1.com"), + }, + { + Id: awssdk.String("Z01874941AR1TCGV5K65C"), + Name: awssdk.String("foo-2.com"), + }, + }, + nil, + ) + }, + err: nil, + }, + { + test: "cannot list zones", + dirName: "aws_route53_zone_empty", + mocks: func(client *repository2.MockRoute53Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + client.On("ListAllZones").Return( + []*route53.HostedZone{}, + awsError, + ) + + alerter.On("SendAlert", resourceaws.AwsRoute53ZoneResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRoute53ZoneResourceType, resourceaws.AwsRoute53ZoneResourceType), alerts.EnumerationPhase)).Return() + }, + err: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + session := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockRoute53Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.Route53Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewRoute53Repository(session, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewRoute53ZoneEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsRoute53ZoneResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsRoute53ZoneResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.err, err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsRoute53ZoneResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestRoute53_Record(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockRoute53Repository, *mocks.AlerterInterface) + err error + }{ + { + test: "no records", + dirName: "aws_route53_zone_with_no_record", + mocks: func(client *repository2.MockRoute53Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllZones").Return( + []*route53.HostedZone{ + { + Id: awssdk.String("Z1035360GLIB82T1EH2G"), + Name: awssdk.String("foo-0.com"), + }, + }, + nil, + ) + client.On("ListRecordsForZone", "Z1035360GLIB82T1EH2G").Return([]*route53.ResourceRecordSet{}, nil) + }, + err: nil, + }, + { + test: "multiples records in multiples zones", + dirName: "aws_route53_record_multiples", + mocks: func(client *repository2.MockRoute53Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllZones").Return( + []*route53.HostedZone{ + { + Id: awssdk.String("Z1035360GLIB82T1EH2G"), + Name: awssdk.String("foo-0.com"), + }, + { + Id: awssdk.String("Z10347383HV75H96J919W"), + Name: awssdk.String("foo-1.com"), + }, + }, + nil, + ) + client.On("ListRecordsForZone", "Z1035360GLIB82T1EH2G").Return([]*route53.ResourceRecordSet{ + { + Name: awssdk.String("foo-0.com"), + Type: awssdk.String("NS"), + }, + { + Name: awssdk.String("test0"), + Type: awssdk.String("A"), + }, + { + Name: awssdk.String("test1"), + Type: awssdk.String("A"), + }, + { + Name: awssdk.String("test2"), + Type: awssdk.String("A"), + }, + { + Name: awssdk.String("\\052.test4."), + Type: awssdk.String("A"), + }, + }, nil) + client.On("ListRecordsForZone", "Z10347383HV75H96J919W").Return([]*route53.ResourceRecordSet{ + { + Name: awssdk.String("test2"), + Type: awssdk.String("A"), + }, + }, nil) + }, + err: nil, + }, + { + test: "explicit subdomain records", + dirName: "aws_route53_record_explicit_subdomain", + mocks: func(client *repository2.MockRoute53Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllZones").Return( + []*route53.HostedZone{ + { + Id: awssdk.String("Z06486383UC8WYSBZTWFM"), + Name: awssdk.String("foo-2.com"), + }, + }, + nil, + ) + client.On("ListRecordsForZone", "Z06486383UC8WYSBZTWFM").Return([]*route53.ResourceRecordSet{ + { + Name: awssdk.String("test0"), + Type: awssdk.String("TXT"), + }, + { + Name: awssdk.String("test0"), + Type: awssdk.String("A"), + }, + { + Name: awssdk.String("test1.foo-2.com"), + Type: awssdk.String("TXT"), + }, + { + Name: awssdk.String("test1.foo-2.com"), + Type: awssdk.String("A"), + }, + { + Name: awssdk.String("_test2.foo-2.com"), + Type: awssdk.String("TXT"), + }, + { + Name: awssdk.String("_test2.foo-2.com"), + Type: awssdk.String("A"), + }, + }, nil) + }, + err: nil, + }, + { + test: "cannot list zones", + dirName: "aws_route53_zone_with_no_record", + mocks: func(client *repository2.MockRoute53Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + client.On("ListAllZones").Return( + []*route53.HostedZone{}, + awsError) + + alerter.On("SendAlert", resourceaws.AwsRoute53RecordResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRoute53RecordResourceType, resourceaws.AwsRoute53ZoneResourceType), alerts.EnumerationPhase)).Return() + }, + err: nil, + }, + { + test: "cannot list records", + dirName: "aws_route53_zone_with_no_record", + mocks: func(client *repository2.MockRoute53Repository, alerter *mocks.AlerterInterface) { + client.On("ListAllZones").Return( + []*route53.HostedZone{ + { + Id: awssdk.String("Z06486383UC8WYSBZTWFM"), + Name: awssdk.String("foo-2.com"), + }, + }, + nil) + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + client.On("ListRecordsForZone", "Z06486383UC8WYSBZTWFM").Return( + []*route53.ResourceRecordSet{}, awsError) + + alerter.On("SendAlert", resourceaws.AwsRoute53RecordResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRoute53RecordResourceType, resourceaws.AwsRoute53RecordResourceType), alerts.EnumerationPhase)).Return() + }, + err: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + session := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockRoute53Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.Route53Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewRoute53Repository(session, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewRoute53RecordEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsRoute53RecordResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsRoute53RecordResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.err, err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsRoute53RecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_s3_scanner_test.go b/enumeration/remote/aws_s3_scanner_test.go new file mode 100644 index 00000000..1616460f --- /dev/null +++ b/enumeration/remote/aws_s3_scanner_test.go @@ -0,0 +1,1084 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + aws2 "github.com/snyk/driftctl/enumeration/remote/aws" + "github.com/snyk/driftctl/enumeration/remote/aws/client" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + tf "github.com/snyk/driftctl/enumeration/remote/terraform" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/pkg/errors" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/mock" + + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + "github.com/stretchr/testify/assert" +) + +func TestS3Bucket(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockS3Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "multiple bucket", dirName: "aws_s3_bucket_multiple", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + repository.On( + "ListAllBuckets", + ).Return([]*s3.Bucket{ + {Name: awssdk.String("bucket-martin-test-drift")}, + {Name: awssdk.String("bucket-martin-test-drift2")}, + {Name: awssdk.String("bucket-martin-test-drift3")}, + }, nil) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift", + ).Return( + "eu-west-1", + nil, + ) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift2", + ).Return( + "eu-west-3", + nil, + ) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift3", + ).Return( + "ap-northeast-1", + nil, + ) + }, + }, + { + test: "cannot list bucket", dirName: "aws_s3_bucket_list", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllBuckets").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsS3BucketResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketResourceType, resourceaws.AwsS3BucketResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + session := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockS3Repository{} + c.mocks(fakeRepo, alerter) + var repo repository2.S3Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewS3Repository(client.NewAWSClientFactory(session), cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewS3BucketEnumerator(repo, factory, tf.TerraformProviderConfig{ + Name: "test", + DefaultAlias: "eu-west-3", + }, alerter)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsS3BucketResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsS3BucketResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsS3BucketResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestS3BucketInventory(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockS3Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "multiple bucket with multiple inventories", dirName: "aws_s3_bucket_inventories_multiple", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + repository.On( + "ListAllBuckets", + ).Return([]*s3.Bucket{ + {Name: awssdk.String("bucket-martin-test-drift")}, + {Name: awssdk.String("bucket-martin-test-drift2")}, + {Name: awssdk.String("bucket-martin-test-drift3")}, + }, nil) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift", + ).Return( + "eu-west-1", + nil, + ) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift2", + ).Return( + "eu-west-3", + nil, + ) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift3", + ).Return( + "eu-west-1", + nil, + ) + + repository.On( + "ListBucketInventoryConfigurations", + &s3.Bucket{Name: awssdk.String("bucket-martin-test-drift2")}, + "eu-west-3", + ).Return( + []*s3.InventoryConfiguration{ + {Id: awssdk.String("Inventory_Bucket2")}, + {Id: awssdk.String("Inventory2_Bucket2")}, + }, + nil, + ) + }, + }, + { + test: "cannot list bucket", dirName: "aws_s3_bucket_inventories_list_bucket", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllBuckets").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsS3BucketInventoryResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketInventoryResourceType, resourceaws.AwsS3BucketResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + { + test: "cannot list bucket inventories", dirName: "aws_s3_bucket_inventories_list_inventories", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllBuckets").Return( + []*s3.Bucket{ + {Name: awssdk.String("bucket-martin-test-drift")}, + }, + nil, + ) + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift", + ).Return( + "eu-west-3", + nil, + ) + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On( + "ListBucketInventoryConfigurations", + &s3.Bucket{Name: awssdk.String("bucket-martin-test-drift")}, + "eu-west-3", + ).Return( + nil, + awsError, + ) + + alerter.On("SendAlert", resourceaws.AwsS3BucketInventoryResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketInventoryResourceType, resourceaws.AwsS3BucketInventoryResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + session := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockS3Repository{} + c.mocks(fakeRepo, alerter) + var repo repository2.S3Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewS3Repository(client.NewAWSClientFactory(session), cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewS3BucketInventoryEnumerator(repo, factory, tf.TerraformProviderConfig{ + Name: "test", + DefaultAlias: "eu-west-3", + }, alerter)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsS3BucketInventoryResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsS3BucketInventoryResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsS3BucketInventoryResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestS3BucketNotification(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockS3Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "single bucket without notifications", + dirName: "aws_s3_bucket_notifications_no_notif", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + repository.On( + "ListAllBuckets", + ).Return([]*s3.Bucket{ + {Name: awssdk.String("dritftctl-test-no-notifications")}, + }, nil) + + repository.On( + "GetBucketLocation", + "dritftctl-test-no-notifications", + ).Return( + "eu-west-3", + nil, + ) + + repository.On( + "GetBucketNotification", + "dritftctl-test-no-notifications", + "eu-west-3", + ).Return( + nil, + nil, + ) + }, + }, + { + test: "multiple bucket with notifications", dirName: "aws_s3_bucket_notifications_multiple", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + repository.On( + "ListAllBuckets", + ).Return([]*s3.Bucket{ + {Name: awssdk.String("bucket-martin-test-drift")}, + {Name: awssdk.String("bucket-martin-test-drift2")}, + {Name: awssdk.String("bucket-martin-test-drift3")}, + }, nil) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift", + ).Return( + "eu-west-1", + nil, + ) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift2", + ).Return( + "eu-west-3", + nil, + ) + + repository.On( + "GetBucketNotification", + "bucket-martin-test-drift2", + "eu-west-3", + ).Return( + &s3.NotificationConfiguration{ + LambdaFunctionConfigurations: []*s3.LambdaFunctionConfiguration{ + { + Id: awssdk.String("tf-s3-lambda-20201103165354926600000001"), + }, + { + Id: awssdk.String("tf-s3-lambda-20201103165354926600000002"), + }, + }, + }, + nil, + ) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift3", + ).Return( + "ap-northeast-1", + nil, + ) + }, + }, + { + test: "Cannot get bucket notification", dirName: "aws_s3_bucket_notifications_list_bucket", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + repository.On( + "ListAllBuckets", + ).Return([]*s3.Bucket{ + {Name: awssdk.String("dritftctl-test-notifications-error")}, + }, nil) + repository.On( + "GetBucketLocation", + "dritftctl-test-notifications-error", + ).Return( + "eu-west-3", + nil, + ) + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("GetBucketNotification", "dritftctl-test-notifications-error", "eu-west-3").Return(nil, awsError) + + alerter.On("SendAlert", "aws_s3_bucket_notification.dritftctl-test-notifications-error", alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, "aws_s3_bucket_notification.dritftctl-test-notifications-error", resourceaws.AwsS3BucketNotificationResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + { + test: "Cannot list bucket", dirName: "aws_s3_bucket_notifications_list_bucket", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllBuckets").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsS3BucketNotificationResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketNotificationResourceType, resourceaws.AwsS3BucketResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + session := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockS3Repository{} + c.mocks(fakeRepo, alerter) + var repo repository2.S3Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewS3Repository(client.NewAWSClientFactory(session), cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewS3BucketNotificationEnumerator(repo, factory, tf.TerraformProviderConfig{ + Name: "test", + DefaultAlias: "eu-west-3", + }, alerter)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsS3BucketNotificationResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsS3BucketNotificationResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsS3BucketNotificationResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestS3BucketMetrics(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockS3Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "multiple bucket with multiple metrics", dirName: "aws_s3_bucket_metrics_multiple", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + repository.On( + "ListAllBuckets", + ).Return([]*s3.Bucket{ + {Name: awssdk.String("bucket-martin-test-drift")}, + {Name: awssdk.String("bucket-martin-test-drift2")}, + {Name: awssdk.String("bucket-martin-test-drift3")}, + }, nil) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift", + ).Return( + "eu-west-1", + nil, + ) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift2", + ).Return( + "eu-west-3", + nil, + ) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift3", + ).Return( + "ap-northeast-1", + nil, + ) + + repository.On( + "ListBucketMetricsConfigurations", + &s3.Bucket{Name: awssdk.String("bucket-martin-test-drift2")}, + "eu-west-3", + ).Return( + []*s3.MetricsConfiguration{ + {Id: awssdk.String("Metrics_Bucket2")}, + {Id: awssdk.String("Metrics2_Bucket2")}, + }, + nil, + ) + }, + }, + { + test: "cannot list bucket", dirName: "aws_s3_bucket_metrics_list_bucket", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllBuckets").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsS3BucketMetricResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketMetricResourceType, resourceaws.AwsS3BucketResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + { + test: "cannot list metrics", dirName: "aws_s3_bucket_metrics_list_metrics", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllBuckets").Return( + []*s3.Bucket{ + {Name: awssdk.String("bucket-martin-test-drift")}, + }, + nil, + ) + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift", + ).Return( + "eu-west-3", + nil, + ) + + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On( + "ListBucketMetricsConfigurations", + &s3.Bucket{Name: awssdk.String("bucket-martin-test-drift")}, + "eu-west-3", + ).Return( + nil, + awsError, + ) + + alerter.On("SendAlert", resourceaws.AwsS3BucketMetricResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketMetricResourceType, resourceaws.AwsS3BucketMetricResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + session := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockS3Repository{} + c.mocks(fakeRepo, alerter) + var repo repository2.S3Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewS3Repository(client.NewAWSClientFactory(session), cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewS3BucketMetricsEnumerator(repo, factory, tf.TerraformProviderConfig{ + Name: "test", + DefaultAlias: "eu-west-3", + }, alerter)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsS3BucketMetricResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsS3BucketMetricResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsS3BucketMetricResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestS3BucketPolicy(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockS3Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "single bucket without policy", + dirName: "aws_s3_bucket_policy_no_policy", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + repository.On( + "ListAllBuckets", + ).Return([]*s3.Bucket{ + {Name: awssdk.String("dritftctl-test-no-policy")}, + }, nil) + + repository.On( + "GetBucketLocation", + "dritftctl-test-no-policy", + ).Return( + "eu-west-3", + nil, + ) + + repository.On( + "GetBucketPolicy", + "dritftctl-test-no-policy", + "eu-west-3", + ).Return( + nil, + nil, + ) + }, + }, + { + test: "multiple bucket with policies", dirName: "aws_s3_bucket_policies_multiple", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + repository.On( + "ListAllBuckets", + ).Return([]*s3.Bucket{ + {Name: awssdk.String("bucket-martin-test-drift")}, + {Name: awssdk.String("bucket-martin-test-drift2")}, + {Name: awssdk.String("bucket-martin-test-drift3")}, + }, nil) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift", + ).Return( + "eu-west-1", + nil, + ) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift2", + ).Return( + "eu-west-3", + nil, + ) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift3", + ).Return( + "ap-northeast-1", + nil, + ) + + repository.On( + "GetBucketPolicy", + "bucket-martin-test-drift2", + "eu-west-3", + ).Return( + // The value here not matter, we only want something not empty + // to trigger the detail fetcher + awssdk.String("foobar"), + nil, + ) + + }, + }, + { + test: "cannot list bucket", dirName: "aws_s3_bucket_policies_list_bucket", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllBuckets").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsS3BucketPolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketPolicyResourceType, resourceaws.AwsS3BucketResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + session := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockS3Repository{} + c.mocks(fakeRepo, alerter) + var repo repository2.S3Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewS3Repository(client.NewAWSClientFactory(session), cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewS3BucketPolicyEnumerator(repo, factory, tf.TerraformProviderConfig{ + Name: "test", + DefaultAlias: "eu-west-3", + }, alerter)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsS3BucketPolicyResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsS3BucketPolicyResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsS3BucketPolicyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestS3BucketPublicAccessBlock(t *testing.T) { + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockS3Repository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "multiple bucket, one with access block", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllBuckets").Return([]*s3.Bucket{ + {Name: awssdk.String("bucket-with-public-access-block")}, + {Name: awssdk.String("bucket-without-public-access-block")}, + }, nil) + + repository.On("GetBucketLocation", "bucket-with-public-access-block"). + Return("us-east-1", nil) + repository.On("GetBucketLocation", "bucket-without-public-access-block"). + Return("us-east-1", nil) + + repository.On("GetBucketPublicAccessBlock", "bucket-with-public-access-block", "us-east-1"). + Return(&s3.PublicAccessBlockConfiguration{ + BlockPublicAcls: awssdk.Bool(true), + BlockPublicPolicy: awssdk.Bool(false), + }, nil) + + repository.On("GetBucketPublicAccessBlock", "bucket-without-public-access-block", "us-east-1"). + Return(nil, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + assert.Equal(t, got[0].ResourceId(), "bucket-with-public-access-block") + assert.Equal(t, got[0].ResourceType(), resourceaws.AwsS3BucketPublicAccessBlockResourceType) + assert.Equal(t, got[0].Attributes(), &resource.Attributes{ + "block_public_acls": true, + "block_public_policy": false, + "ignore_public_acls": false, + "restrict_public_buckets": false, + }) + }, + }, + { + test: "cannot list bucket", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllBuckets").Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsS3BucketPublicAccessBlockResourceType, resourceaws.AwsS3BucketResourceType), + }, + { + test: "cannot list public access block", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllBuckets").Return([]*s3.Bucket{{Name: awssdk.String("foobar")}}, nil) + repository.On("GetBucketLocation", "foobar").Return("us-east-1", nil) + repository.On("GetBucketPublicAccessBlock", "foobar", "us-east-1").Return(nil, dummyError) + alerter.On("SendAlert", "aws_s3_bucket_public_access_block.foobar", alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceScanningError(dummyError, resourceaws.AwsS3BucketPublicAccessBlockResourceType, "foobar"), alerts.EnumerationPhase)).Return() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + } + + providerVersion := "3.19.0" + schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockS3Repository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.S3Repository = fakeRepo + + remoteLibrary.AddEnumerator(aws2.NewS3BucketPublicAccessBlockEnumerator( + repo, factory, + tf.TerraformProviderConfig{DefaultAlias: "us-east-1"}, + alerter, + )) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + }) + } +} + +func TestS3BucketAnalytic(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockS3Repository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "multiple bucket with multiple analytics", + dirName: "aws_s3_bucket_analytics_multiple", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + repository.On( + "ListAllBuckets", + ).Return([]*s3.Bucket{ + {Name: awssdk.String("bucket-martin-test-drift")}, + {Name: awssdk.String("bucket-martin-test-drift2")}, + {Name: awssdk.String("bucket-martin-test-drift3")}, + }, nil) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift", + ).Return( + "eu-west-1", + nil, + ) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift2", + ).Return( + "eu-west-3", + nil, + ) + + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift3", + ).Return( + "ap-northeast-1", + nil, + ) + + repository.On( + "ListBucketAnalyticsConfigurations", + &s3.Bucket{Name: awssdk.String("bucket-martin-test-drift2")}, + "eu-west-3", + ).Return( + []*s3.AnalyticsConfiguration{ + {Id: awssdk.String("Analytics_Bucket2")}, + {Id: awssdk.String("Analytics2_Bucket2")}, + }, + nil, + ) + }, + }, + { + test: "cannot list bucket", dirName: "aws_s3_bucket_analytics_list_bucket", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On("ListAllBuckets").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsS3BucketAnalyticsConfigurationResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketAnalyticsConfigurationResourceType, resourceaws.AwsS3BucketResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + { + test: "cannot list Analytics", dirName: "aws_s3_bucket_analytics_list_analytics", + mocks: func(repository *repository2.MockS3Repository, alerter *mocks.AlerterInterface) { + repository.On("ListAllBuckets").Return( + []*s3.Bucket{ + {Name: awssdk.String("bucket-martin-test-drift")}, + }, + nil, + ) + repository.On( + "GetBucketLocation", + "bucket-martin-test-drift", + ).Return( + "eu-west-3", + nil, + ) + + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + repository.On( + "ListBucketAnalyticsConfigurations", + &s3.Bucket{Name: awssdk.String("bucket-martin-test-drift")}, + "eu-west-3", + ).Return( + nil, + awsError, + ) + + alerter.On("SendAlert", resourceaws.AwsS3BucketAnalyticsConfigurationResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketAnalyticsConfigurationResourceType, resourceaws.AwsS3BucketAnalyticsConfigurationResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + session := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockS3Repository{} + c.mocks(fakeRepo, alerter) + var repo repository2.S3Repository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewS3Repository(client.NewAWSClientFactory(session), cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewS3BucketAnalyticEnumerator(repo, factory, tf.TerraformProviderConfig{ + Name: "test", + DefaultAlias: "eu-west-3", + }, alerter)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsS3BucketAnalyticsConfigurationResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsS3BucketAnalyticsConfigurationResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsS3BucketAnalyticsConfigurationResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_sns_scanner_test.go b/enumeration/remote/aws_sns_scanner_test.go new file mode 100644 index 00000000..9ed38085 --- /dev/null +++ b/enumeration/remote/aws_sns_scanner_test.go @@ -0,0 +1,350 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + aws2 "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sns" + "github.com/pkg/errors" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/mock" + + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + "github.com/stretchr/testify/assert" +) + +func TestScanSNSTopic(t *testing.T) { + + cases := []struct { + test string + dirName string + mocks func(*repository2.MockSNSRepository, *mocks.AlerterInterface) + err error + }{ + { + test: "no SNS Topic", + dirName: "aws_sns_topic_empty", + mocks: func(client *repository2.MockSNSRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllTopics").Return([]*sns.Topic{}, nil) + }, + err: nil, + }, + { + test: "Multiple SNSTopic", + dirName: "aws_sns_topic_multiple", + mocks: func(client *repository2.MockSNSRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllTopics").Return([]*sns.Topic{ + {TopicArn: awssdk.String("arn:aws:sns:eu-west-3:526954929923:user-updates-topic")}, + {TopicArn: awssdk.String("arn:aws:sns:eu-west-3:526954929923:user-updates-topic2")}, + {TopicArn: awssdk.String("arn:aws:sns:eu-west-3:526954929923:user-updates-topic3")}, + }, nil) + }, + err: nil, + }, + { + test: "cannot list SNSTopic", + dirName: "aws_sns_topic_empty", + mocks: func(client *repository2.MockSNSRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + client.On("ListAllTopics").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsSnsTopicResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSnsTopicResourceType, resourceaws.AwsSnsTopicResourceType), alerts.EnumerationPhase)).Return() + }, + err: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockSNSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.SNSRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewSNSRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewSNSTopicEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsSnsTopicResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsSnsTopicResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.err, err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsSnsTopicResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestSNSTopicPolicyScan(t *testing.T) { + + cases := []struct { + test string + dirName string + mocks func(*repository2.MockSNSRepository, *mocks.AlerterInterface) + err error + }{ + { + test: "no SNS Topic policy", + dirName: "aws_sns_topic_policy_empty", + mocks: func(client *repository2.MockSNSRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllTopics").Return([]*sns.Topic{}, nil) + }, + err: nil, + }, + { + test: "Multiple SNSTopicPolicy", + dirName: "aws_sns_topic_policy_multiple", + mocks: func(client *repository2.MockSNSRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllTopics").Return([]*sns.Topic{ + {TopicArn: awssdk.String("arn:aws:sns:us-east-1:526954929923:my-topic-with-policy")}, + {TopicArn: awssdk.String("arn:aws:sns:us-east-1:526954929923:my-topic-with-policy2")}, + }, nil) + }, + err: nil, + }, + { + test: "cannot list SNSTopic", + dirName: "aws_sns_topic_policy_topic_list", + mocks: func(client *repository2.MockSNSRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + client.On("ListAllTopics").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsSnsTopicPolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSnsTopicPolicyResourceType, resourceaws.AwsSnsTopicResourceType), alerts.EnumerationPhase)).Return() + }, + err: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockSNSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.SNSRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewSNSRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewSNSTopicPolicyEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsSnsTopicPolicyResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsSnsTopicPolicyResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.err, err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsSnsTopicPolicyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestSNSTopicSubscriptionScan(t *testing.T) { + + cases := []struct { + test string + dirName string + mocks func(*repository2.MockSNSRepository, *mocks.AlerterInterface) + err error + }{ + { + test: "no SNS Topic Subscription", + dirName: "aws_sns_topic_subscription_empty", + mocks: func(client *repository2.MockSNSRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllSubscriptions").Return([]*sns.Subscription{}, nil) + }, + err: nil, + }, + { + test: "Multiple SNSTopic Subscription", + dirName: "aws_sns_topic_subscription_multiple", + mocks: func(client *repository2.MockSNSRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllSubscriptions").Return([]*sns.Subscription{ + {SubscriptionArn: awssdk.String("arn:aws:sns:us-east-1:526954929923:user-updates-topic2:c0f794c5-a009-4db4-9147-4c55959787fa")}, + {SubscriptionArn: awssdk.String("arn:aws:sns:us-east-1:526954929923:user-updates-topic:b6e66147-2b31-4486-8d4b-2a2272264c8e")}, + }, nil) + }, + err: nil, + }, + { + test: "Multiple SNSTopic Subscription with one pending and one incorrect", + dirName: "aws_sns_topic_subscription_multiple", + mocks: func(client *repository2.MockSNSRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllSubscriptions").Return([]*sns.Subscription{ + {SubscriptionArn: awssdk.String("PendingConfirmation"), Endpoint: awssdk.String("TEST")}, + {SubscriptionArn: awssdk.String("Incorrect"), Endpoint: awssdk.String("INCORRECT")}, + {SubscriptionArn: awssdk.String("arn:aws:sns:us-east-1:526954929923:user-updates-topic2:c0f794c5-a009-4db4-9147-4c55959787fa")}, + {SubscriptionArn: awssdk.String("arn:aws:sns:us-east-1:526954929923:user-updates-topic:b6e66147-2b31-4486-8d4b-2a2272264c8e")}, + }, nil) + + alerter.On("SendAlert", "aws_sns_topic_subscription.PendingConfirmation", aws2.NewWrongArnTopicAlert("PendingConfirmation", awssdk.String("TEST"))).Return() + + alerter.On("SendAlert", "aws_sns_topic_subscription.Incorrect", aws2.NewWrongArnTopicAlert("Incorrect", awssdk.String("INCORRECT"))).Return() + }, + err: nil, + }, + { + test: "cannot list SNSTopic subscription", + dirName: "aws_sns_topic_subscription_list", + mocks: func(client *repository2.MockSNSRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + client.On("ListAllSubscriptions").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsSnsTopicSubscriptionResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSnsTopicSubscriptionResourceType, resourceaws.AwsSnsTopicSubscriptionResourceType), alerts.EnumerationPhase)).Return() + }, + err: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockSNSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.SNSRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewSNSRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewSNSTopicSubscriptionEnumerator(repo, factory, alerter)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsSnsTopicSubscriptionResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsSnsTopicSubscriptionResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsSnsTopicSubscriptionResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/aws_sqs_scanner_test.go b/enumeration/remote/aws_sqs_scanner_test.go new file mode 100644 index 00000000..72b7345a --- /dev/null +++ b/enumeration/remote/aws_sqs_scanner_test.go @@ -0,0 +1,255 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + aws2 "github.com/snyk/driftctl/enumeration/remote/aws" + repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestSQSQueue(t *testing.T) { + cases := []struct { + test string + dirName string + mocks func(*repository2.MockSQSRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no sqs queues", + dirName: "aws_sqs_queue_empty", + mocks: func(client *repository2.MockSQSRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllQueues").Return([]*string{}, nil) + }, + wantErr: nil, + }, + { + test: "multiple sqs queues", + dirName: "aws_sqs_queue_multiple", + mocks: func(client *repository2.MockSQSRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllQueues").Return([]*string{ + awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/bar.fifo"), + awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/foo"), + }, nil) + }, + wantErr: nil, + }, + { + test: "cannot list sqs queues", + dirName: "aws_sqs_queue_empty", + mocks: func(client *repository2.MockSQSRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + client.On("ListAllQueues").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsSqsQueueResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSqsQueueResourceType, resourceaws.AwsSqsQueueResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockSQSRepository{} + c.mocks(fakeRepo, alerter) + var repo repository2.SQSRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewSQSRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewSQSQueueEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsSqsQueueResourceType, aws2.NewSQSQueueDetailsFetcher(provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsSqsQueueResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + fakeRepo.AssertExpectations(tt) + alerter.AssertExpectations(tt) + }) + } +} + +func TestSQSQueuePolicy(t *testing.T) { + cases := []struct { + test string + dirName string + mocks func(*repository2.MockSQSRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + // sqs queue with no policy case is not possible + // as a default SQSDefaultPolicy (e.g. policy="") will always be present in each queue + test: "no sqs queue policies", + dirName: "aws_sqs_queue_policy_empty", + mocks: func(client *repository2.MockSQSRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllQueues").Return([]*string{}, nil) + }, + wantErr: nil, + }, + { + test: "multiple sqs queue policies (default or not)", + dirName: "aws_sqs_queue_policy_multiple", + mocks: func(client *repository2.MockSQSRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllQueues").Return([]*string{ + awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/bar.fifo"), + awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/foo"), + awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/baz"), + }, nil) + + client.On("GetQueueAttributes", mock.Anything).Return( + &sqs.GetQueueAttributesOutput{ + Attributes: map[string]*string{ + sqs.QueueAttributeNamePolicy: awssdk.String(""), + }, + }, + nil, + ) + }, + wantErr: nil, + }, + { + test: "multiple sqs queue policies (with nil attributes)", + dirName: "aws_sqs_queue_policy_multiple", + mocks: func(client *repository2.MockSQSRepository, alerter *mocks.AlerterInterface) { + client.On("ListAllQueues").Return([]*string{ + awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/bar.fifo"), + awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/foo"), + awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/baz"), + }, nil) + + client.On("GetQueueAttributes", mock.Anything).Return( + &sqs.GetQueueAttributesOutput{}, + nil, + ) + }, + wantErr: nil, + }, + { + test: "cannot list sqs queues, thus sqs queue policies", + dirName: "aws_sqs_queue_policy_empty", + mocks: func(client *repository2.MockSQSRepository, alerter *mocks.AlerterInterface) { + awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") + client.On("ListAllQueues").Return(nil, awsError) + + alerter.On("SendAlert", resourceaws.AwsSqsQueuePolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSqsQueuePolicyResourceType, resourceaws.AwsSqsQueueResourceType), alerts.EnumerationPhase)).Return() + }, + wantErr: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") + resourceaws.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockSQSRepository{} + c.mocks(fakeRepo, alerter) + var repo repository2.SQSRepository = fakeRepo + providerVersion := "3.19.0" + realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = repository2.NewSQSRepository(sess, cache.New(0)) + } + + remoteLibrary.AddEnumerator(aws2.NewSQSQueuePolicyEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceaws.AwsSqsQueuePolicyResourceType, common2.NewGenericDetailsFetcher(resourceaws.AwsSqsQueuePolicyResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceaws.AwsSqsQueuePolicyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + fakeRepo.AssertExpectations(tt) + alerter.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/azurerm/azurerm_container_registry_enumerator.go b/enumeration/remote/azurerm/azurerm_container_registry_enumerator.go new file mode 100644 index 00000000..66ed08e9 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_container_registry_enumerator.go @@ -0,0 +1,45 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermContainerRegistryEnumerator struct { + repository repository.ContainerRegistryRepository + factory resource.ResourceFactory +} + +func NewAzurermContainerRegistryEnumerator(repo repository.ContainerRegistryRepository, factory resource.ResourceFactory) *AzurermContainerRegistryEnumerator { + return &AzurermContainerRegistryEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermContainerRegistryEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzureContainerRegistryResourceType +} + +func (e *AzurermContainerRegistryEnumerator) Enumerate() ([]*resource.Resource, error) { + registries, err := e.repository.ListAllContainerRegistries() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0) + for _, registry := range registries { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *registry.ID, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_firewalls_enumerator.go b/enumeration/remote/azurerm/azurerm_firewalls_enumerator.go new file mode 100644 index 00000000..6487acdd --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_firewalls_enumerator.go @@ -0,0 +1,48 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermFirewallsEnumerator struct { + repository repository.NetworkRepository + factory resource.ResourceFactory +} + +func NewAzurermFirewallsEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermFirewallsEnumerator { + return &AzurermFirewallsEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermFirewallsEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzureFirewallResourceType +} + +func (e *AzurermFirewallsEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.ListAllFirewalls() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *res.ID, + map[string]interface{}{ + "name": *res.Name, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_image_enumerator.go b/enumeration/remote/azurerm/azurerm_image_enumerator.go new file mode 100644 index 00000000..0fe11379 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_image_enumerator.go @@ -0,0 +1,65 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "strings" + + "github.com/Azure/go-autorest/autorest/azure" + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermImageEnumerator struct { + repository repository.ComputeRepository + factory resource.ResourceFactory +} + +func NewAzurermImageEnumerator(repo repository.ComputeRepository, factory resource.ResourceFactory) *AzurermImageEnumerator { + return &AzurermImageEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermImageEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzureImageResourceType +} + +func (e *AzurermImageEnumerator) Enumerate() ([]*resource.Resource, error) { + images, err := e.repository.ListAllImages() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(images)) + + for _, res := range images { + r, err := azure.ParseResourceID(*res.ID) + if err != nil { + logrus.WithFields(map[string]interface{}{ + "id": *res.ID, + "type": string(e.SupportedType()), + }).Error("Failed to parse Azure resource ID") + continue + } + + // Here we turn the resource group into lowercase because for some reason the API returns it in uppercase. + resourceId := strings.Replace(*res.ID, r.ResourceGroup, strings.ToLower(r.ResourceGroup), 1) + + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + resourceId, + map[string]interface{}{ + "name": *res.Name, + }, + ), + ) + + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_lb_enumerator.go b/enumeration/remote/azurerm/azurerm_lb_enumerator.go new file mode 100644 index 00000000..5ea34df3 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_lb_enumerator.go @@ -0,0 +1,48 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermLoadBalancerEnumerator struct { + repository repository.NetworkRepository + factory resource.ResourceFactory +} + +func NewAzurermLoadBalancerEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermLoadBalancerEnumerator { + return &AzurermLoadBalancerEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermLoadBalancerEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzureLoadBalancerResourceType +} + +func (e *AzurermLoadBalancerEnumerator) Enumerate() ([]*resource.Resource, error) { + loadBalancers, err := e.repository.ListAllLoadBalancers() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(loadBalancers)) + + for _, res := range loadBalancers { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *res.ID, + map[string]interface{}{ + "name": *res.Name, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_lb_rule_enumerator.go b/enumeration/remote/azurerm/azurerm_lb_rule_enumerator.go new file mode 100644 index 00000000..01a781fe --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_lb_rule_enumerator.go @@ -0,0 +1,56 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermLoadBalancerRuleEnumerator struct { + repository repository.NetworkRepository + factory resource.ResourceFactory +} + +func NewAzurermLoadBalancerRuleEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermLoadBalancerRuleEnumerator { + return &AzurermLoadBalancerRuleEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermLoadBalancerRuleEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzureLoadBalancerRuleResourceType +} + +func (e *AzurermLoadBalancerRuleEnumerator) Enumerate() ([]*resource.Resource, error) { + loadBalancers, err := e.repository.ListAllLoadBalancers() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzureLoadBalancerResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, res := range loadBalancers { + rules, err := e.repository.ListLoadBalancerRules(res) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, rule := range rules { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *rule.ID, + map[string]interface{}{ + "name": *rule.Name, + "loadbalancer_id": *res.ID, + }, + ), + ) + } + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_network_security_group_enumerator.go b/enumeration/remote/azurerm/azurerm_network_security_group_enumerator.go new file mode 100644 index 00000000..aede9644 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_network_security_group_enumerator.go @@ -0,0 +1,48 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermNetworkSecurityGroupEnumerator struct { + repository repository.NetworkRepository + factory resource.ResourceFactory +} + +func NewAzurermNetworkSecurityGroupEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermNetworkSecurityGroupEnumerator { + return &AzurermNetworkSecurityGroupEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermNetworkSecurityGroupEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzureNetworkSecurityGroupResourceType +} + +func (e *AzurermNetworkSecurityGroupEnumerator) Enumerate() ([]*resource.Resource, error) { + securityGroups, err := e.repository.ListAllSecurityGroups() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzureNetworkSecurityGroupResourceType) + } + + results := make([]*resource.Resource, 0, len(securityGroups)) + + for _, res := range securityGroups { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *res.ID, + map[string]interface{}{ + "name": *res.Name, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_postgresql_database_enumerator.go b/enumeration/remote/azurerm/azurerm_postgresql_database_enumerator.go new file mode 100644 index 00000000..2283ec38 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_postgresql_database_enumerator.go @@ -0,0 +1,54 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermPostgresqlDatabaseEnumerator struct { + repository repository.PostgresqlRespository + factory resource.ResourceFactory +} + +func NewAzurermPostgresqlDatabaseEnumerator(repo repository.PostgresqlRespository, factory resource.ResourceFactory) *AzurermPostgresqlDatabaseEnumerator { + return &AzurermPostgresqlDatabaseEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermPostgresqlDatabaseEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzurePostgresqlDatabaseResourceType +} + +func (e *AzurermPostgresqlDatabaseEnumerator) Enumerate() ([]*resource.Resource, error) { + servers, err := e.repository.ListAllServers() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePostgresqlServerResourceType) + } + + results := make([]*resource.Resource, 0) + for _, server := range servers { + databases, err := e.repository.ListAllDatabasesByServer(server) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, db := range databases { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *db.ID, + map[string]interface{}{ + "name": *db.Name, + }, + ), + ) + } + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_postgresql_server_enumerator.go b/enumeration/remote/azurerm/azurerm_postgresql_server_enumerator.go new file mode 100644 index 00000000..a874f921 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_postgresql_server_enumerator.go @@ -0,0 +1,47 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermPostgresqlServerEnumerator struct { + repository repository.PostgresqlRespository + factory resource.ResourceFactory +} + +func NewAzurermPostgresqlServerEnumerator(repo repository.PostgresqlRespository, factory resource.ResourceFactory) *AzurermPostgresqlServerEnumerator { + return &AzurermPostgresqlServerEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermPostgresqlServerEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzurePostgresqlServerResourceType +} + +func (e *AzurermPostgresqlServerEnumerator) Enumerate() ([]*resource.Resource, error) { + servers, err := e.repository.ListAllServers() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0) + for _, server := range servers { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *server.ID, + map[string]interface{}{ + "name": *server.Name, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_private_dns_cname_record_enumerator.go b/enumeration/remote/azurerm/azurerm_private_dns_cname_record_enumerator.go new file mode 100644 index 00000000..0a69dac1 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_private_dns_cname_record_enumerator.go @@ -0,0 +1,57 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermPrivateDNSCNameRecordEnumerator struct { + repository repository.PrivateDNSRepository + factory resource.ResourceFactory +} + +func NewAzurermPrivateDNSCNameRecordEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSCNameRecordEnumerator { + return &AzurermPrivateDNSCNameRecordEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermPrivateDNSCNameRecordEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzurePrivateDNSCNameRecordResourceType +} + +func (e *AzurermPrivateDNSCNameRecordEnumerator) Enumerate() ([]*resource.Resource, error) { + + zones, err := e.repository.ListAllPrivateZones() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePrivateDNSZoneResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, zone := range zones { + records, err := e.repository.ListAllCNAMERecords(zone) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + for _, record := range records { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *record.ID, + map[string]interface{}{ + "name": *record.Name, + "zone_name": *zone.Name, + }, + ), + ) + } + + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_privatedns_a_record_enumerator.go b/enumeration/remote/azurerm/azurerm_privatedns_a_record_enumerator.go new file mode 100644 index 00000000..6dc1fb25 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_privatedns_a_record_enumerator.go @@ -0,0 +1,57 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermPrivateDNSARecordEnumerator struct { + repository repository.PrivateDNSRepository + factory resource.ResourceFactory +} + +func NewAzurermPrivateDNSARecordEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSARecordEnumerator { + return &AzurermPrivateDNSARecordEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermPrivateDNSARecordEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzurePrivateDNSARecordResourceType +} + +func (e *AzurermPrivateDNSARecordEnumerator) Enumerate() ([]*resource.Resource, error) { + + zones, err := e.repository.ListAllPrivateZones() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePrivateDNSZoneResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, zone := range zones { + records, err := e.repository.ListAllARecords(zone) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + for _, record := range records { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *record.ID, + map[string]interface{}{ + "name": *record.Name, + "zone_name": *zone.Name, + }, + ), + ) + } + + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_privatedns_aaaa_record_enumerator.go b/enumeration/remote/azurerm/azurerm_privatedns_aaaa_record_enumerator.go new file mode 100644 index 00000000..b14223bb --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_privatedns_aaaa_record_enumerator.go @@ -0,0 +1,57 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermPrivateDNSAAAARecordEnumerator struct { + repository repository.PrivateDNSRepository + factory resource.ResourceFactory +} + +func NewAzurermPrivateDNSAAAARecordEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSAAAARecordEnumerator { + return &AzurermPrivateDNSAAAARecordEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermPrivateDNSAAAARecordEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzurePrivateDNSAAAARecordResourceType +} + +func (e *AzurermPrivateDNSAAAARecordEnumerator) Enumerate() ([]*resource.Resource, error) { + + zones, err := e.repository.ListAllPrivateZones() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePrivateDNSZoneResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, zone := range zones { + records, err := e.repository.ListAllAAAARecords(zone) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + for _, record := range records { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *record.ID, + map[string]interface{}{ + "name": *record.Name, + "zone_name": *zone.Name, + }, + ), + ) + } + + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_privatedns_mx_record_enumerator.go b/enumeration/remote/azurerm/azurerm_privatedns_mx_record_enumerator.go new file mode 100644 index 00000000..5c532e83 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_privatedns_mx_record_enumerator.go @@ -0,0 +1,57 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermPrivateDNSMXRecordEnumerator struct { + repository repository.PrivateDNSRepository + factory resource.ResourceFactory +} + +func NewAzurermPrivateDNSMXRecordEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSMXRecordEnumerator { + return &AzurermPrivateDNSMXRecordEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermPrivateDNSMXRecordEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzurePrivateDNSMXRecordResourceType +} + +func (e *AzurermPrivateDNSMXRecordEnumerator) Enumerate() ([]*resource.Resource, error) { + + zones, err := e.repository.ListAllPrivateZones() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePrivateDNSZoneResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, zone := range zones { + records, err := e.repository.ListAllMXRecords(zone) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + for _, record := range records { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *record.ID, + map[string]interface{}{ + "name": *record.Name, + "zone_name": *zone.Name, + }, + ), + ) + } + + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_privatedns_ptr_record_enumerator.go b/enumeration/remote/azurerm/azurerm_privatedns_ptr_record_enumerator.go new file mode 100644 index 00000000..fdfbd136 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_privatedns_ptr_record_enumerator.go @@ -0,0 +1,57 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermPrivateDNSPTRRecordEnumerator struct { + repository repository.PrivateDNSRepository + factory resource.ResourceFactory +} + +func NewAzurermPrivateDNSPTRRecordEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSPTRRecordEnumerator { + return &AzurermPrivateDNSPTRRecordEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermPrivateDNSPTRRecordEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzurePrivateDNSPTRRecordResourceType +} + +func (e *AzurermPrivateDNSPTRRecordEnumerator) Enumerate() ([]*resource.Resource, error) { + + zones, err := e.repository.ListAllPrivateZones() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePrivateDNSZoneResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, zone := range zones { + records, err := e.repository.ListAllPTRRecords(zone) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + for _, record := range records { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *record.ID, + map[string]interface{}{ + "name": *record.Name, + "zone_name": *zone.Name, + }, + ), + ) + } + + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_privatedns_srv_record_enumerator.go b/enumeration/remote/azurerm/azurerm_privatedns_srv_record_enumerator.go new file mode 100644 index 00000000..fdb99853 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_privatedns_srv_record_enumerator.go @@ -0,0 +1,57 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermPrivateDNSSRVRecordEnumerator struct { + repository repository.PrivateDNSRepository + factory resource.ResourceFactory +} + +func NewAzurermPrivateDNSSRVRecordEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSSRVRecordEnumerator { + return &AzurermPrivateDNSSRVRecordEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermPrivateDNSSRVRecordEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzurePrivateDNSSRVRecordResourceType +} + +func (e *AzurermPrivateDNSSRVRecordEnumerator) Enumerate() ([]*resource.Resource, error) { + + zones, err := e.repository.ListAllPrivateZones() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePrivateDNSZoneResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, zone := range zones { + records, err := e.repository.ListAllSRVRecords(zone) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + for _, record := range records { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *record.ID, + map[string]interface{}{ + "name": *record.Name, + "zone_name": *zone.Name, + }, + ), + ) + } + + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_privatedns_txt_record_enumerator.go b/enumeration/remote/azurerm/azurerm_privatedns_txt_record_enumerator.go new file mode 100644 index 00000000..ec0047fb --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_privatedns_txt_record_enumerator.go @@ -0,0 +1,57 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermPrivateDNSTXTRecordEnumerator struct { + repository repository.PrivateDNSRepository + factory resource.ResourceFactory +} + +func NewAzurermPrivateDNSTXTRecordEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSTXTRecordEnumerator { + return &AzurermPrivateDNSTXTRecordEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermPrivateDNSTXTRecordEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzurePrivateDNSTXTRecordResourceType +} + +func (e *AzurermPrivateDNSTXTRecordEnumerator) Enumerate() ([]*resource.Resource, error) { + + zones, err := e.repository.ListAllPrivateZones() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePrivateDNSZoneResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, zone := range zones { + records, err := e.repository.ListAllTXTRecords(zone) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + for _, record := range records { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *record.ID, + map[string]interface{}{ + "name": *record.Name, + "zone_name": *zone.Name, + }, + ), + ) + } + + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_privatedns_zone_enumerator.go b/enumeration/remote/azurerm/azurerm_privatedns_zone_enumerator.go new file mode 100644 index 00000000..afa152a0 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_privatedns_zone_enumerator.go @@ -0,0 +1,49 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermPrivateDNSZoneEnumerator struct { + repository repository.PrivateDNSRepository + factory resource.ResourceFactory +} + +func NewAzurermPrivateDNSZoneEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSZoneEnumerator { + return &AzurermPrivateDNSZoneEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermPrivateDNSZoneEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzurePrivateDNSZoneResourceType +} + +func (e *AzurermPrivateDNSZoneEnumerator) Enumerate() ([]*resource.Resource, error) { + + zones, err := e.repository.ListAllPrivateZones() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0) + + for _, zone := range zones { + + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *zone.ID, + map[string]interface{}{}, + ), + ) + + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_public_ip_enumerator.go b/enumeration/remote/azurerm/azurerm_public_ip_enumerator.go new file mode 100644 index 00000000..9b9dfe74 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_public_ip_enumerator.go @@ -0,0 +1,48 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermPublicIPEnumerator struct { + repository repository.NetworkRepository + factory resource.ResourceFactory +} + +func NewAzurermPublicIPEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermPublicIPEnumerator { + return &AzurermPublicIPEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermPublicIPEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzurePublicIPResourceType +} + +func (e *AzurermPublicIPEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.ListAllPublicIPAddresses() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *res.ID, + map[string]interface{}{ + "name": *res.Name, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_resource_group_enumerator.go b/enumeration/remote/azurerm/azurerm_resource_group_enumerator.go new file mode 100644 index 00000000..07a0a23b --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_resource_group_enumerator.go @@ -0,0 +1,47 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermResourceGroupEnumerator struct { + repository repository.ResourcesRepository + factory resource.ResourceFactory +} + +func NewAzurermResourceGroupEnumerator(repo repository.ResourcesRepository, factory resource.ResourceFactory) *AzurermResourceGroupEnumerator { + return &AzurermResourceGroupEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermResourceGroupEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzureResourceGroupResourceType +} + +func (e *AzurermResourceGroupEnumerator) Enumerate() ([]*resource.Resource, error) { + groups, err := e.repository.ListAllResourceGroups() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0) + for _, group := range groups { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *group.ID, + map[string]interface{}{ + "name": *group.Name, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_route_enumerator.go b/enumeration/remote/azurerm/azurerm_route_enumerator.go new file mode 100644 index 00000000..72883051 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_route_enumerator.go @@ -0,0 +1,52 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermRouteEnumerator struct { + repository repository.NetworkRepository + factory resource.ResourceFactory +} + +func NewAzurermRouteEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermRouteEnumerator { + return &AzurermRouteEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermRouteEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzureRouteResourceType +} + +func (e *AzurermRouteEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.ListAllRouteTables() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzureRouteTableResourceType) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + for _, route := range res.Properties.Routes { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *route.ID, + map[string]interface{}{ + "name": *route.Name, + "route_table_name": *res.Name, + }, + ), + ) + } + + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_route_table_enumerator.go b/enumeration/remote/azurerm/azurerm_route_table_enumerator.go new file mode 100644 index 00000000..85a5bec4 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_route_table_enumerator.go @@ -0,0 +1,48 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermRouteTableEnumerator struct { + repository repository.NetworkRepository + factory resource.ResourceFactory +} + +func NewAzurermRouteTableEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermRouteTableEnumerator { + return &AzurermRouteTableEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermRouteTableEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzureRouteTableResourceType +} + +func (e *AzurermRouteTableEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.ListAllRouteTables() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *res.ID, + map[string]interface{}{ + "name": *res.Name, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_ssh_public_key_enumerator.go b/enumeration/remote/azurerm/azurerm_ssh_public_key_enumerator.go new file mode 100644 index 00000000..0ea0326d --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_ssh_public_key_enumerator.go @@ -0,0 +1,48 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermSSHPublicKeyEnumerator struct { + repository repository.ComputeRepository + factory resource.ResourceFactory +} + +func NewAzurermSSHPublicKeyEnumerator(repo repository.ComputeRepository, factory resource.ResourceFactory) *AzurermSSHPublicKeyEnumerator { + return &AzurermSSHPublicKeyEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermSSHPublicKeyEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzureSSHPublicKeyResourceType +} + +func (e *AzurermSSHPublicKeyEnumerator) Enumerate() ([]*resource.Resource, error) { + keys, err := e.repository.ListAllSSHPublicKeys() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(keys)) + + for _, res := range keys { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *res.ID, + map[string]interface{}{ + "name": *res.Name, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_storage_account_enumerator.go b/enumeration/remote/azurerm/azurerm_storage_account_enumerator.go new file mode 100644 index 00000000..0b1a8947 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_storage_account_enumerator.go @@ -0,0 +1,46 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermStorageAccountEnumerator struct { + repository repository.StorageRespository + factory resource.ResourceFactory +} + +func NewAzurermStorageAccountEnumerator(repo repository.StorageRespository, factory resource.ResourceFactory) *AzurermStorageAccountEnumerator { + return &AzurermStorageAccountEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermStorageAccountEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzureStorageAccountResourceType +} + +func (e *AzurermStorageAccountEnumerator) Enumerate() ([]*resource.Resource, error) { + accounts, err := e.repository.ListAllStorageAccount() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(accounts)) + + for _, account := range accounts { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *account.ID, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_storage_container_enumerator.go b/enumeration/remote/azurerm/azurerm_storage_container_enumerator.go new file mode 100644 index 00000000..58729e6e --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_storage_container_enumerator.go @@ -0,0 +1,54 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermStorageContainerEnumerator struct { + repository repository.StorageRespository + factory resource.ResourceFactory +} + +func NewAzurermStorageContainerEnumerator(repo repository.StorageRespository, factory resource.ResourceFactory) *AzurermStorageContainerEnumerator { + return &AzurermStorageContainerEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermStorageContainerEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzureStorageContainerResourceType +} + +func (e *AzurermStorageContainerEnumerator) Enumerate() ([]*resource.Resource, error) { + + accounts, err := e.repository.ListAllStorageAccount() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzureStorageAccountResourceType) + } + + results := make([]*resource.Resource, 0) + + for _, account := range accounts { + containers, err := e.repository.ListAllStorageContainer(account) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + for _, container := range containers { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + container, + map[string]interface{}{}, + ), + ) + } + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_subnets_enumerator.go b/enumeration/remote/azurerm/azurerm_subnets_enumerator.go new file mode 100644 index 00000000..4f2bffa5 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_subnets_enumerator.go @@ -0,0 +1,51 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermSubnetEnumerator struct { + repository repository.NetworkRepository + factory resource.ResourceFactory +} + +func NewAzurermSubnetEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermSubnetEnumerator { + return &AzurermSubnetEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermSubnetEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzureSubnetResourceType +} + +func (e *AzurermSubnetEnumerator) Enumerate() ([]*resource.Resource, error) { + networks, err := e.repository.ListAllVirtualNetworks() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzureVirtualNetworkResourceType) + } + + results := make([]*resource.Resource, 0) + for _, network := range networks { + resources, err := e.repository.ListAllSubnets(network) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + for _, res := range resources { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *res.ID, + map[string]interface{}{}, + ), + ) + } + } + + return results, err +} diff --git a/enumeration/remote/azurerm/azurerm_virtual_network_enumerator.go b/enumeration/remote/azurerm/azurerm_virtual_network_enumerator.go new file mode 100644 index 00000000..b55c8a78 --- /dev/null +++ b/enumeration/remote/azurerm/azurerm_virtual_network_enumerator.go @@ -0,0 +1,48 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) + +type AzurermVirtualNetworkEnumerator struct { + repository repository.NetworkRepository + factory resource.ResourceFactory +} + +func NewAzurermVirtualNetworkEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermVirtualNetworkEnumerator { + return &AzurermVirtualNetworkEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *AzurermVirtualNetworkEnumerator) SupportedType() resource.ResourceType { + return azurerm.AzureVirtualNetworkResourceType +} + +func (e *AzurermVirtualNetworkEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.ListAllVirtualNetworks() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + *res.ID, + map[string]interface{}{ + "name": *res.Name, + }, + ), + ) + } + + return results, err +} diff --git a/pkg/remote/azurerm/common/config.go b/enumeration/remote/azurerm/common/config.go similarity index 100% rename from pkg/remote/azurerm/common/config.go rename to enumeration/remote/azurerm/common/config.go diff --git a/enumeration/remote/azurerm/init.go b/enumeration/remote/azurerm/init.go new file mode 100644 index 00000000..a0eb4480 --- /dev/null +++ b/enumeration/remote/azurerm/init.go @@ -0,0 +1,105 @@ +package azurerm + +import ( + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/alerter" + repository2 "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" + "github.com/snyk/driftctl/enumeration/terraform" +) + +func Init( + version string, + alerter *alerter.Alerter, + providerLibrary *terraform.ProviderLibrary, + remoteLibrary *common2.RemoteLibrary, + progress enumeration.ProgressCounter, + resourceSchemaRepository *resource.SchemaRepository, + factory resource.ResourceFactory, + configDir string) error { + + provider, err := NewAzureTerraformProvider(version, progress, configDir) + if err != nil { + return err + } + err = provider.CheckCredentialsExist() + if err != nil { + return err + } + err = provider.Init() + if err != nil { + return err + } + + providerConfig := provider.GetConfig() + cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) + if err != nil { + return err + } + clientOptions := &arm.ClientOptions{} + + c := cache.New(100) + + storageAccountRepo := repository2.NewStorageRepository(cred, clientOptions, providerConfig, c) + networkRepo := repository2.NewNetworkRepository(cred, clientOptions, providerConfig, c) + resourcesRepo := repository2.NewResourcesRepository(cred, clientOptions, providerConfig, c) + containerRegistryRepo := repository2.NewContainerRegistryRepository(cred, clientOptions, providerConfig, c) + postgresqlRepo := repository2.NewPostgresqlRepository(cred, clientOptions, providerConfig, c) + privateDNSRepo := repository2.NewPrivateDNSRepository(cred, clientOptions, providerConfig, c) + computeRepo := repository2.NewComputeRepository(cred, clientOptions, providerConfig, c) + + providerLibrary.AddProvider(terraform.AZURE, provider) + deserializer := resource.NewDeserializer(factory) + + remoteLibrary.AddEnumerator(NewAzurermStorageAccountEnumerator(storageAccountRepo, factory)) + remoteLibrary.AddEnumerator(NewAzurermStorageContainerEnumerator(storageAccountRepo, factory)) + remoteLibrary.AddEnumerator(NewAzurermVirtualNetworkEnumerator(networkRepo, factory)) + remoteLibrary.AddEnumerator(NewAzurermRouteTableEnumerator(networkRepo, factory)) + remoteLibrary.AddEnumerator(NewAzurermRouteEnumerator(networkRepo, factory)) + remoteLibrary.AddEnumerator(NewAzurermResourceGroupEnumerator(resourcesRepo, factory)) + remoteLibrary.AddEnumerator(NewAzurermSubnetEnumerator(networkRepo, factory)) + remoteLibrary.AddEnumerator(NewAzurermContainerRegistryEnumerator(containerRegistryRepo, factory)) + remoteLibrary.AddEnumerator(NewAzurermFirewallsEnumerator(networkRepo, factory)) + remoteLibrary.AddEnumerator(NewAzurermPostgresqlServerEnumerator(postgresqlRepo, factory)) + remoteLibrary.AddEnumerator(NewAzurermPublicIPEnumerator(networkRepo, factory)) + remoteLibrary.AddEnumerator(NewAzurermPostgresqlDatabaseEnumerator(postgresqlRepo, factory)) + remoteLibrary.AddEnumerator(NewAzurermNetworkSecurityGroupEnumerator(networkRepo, factory)) + remoteLibrary.AddDetailsFetcher(azurerm.AzureNetworkSecurityGroupResourceType, common2.NewGenericDetailsFetcher(azurerm.AzureNetworkSecurityGroupResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewAzurermLoadBalancerEnumerator(networkRepo, factory)) + remoteLibrary.AddEnumerator(NewAzurermLoadBalancerRuleEnumerator(networkRepo, factory)) + remoteLibrary.AddDetailsFetcher(azurerm.AzureLoadBalancerRuleResourceType, common2.NewGenericDetailsFetcher(azurerm.AzureLoadBalancerRuleResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewAzurermPrivateDNSZoneEnumerator(privateDNSRepo, factory)) + remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSZoneResourceType, common2.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSZoneResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewAzurermPrivateDNSARecordEnumerator(privateDNSRepo, factory)) + remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSARecordResourceType, common2.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSARecordResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewAzurermPrivateDNSAAAARecordEnumerator(privateDNSRepo, factory)) + remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSAAAARecordResourceType, common2.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSAAAARecordResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewAzurermPrivateDNSMXRecordEnumerator(privateDNSRepo, factory)) + remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSMXRecordResourceType, common2.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSMXRecordResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewAzurermPrivateDNSCNameRecordEnumerator(privateDNSRepo, factory)) + remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSCNameRecordResourceType, common2.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSCNameRecordResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewAzurermPrivateDNSPTRRecordEnumerator(privateDNSRepo, factory)) + remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSPTRRecordResourceType, common2.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSPTRRecordResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewAzurermPrivateDNSSRVRecordEnumerator(privateDNSRepo, factory)) + remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSSRVRecordResourceType, common2.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSSRVRecordResourceType, provider, deserializer)) + remoteLibrary.AddEnumerator(NewAzurermPrivateDNSTXTRecordEnumerator(privateDNSRepo, factory)) + remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSTXTRecordResourceType, common2.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSTXTRecordResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewAzurermImageEnumerator(computeRepo, factory)) + remoteLibrary.AddEnumerator(NewAzurermSSHPublicKeyEnumerator(computeRepo, factory)) + remoteLibrary.AddDetailsFetcher(azurerm.AzureSSHPublicKeyResourceType, common2.NewGenericDetailsFetcher(azurerm.AzureSSHPublicKeyResourceType, provider, deserializer)) + + err = resourceSchemaRepository.Init(terraform.AZURE, provider.Version(), provider.Schema()) + if err != nil { + return err + } + azurerm.InitResourcesMetadata(resourceSchemaRepository) + + return nil +} diff --git a/enumeration/remote/azurerm/provider.go b/enumeration/remote/azurerm/provider.go new file mode 100644 index 00000000..f516fa49 --- /dev/null +++ b/enumeration/remote/azurerm/provider.go @@ -0,0 +1,95 @@ +package azurerm + +import ( + "context" + "errors" + "os" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/azurerm/common" + "github.com/snyk/driftctl/enumeration/remote/terraform" + terraform2 "github.com/snyk/driftctl/enumeration/terraform" +) + +type AzureTerraformProvider struct { + *terraform.TerraformProvider + name string + version string +} + +func NewAzureTerraformProvider(version string, progress enumeration.ProgressCounter, configDir string) (*AzureTerraformProvider, error) { + if version == "" { + version = "2.71.0" + } + // Just pass your version and name + p := &AzureTerraformProvider{ + version: version, + name: terraform2.AZURE, + } + // Use TerraformProviderInstaller to retrieve the provider if needed + installer, err := terraform2.NewProviderInstaller(terraform2.ProviderConfig{ + Key: p.name, + Version: version, + ConfigDir: configDir, + }) + if err != nil { + return nil, err + } + + tfProvider, err := terraform.NewTerraformProvider(installer, terraform.TerraformProviderConfig{ + Name: p.name, + GetProviderConfig: func(_ string) interface{} { + c := p.GetConfig() + return map[string]interface{}{ + "subscription_id": c.SubscriptionID, + "tenant_id": c.TenantID, + "client_id": c.ClientID, + "client_secret": c.ClientSecret, + "skip_provider_registration": true, + } + }, + }, progress) + if err != nil { + return nil, err + } + p.TerraformProvider = tfProvider + return p, err +} + +func (p *AzureTerraformProvider) GetConfig() common.AzureProviderConfig { + return common.AzureProviderConfig{ + SubscriptionID: os.Getenv("AZURE_SUBSCRIPTION_ID"), + TenantID: os.Getenv("AZURE_TENANT_ID"), + ClientID: os.Getenv("AZURE_CLIENT_ID"), + ClientSecret: os.Getenv("AZURE_CLIENT_SECRET"), + } +} + +func (p *AzureTerraformProvider) Name() string { + return p.name +} + +func (p *AzureTerraformProvider) Version() string { + return p.version +} + +func (p *AzureTerraformProvider) CheckCredentialsExist() error { + cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) + if err != nil { + return err + } + + _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"https://management.azure.com//.default"}}) + if err != nil { + return errors.New("Could not find any authentication method for Azure.\n" + + "For more information, please check the official Azure documentation: https://docs.microsoft.com/en-us/azure/developer/go/azure-sdk-authorization#use-environment-based-authentication") + } + + if p.GetConfig().SubscriptionID == "" { + return errors.New("Please provide an Azure subscription ID by setting the `AZURE_SUBSCRIPTION_ID` environment variable.") + } + + return nil +} diff --git a/enumeration/remote/azurerm/repository/compute.go b/enumeration/remote/azurerm/repository/compute.go new file mode 100644 index 00000000..41499532 --- /dev/null +++ b/enumeration/remote/azurerm/repository/compute.go @@ -0,0 +1,110 @@ +package repository + +import ( + "context" + "github.com/snyk/driftctl/enumeration/remote/azurerm/common" + "github.com/snyk/driftctl/enumeration/remote/cache" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" +) + +type ComputeRepository interface { + ListAllImages() ([]*armcompute.Image, error) + ListAllSSHPublicKeys() ([]*armcompute.SSHPublicKeyResource, error) +} + +type imagesListPager interface { + pager + PageResponse() armcompute.ImagesListResponse +} + +type imagesClient interface { + List(options *armcompute.ImagesListOptions) imagesListPager +} + +type imagesClientImpl struct { + client *armcompute.ImagesClient +} + +func (c imagesClientImpl) List(options *armcompute.ImagesListOptions) imagesListPager { + return c.client.List(options) +} + +type sshPublicKeyListPager interface { + pager + PageResponse() armcompute.SSHPublicKeysListBySubscriptionResponse +} + +type sshPublicKeyClient interface { + ListBySubscription(options *armcompute.SSHPublicKeysListBySubscriptionOptions) sshPublicKeyListPager +} + +type sshPublicKeyClientImpl struct { + client *armcompute.SSHPublicKeysClient +} + +func (c sshPublicKeyClientImpl) ListBySubscription(options *armcompute.SSHPublicKeysListBySubscriptionOptions) sshPublicKeyListPager { + return c.client.ListBySubscription(options) +} + +type computeRepository struct { + imagesClient imagesClient + sshPublicKeyClient sshPublicKeyClient + cache cache.Cache +} + +func NewComputeRepository(cred azcore.TokenCredential, options *arm.ClientOptions, config common.AzureProviderConfig, cache cache.Cache) *computeRepository { + return &computeRepository{ + &imagesClientImpl{armcompute.NewImagesClient(config.SubscriptionID, cred, options)}, + &sshPublicKeyClientImpl{armcompute.NewSSHPublicKeysClient(config.SubscriptionID, cred, options)}, + cache, + } +} + +func (s *computeRepository) ListAllImages() ([]*armcompute.Image, error) { + cacheKey := "computeListAllImages" + if v := s.cache.Get(cacheKey); v != nil { + return v.([]*armcompute.Image), nil + } + + pager := s.imagesClient.List(nil) + results := make([]*armcompute.Image, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + results = append(results, resp.Value...) + } + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put(cacheKey, results) + return results, nil +} + +func (s *computeRepository) ListAllSSHPublicKeys() ([]*armcompute.SSHPublicKeyResource, error) { + cacheKey := "computeListAllSSHPublicKeys" + if v := s.cache.Get(cacheKey); v != nil { + return v.([]*armcompute.SSHPublicKeyResource), nil + } + + pager := s.sshPublicKeyClient.ListBySubscription(nil) + results := make([]*armcompute.SSHPublicKeyResource, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + results = append(results, resp.Value...) + } + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put(cacheKey, results) + return results, nil +} diff --git a/enumeration/remote/azurerm/repository/compute_test.go b/enumeration/remote/azurerm/repository/compute_test.go new file mode 100644 index 00000000..2769be5d --- /dev/null +++ b/enumeration/remote/azurerm/repository/compute_test.go @@ -0,0 +1,275 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "reflect" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func Test_Compute_ListAllImages(t *testing.T) { + expectedResults := []*armcompute.Image{ + { + Resource: armcompute.Resource{ + ID: to.StringPtr("/subscriptions/2c361f34-30fb-47ae-a227-83a5d3a26c66/resourceGroups/tfvmex-resources/providers/Microsoft.Compute/images/image1"), + Name: to.StringPtr("image1"), + }, + }, + { + Resource: armcompute.Resource{ + ID: to.StringPtr("/subscriptions/2c361f34-30fb-47ae-a227-83a5d3a26c66/resourceGroups/tfvmex-resources/providers/Microsoft.Compute/images/image2"), + Name: to.StringPtr("image2"), + }, + }, + { + Resource: armcompute.Resource{ + ID: to.StringPtr("/subscriptions/2c361f34-30fb-47ae-a227-83a5d3a26c66/resourceGroups/tfvmex-resources/providers/Microsoft.Compute/images/image3"), + Name: to.StringPtr("image3"), + }, + }, + } + + testcases := []struct { + name string + mocks func(*mockImagesListPager, *cache.MockCache) + expected []*armcompute.Image + wantErr string + }{ + { + name: "should return images", + mocks: func(mockPager *mockImagesListPager, mockCache *cache.MockCache) { + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armcompute.ImagesListResponse{ + ImagesListResult: armcompute.ImagesListResult{ + ImageListResult: armcompute.ImageListResult{ + Value: expectedResults[:2], + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armcompute.ImagesListResponse{ + ImagesListResult: armcompute.ImagesListResult{ + ImageListResult: armcompute.ImageListResult{ + Value: expectedResults[2:], + }, + }, + }).Times(1) + + mockCache.On("Get", "computeListAllImages").Return(nil).Times(1) + mockCache.On("Put", "computeListAllImages", expectedResults).Return(false).Times(1) + }, + expected: expectedResults, + }, + { + name: "should hit cache and return images", + mocks: func(mockPager *mockImagesListPager, mockCache *cache.MockCache) { + mockCache.On("Get", "computeListAllImages").Return(expectedResults).Times(1) + }, + expected: expectedResults, + }, + { + name: "should return remote error", + mocks: func(mockPager *mockImagesListPager, mockCache *cache.MockCache) { + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armcompute.ImagesListResponse{ + ImagesListResult: armcompute.ImagesListResult{ + ImageListResult: armcompute.ImageListResult{ + Value: []*armcompute.Image{}, + }, + }, + }).Times(1) + mockPager.On("Err").Return(errors.New("remote error")).Times(1) + + mockCache.On("Get", "computeListAllImages").Return(nil).Times(1) + }, + wantErr: "remote error", + }, + { + name: "should return remote error after fetching all pages", + mocks: func(mockPager *mockImagesListPager, mockCache *cache.MockCache) { + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armcompute.ImagesListResponse{ + ImagesListResult: armcompute.ImagesListResult{ + ImageListResult: armcompute.ImageListResult{ + Value: []*armcompute.Image{}, + }, + }, + }).Times(1) + mockPager.On("Err").Return(nil).Times(1) + mockPager.On("Err").Return(errors.New("remote error")).Times(1) + + mockCache.On("Get", "computeListAllImages").Return(nil).Times(1) + }, + wantErr: "remote error", + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + fakeClient := &mockImagesClient{} + mockPager := &mockImagesListPager{} + mockCache := &cache.MockCache{} + + fakeClient.On("List", mock.Anything).Maybe().Return(mockPager) + + tt.mocks(mockPager, mockCache) + + s := &computeRepository{ + imagesClient: fakeClient, + cache: mockCache, + } + got, err := s.ListAllImages() + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + } else { + assert.Nil(t, err) + } + + fakeClient.AssertExpectations(t) + mockPager.AssertExpectations(t) + mockCache.AssertExpectations(t) + + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("ListAllResourceGroups() got = %v, want %v", got, tt.expected) + } + }) + } +} + +func Test_Compute_ListAllSSHPublicKeys(t *testing.T) { + expectedResults := []*armcompute.SSHPublicKeyResource{ + { + Resource: armcompute.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/TESTRESGROUP/providers/Microsoft.Compute/sshPublicKeys/key1"), + Name: to.StringPtr("key1"), + }, + }, + { + Resource: armcompute.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/TESTRESGROUP/providers/Microsoft.Compute/sshPublicKeys/key2"), + Name: to.StringPtr("key2"), + }, + }, + { + Resource: armcompute.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/TESTRESGROUP/providers/Microsoft.Compute/sshPublicKeys/key3"), + Name: to.StringPtr("key3"), + }, + }, + } + + testcases := []struct { + name string + mocks func(*mockSshPublicKeyListPager, *cache.MockCache) + expected []*armcompute.SSHPublicKeyResource + wantErr string + }{ + { + name: "should return SSH public keys", + mocks: func(mockPager *mockSshPublicKeyListPager, mockCache *cache.MockCache) { + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armcompute.SSHPublicKeysListBySubscriptionResponse{ + SSHPublicKeysListBySubscriptionResult: armcompute.SSHPublicKeysListBySubscriptionResult{ + SSHPublicKeysGroupListResult: armcompute.SSHPublicKeysGroupListResult{ + Value: expectedResults[:2], + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armcompute.SSHPublicKeysListBySubscriptionResponse{ + SSHPublicKeysListBySubscriptionResult: armcompute.SSHPublicKeysListBySubscriptionResult{ + SSHPublicKeysGroupListResult: armcompute.SSHPublicKeysGroupListResult{ + Value: expectedResults[2:], + }, + }, + }).Times(1) + + mockCache.On("Get", "computeListAllSSHPublicKeys").Return(nil).Times(1) + mockCache.On("Put", "computeListAllSSHPublicKeys", expectedResults).Return(false).Times(1) + }, + expected: expectedResults, + }, + { + name: "should hit cache and return SSH public keys", + mocks: func(mockPager *mockSshPublicKeyListPager, mockCache *cache.MockCache) { + mockCache.On("Get", "computeListAllSSHPublicKeys").Return(expectedResults).Times(1) + }, + expected: expectedResults, + }, + { + name: "should return remote error", + mocks: func(mockPager *mockSshPublicKeyListPager, mockCache *cache.MockCache) { + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armcompute.SSHPublicKeysListBySubscriptionResponse{ + SSHPublicKeysListBySubscriptionResult: armcompute.SSHPublicKeysListBySubscriptionResult{ + SSHPublicKeysGroupListResult: armcompute.SSHPublicKeysGroupListResult{ + Value: []*armcompute.SSHPublicKeyResource{}, + }, + }, + }).Times(1) + mockPager.On("Err").Return(errors.New("remote error")).Times(1) + + mockCache.On("Get", "computeListAllSSHPublicKeys").Return(nil).Times(1) + }, + wantErr: "remote error", + }, + { + name: "should return remote error after fetching all pages", + mocks: func(mockPager *mockSshPublicKeyListPager, mockCache *cache.MockCache) { + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armcompute.SSHPublicKeysListBySubscriptionResponse{ + SSHPublicKeysListBySubscriptionResult: armcompute.SSHPublicKeysListBySubscriptionResult{ + SSHPublicKeysGroupListResult: armcompute.SSHPublicKeysGroupListResult{ + Value: []*armcompute.SSHPublicKeyResource{}, + }, + }, + }).Times(1) + mockPager.On("Err").Return(nil).Times(1) + mockPager.On("Err").Return(errors.New("remote error")).Times(1) + + mockCache.On("Get", "computeListAllSSHPublicKeys").Return(nil).Times(1) + }, + wantErr: "remote error", + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + fakeClient := &mockSshPublicKeyClient{} + mockPager := &mockSshPublicKeyListPager{} + mockCache := &cache.MockCache{} + + fakeClient.On("ListBySubscription", mock.Anything).Maybe().Return(mockPager) + + tt.mocks(mockPager, mockCache) + + s := &computeRepository{ + sshPublicKeyClient: fakeClient, + cache: mockCache, + } + got, err := s.ListAllSSHPublicKeys() + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + } else { + assert.Nil(t, err) + } + + fakeClient.AssertExpectations(t) + mockPager.AssertExpectations(t) + mockCache.AssertExpectations(t) + + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("ListAllResourceGroups() got = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/enumeration/remote/azurerm/repository/containerregistry.go b/enumeration/remote/azurerm/repository/containerregistry.go new file mode 100644 index 00000000..99245a9b --- /dev/null +++ b/enumeration/remote/azurerm/repository/containerregistry.go @@ -0,0 +1,69 @@ +package repository + +import ( + "context" + "github.com/snyk/driftctl/enumeration/remote/azurerm/common" + "github.com/snyk/driftctl/enumeration/remote/cache" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry" +) + +type ContainerRegistryRepository interface { + ListAllContainerRegistries() ([]*armcontainerregistry.Registry, error) +} + +type registryClient interface { + List(options *armcontainerregistry.RegistriesListOptions) registryListAllPager +} + +type registryListAllPager interface { + pager + PageResponse() armcontainerregistry.RegistriesListResponse +} + +type registryClientImpl struct { + client *armcontainerregistry.RegistriesClient +} + +func (c registryClientImpl) List(options *armcontainerregistry.RegistriesListOptions) registryListAllPager { + return c.client.List(options) +} + +type containerRegistryRepository struct { + registryClient registryClient + cache cache.Cache +} + +func NewContainerRegistryRepository(cred azcore.TokenCredential, options *arm.ClientOptions, config common.AzureProviderConfig, cache cache.Cache) *containerRegistryRepository { + return &containerRegistryRepository{ + ®istryClientImpl{client: armcontainerregistry.NewRegistriesClient(config.SubscriptionID, cred, options)}, + cache, + } +} + +func (s *containerRegistryRepository) ListAllContainerRegistries() ([]*armcontainerregistry.Registry, error) { + + if v := s.cache.Get("ListAllContainerRegistries"); v != nil { + return v.([]*armcontainerregistry.Registry), nil + } + + pager := s.registryClient.List(nil) + results := make([]*armcontainerregistry.Registry, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + results = append(results, resp.Value...) + } + + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put("ListAllContainerRegistries", results) + + return results, nil +} diff --git a/enumeration/remote/azurerm/repository/containerregistry_test.go b/enumeration/remote/azurerm/repository/containerregistry_test.go new file mode 100644 index 00000000..a204b294 --- /dev/null +++ b/enumeration/remote/azurerm/repository/containerregistry_test.go @@ -0,0 +1,144 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "reflect" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func Test_Resources_ListAllContainerRegistries(t *testing.T) { + expectedResults := []*armcontainerregistry.Registry{ + { + Resource: armcontainerregistry.Resource{ + ID: to.StringPtr("/subscriptions/2c361f34-30fb-47ae-a227-83a5d3a26c66/resourceGroups/my-group/providers/Microsoft.ContainerRegistry/registries/containerRegistry1"), + Name: to.StringPtr("containerRegistry1"), + }, + }, + { + Resource: armcontainerregistry.Resource{ + ID: to.StringPtr("/subscriptions/2c361f34-30fb-47ae-a227-83a5d3a26c66/resourceGroups/my-group/providers/Microsoft.ContainerRegistry/registries/containerRegistry1"), + Name: to.StringPtr("containerRegistry2"), + }, + }, + { + Resource: armcontainerregistry.Resource{ + ID: to.StringPtr("/subscriptions/008b5f48-1b66-4d92-a6b6-d215b4c9b473/-/resource-3"), + Name: to.StringPtr("resource-3"), + }, + }, + } + + testcases := []struct { + name string + mocks func(*mockRegistryListAllPager, *cache.MockCache) + expected []*armcontainerregistry.Registry + wantErr string + }{ + { + name: "should return container registries", + mocks: func(mockPager *mockRegistryListAllPager, mockCache *cache.MockCache) { + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armcontainerregistry.RegistriesListResponse{ + RegistriesListResult: armcontainerregistry.RegistriesListResult{ + RegistryListResult: armcontainerregistry.RegistryListResult{ + Value: expectedResults[:2], + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armcontainerregistry.RegistriesListResponse{ + RegistriesListResult: armcontainerregistry.RegistriesListResult{ + RegistryListResult: armcontainerregistry.RegistryListResult{ + Value: expectedResults[2:], + }, + }, + }).Times(1) + + mockCache.On("Get", "ListAllContainerRegistries").Return(nil).Times(1) + mockCache.On("Put", "ListAllContainerRegistries", expectedResults).Return(false).Times(1) + }, + expected: expectedResults, + }, + { + name: "should hit cache and return container registries", + mocks: func(mockPager *mockRegistryListAllPager, mockCache *cache.MockCache) { + mockCache.On("Get", "ListAllContainerRegistries").Return(expectedResults).Times(1) + }, + expected: expectedResults, + }, + { + name: "should return remote error", + mocks: func(mockPager *mockRegistryListAllPager, mockCache *cache.MockCache) { + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armcontainerregistry.RegistriesListResponse{ + RegistriesListResult: armcontainerregistry.RegistriesListResult{ + RegistryListResult: armcontainerregistry.RegistryListResult{ + Value: []*armcontainerregistry.Registry{}, + }, + }, + }).Times(1) + mockPager.On("Err").Return(errors.New("remote error")).Times(1) + + mockCache.On("Get", "ListAllContainerRegistries").Return(nil).Times(1) + }, + wantErr: "remote error", + }, + { + name: "should return remote error after fetching all pages", + mocks: func(mockPager *mockRegistryListAllPager, mockCache *cache.MockCache) { + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armcontainerregistry.RegistriesListResponse{ + RegistriesListResult: armcontainerregistry.RegistriesListResult{ + RegistryListResult: armcontainerregistry.RegistryListResult{ + Value: []*armcontainerregistry.Registry{}, + }, + }, + }).Times(1) + mockPager.On("Err").Return(nil).Times(1) + mockPager.On("Err").Return(errors.New("remote error")).Times(1) + + mockCache.On("Get", "ListAllContainerRegistries").Return(nil).Times(1) + }, + wantErr: "remote error", + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + fakeClient := &mockRegistryClient{} + mockPager := &mockRegistryListAllPager{} + mockCache := &cache.MockCache{} + + fakeClient.On("List", mock.Anything).Maybe().Return(mockPager) + + tt.mocks(mockPager, mockCache) + + s := &containerRegistryRepository{ + registryClient: fakeClient, + cache: mockCache, + } + got, err := s.ListAllContainerRegistries() + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + } else { + assert.Nil(t, err) + } + + fakeClient.AssertExpectations(t) + mockPager.AssertExpectations(t) + mockCache.AssertExpectations(t) + + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("ListAllResourceGroups() got = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/pkg/remote/azurerm/repository/mock_ComputeRepository.go b/enumeration/remote/azurerm/repository/mock_ComputeRepository.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_ComputeRepository.go rename to enumeration/remote/azurerm/repository/mock_ComputeRepository.go diff --git a/pkg/remote/azurerm/repository/mock_ContainerRegistryRepository.go b/enumeration/remote/azurerm/repository/mock_ContainerRegistryRepository.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_ContainerRegistryRepository.go rename to enumeration/remote/azurerm/repository/mock_ContainerRegistryRepository.go diff --git a/pkg/remote/azurerm/repository/mock_NetworkRepository.go b/enumeration/remote/azurerm/repository/mock_NetworkRepository.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_NetworkRepository.go rename to enumeration/remote/azurerm/repository/mock_NetworkRepository.go diff --git a/pkg/remote/azurerm/repository/mock_PostgresqlRespository.go b/enumeration/remote/azurerm/repository/mock_PostgresqlRespository.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_PostgresqlRespository.go rename to enumeration/remote/azurerm/repository/mock_PostgresqlRespository.go diff --git a/pkg/remote/azurerm/repository/mock_PrivateDNSRepository.go b/enumeration/remote/azurerm/repository/mock_PrivateDNSRepository.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_PrivateDNSRepository.go rename to enumeration/remote/azurerm/repository/mock_PrivateDNSRepository.go diff --git a/pkg/remote/azurerm/repository/mock_ResourcesRepository.go b/enumeration/remote/azurerm/repository/mock_ResourcesRepository.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_ResourcesRepository.go rename to enumeration/remote/azurerm/repository/mock_ResourcesRepository.go diff --git a/pkg/remote/azurerm/repository/mock_StorageRespository.go b/enumeration/remote/azurerm/repository/mock_StorageRespository.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_StorageRespository.go rename to enumeration/remote/azurerm/repository/mock_StorageRespository.go diff --git a/pkg/remote/azurerm/repository/mock_blobContainerClient.go b/enumeration/remote/azurerm/repository/mock_blobContainerClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_blobContainerClient.go rename to enumeration/remote/azurerm/repository/mock_blobContainerClient.go diff --git a/pkg/remote/azurerm/repository/mock_blobContainerListPager.go b/enumeration/remote/azurerm/repository/mock_blobContainerListPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_blobContainerListPager.go rename to enumeration/remote/azurerm/repository/mock_blobContainerListPager.go diff --git a/pkg/remote/azurerm/repository/mock_firewallsClient.go b/enumeration/remote/azurerm/repository/mock_firewallsClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_firewallsClient.go rename to enumeration/remote/azurerm/repository/mock_firewallsClient.go diff --git a/pkg/remote/azurerm/repository/mock_firewallsListAllPager.go b/enumeration/remote/azurerm/repository/mock_firewallsListAllPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_firewallsListAllPager.go rename to enumeration/remote/azurerm/repository/mock_firewallsListAllPager.go diff --git a/pkg/remote/azurerm/repository/mock_imagesClient.go b/enumeration/remote/azurerm/repository/mock_imagesClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_imagesClient.go rename to enumeration/remote/azurerm/repository/mock_imagesClient.go diff --git a/pkg/remote/azurerm/repository/mock_imagesListPager.go b/enumeration/remote/azurerm/repository/mock_imagesListPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_imagesListPager.go rename to enumeration/remote/azurerm/repository/mock_imagesListPager.go diff --git a/pkg/remote/azurerm/repository/mock_loadBalancerRulesClient.go b/enumeration/remote/azurerm/repository/mock_loadBalancerRulesClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_loadBalancerRulesClient.go rename to enumeration/remote/azurerm/repository/mock_loadBalancerRulesClient.go diff --git a/pkg/remote/azurerm/repository/mock_loadBalancerRulesListAllPager.go b/enumeration/remote/azurerm/repository/mock_loadBalancerRulesListAllPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_loadBalancerRulesListAllPager.go rename to enumeration/remote/azurerm/repository/mock_loadBalancerRulesListAllPager.go diff --git a/pkg/remote/azurerm/repository/mock_loadBalancersClient.go b/enumeration/remote/azurerm/repository/mock_loadBalancersClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_loadBalancersClient.go rename to enumeration/remote/azurerm/repository/mock_loadBalancersClient.go diff --git a/pkg/remote/azurerm/repository/mock_loadBalancersListAllPager.go b/enumeration/remote/azurerm/repository/mock_loadBalancersListAllPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_loadBalancersListAllPager.go rename to enumeration/remote/azurerm/repository/mock_loadBalancersListAllPager.go diff --git a/pkg/remote/azurerm/repository/mock_networkSecurityGroupsClient.go b/enumeration/remote/azurerm/repository/mock_networkSecurityGroupsClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_networkSecurityGroupsClient.go rename to enumeration/remote/azurerm/repository/mock_networkSecurityGroupsClient.go diff --git a/pkg/remote/azurerm/repository/mock_networkSecurityGroupsListAllPager.go b/enumeration/remote/azurerm/repository/mock_networkSecurityGroupsListAllPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_networkSecurityGroupsListAllPager.go rename to enumeration/remote/azurerm/repository/mock_networkSecurityGroupsListAllPager.go diff --git a/pkg/remote/azurerm/repository/mock_postgresqlDatabaseClient.go b/enumeration/remote/azurerm/repository/mock_postgresqlDatabaseClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_postgresqlDatabaseClient.go rename to enumeration/remote/azurerm/repository/mock_postgresqlDatabaseClient.go diff --git a/pkg/remote/azurerm/repository/mock_postgresqlServersClient.go b/enumeration/remote/azurerm/repository/mock_postgresqlServersClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_postgresqlServersClient.go rename to enumeration/remote/azurerm/repository/mock_postgresqlServersClient.go diff --git a/pkg/remote/azurerm/repository/mock_privateDNSRecordSetListPager.go b/enumeration/remote/azurerm/repository/mock_privateDNSRecordSetListPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_privateDNSRecordSetListPager.go rename to enumeration/remote/azurerm/repository/mock_privateDNSRecordSetListPager.go diff --git a/pkg/remote/azurerm/repository/mock_privateDNSZoneListPager.go b/enumeration/remote/azurerm/repository/mock_privateDNSZoneListPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_privateDNSZoneListPager.go rename to enumeration/remote/azurerm/repository/mock_privateDNSZoneListPager.go diff --git a/pkg/remote/azurerm/repository/mock_privateRecordSetClient.go b/enumeration/remote/azurerm/repository/mock_privateRecordSetClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_privateRecordSetClient.go rename to enumeration/remote/azurerm/repository/mock_privateRecordSetClient.go diff --git a/pkg/remote/azurerm/repository/mock_privateZonesClient.go b/enumeration/remote/azurerm/repository/mock_privateZonesClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_privateZonesClient.go rename to enumeration/remote/azurerm/repository/mock_privateZonesClient.go diff --git a/pkg/remote/azurerm/repository/mock_publicIPAddressesClient.go b/enumeration/remote/azurerm/repository/mock_publicIPAddressesClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_publicIPAddressesClient.go rename to enumeration/remote/azurerm/repository/mock_publicIPAddressesClient.go diff --git a/pkg/remote/azurerm/repository/mock_publicIPAddressesListAllPager.go b/enumeration/remote/azurerm/repository/mock_publicIPAddressesListAllPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_publicIPAddressesListAllPager.go rename to enumeration/remote/azurerm/repository/mock_publicIPAddressesListAllPager.go diff --git a/pkg/remote/azurerm/repository/mock_registryClient.go b/enumeration/remote/azurerm/repository/mock_registryClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_registryClient.go rename to enumeration/remote/azurerm/repository/mock_registryClient.go diff --git a/pkg/remote/azurerm/repository/mock_registryListAllPager.go b/enumeration/remote/azurerm/repository/mock_registryListAllPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_registryListAllPager.go rename to enumeration/remote/azurerm/repository/mock_registryListAllPager.go diff --git a/pkg/remote/azurerm/repository/mock_resourcesClient.go b/enumeration/remote/azurerm/repository/mock_resourcesClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_resourcesClient.go rename to enumeration/remote/azurerm/repository/mock_resourcesClient.go diff --git a/pkg/remote/azurerm/repository/mock_resourcesListPager.go b/enumeration/remote/azurerm/repository/mock_resourcesListPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_resourcesListPager.go rename to enumeration/remote/azurerm/repository/mock_resourcesListPager.go diff --git a/pkg/remote/azurerm/repository/mock_routeTablesClient.go b/enumeration/remote/azurerm/repository/mock_routeTablesClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_routeTablesClient.go rename to enumeration/remote/azurerm/repository/mock_routeTablesClient.go diff --git a/pkg/remote/azurerm/repository/mock_routeTablesListAllPager.go b/enumeration/remote/azurerm/repository/mock_routeTablesListAllPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_routeTablesListAllPager.go rename to enumeration/remote/azurerm/repository/mock_routeTablesListAllPager.go diff --git a/pkg/remote/azurerm/repository/mock_sshPublicKeyClient.go b/enumeration/remote/azurerm/repository/mock_sshPublicKeyClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_sshPublicKeyClient.go rename to enumeration/remote/azurerm/repository/mock_sshPublicKeyClient.go diff --git a/pkg/remote/azurerm/repository/mock_sshPublicKeyListPager.go b/enumeration/remote/azurerm/repository/mock_sshPublicKeyListPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_sshPublicKeyListPager.go rename to enumeration/remote/azurerm/repository/mock_sshPublicKeyListPager.go diff --git a/pkg/remote/azurerm/repository/mock_storageAccountClient.go b/enumeration/remote/azurerm/repository/mock_storageAccountClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_storageAccountClient.go rename to enumeration/remote/azurerm/repository/mock_storageAccountClient.go diff --git a/pkg/remote/azurerm/repository/mock_storageAccountListPager.go b/enumeration/remote/azurerm/repository/mock_storageAccountListPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_storageAccountListPager.go rename to enumeration/remote/azurerm/repository/mock_storageAccountListPager.go diff --git a/pkg/remote/azurerm/repository/mock_subnetsClient.go b/enumeration/remote/azurerm/repository/mock_subnetsClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_subnetsClient.go rename to enumeration/remote/azurerm/repository/mock_subnetsClient.go diff --git a/pkg/remote/azurerm/repository/mock_subnetsListPager.go b/enumeration/remote/azurerm/repository/mock_subnetsListPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_subnetsListPager.go rename to enumeration/remote/azurerm/repository/mock_subnetsListPager.go diff --git a/pkg/remote/azurerm/repository/mock_virtualNetworkClient.go b/enumeration/remote/azurerm/repository/mock_virtualNetworkClient.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_virtualNetworkClient.go rename to enumeration/remote/azurerm/repository/mock_virtualNetworkClient.go diff --git a/pkg/remote/azurerm/repository/mock_virtualNetworksListAllPager.go b/enumeration/remote/azurerm/repository/mock_virtualNetworksListAllPager.go similarity index 100% rename from pkg/remote/azurerm/repository/mock_virtualNetworksListAllPager.go rename to enumeration/remote/azurerm/repository/mock_virtualNetworksListAllPager.go diff --git a/enumeration/remote/azurerm/repository/network.go b/enumeration/remote/azurerm/repository/network.go new file mode 100644 index 00000000..3c65d112 --- /dev/null +++ b/enumeration/remote/azurerm/repository/network.go @@ -0,0 +1,405 @@ +package repository + +import ( + "context" + "fmt" + "github.com/snyk/driftctl/enumeration/remote/azurerm/common" + "github.com/snyk/driftctl/enumeration/remote/cache" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + "github.com/Azure/go-autorest/autorest/azure" +) + +type NetworkRepository interface { + ListAllVirtualNetworks() ([]*armnetwork.VirtualNetwork, error) + ListAllRouteTables() ([]*armnetwork.RouteTable, error) + ListAllSubnets(virtualNetwork *armnetwork.VirtualNetwork) ([]*armnetwork.Subnet, error) + ListAllFirewalls() ([]*armnetwork.AzureFirewall, error) + ListAllPublicIPAddresses() ([]*armnetwork.PublicIPAddress, error) + ListAllSecurityGroups() ([]*armnetwork.NetworkSecurityGroup, error) + ListAllLoadBalancers() ([]*armnetwork.LoadBalancer, error) + ListLoadBalancerRules(*armnetwork.LoadBalancer) ([]*armnetwork.LoadBalancingRule, error) +} + +type publicIPAddressesClient interface { + ListAll(options *armnetwork.PublicIPAddressesListAllOptions) publicIPAddressesListAllPager +} + +type publicIPAddressesListAllPager interface { + pager + PageResponse() armnetwork.PublicIPAddressesListAllResponse +} + +type publicIPAddressesClientImpl struct { + client *armnetwork.PublicIPAddressesClient +} + +func (p publicIPAddressesClientImpl) ListAll(options *armnetwork.PublicIPAddressesListAllOptions) publicIPAddressesListAllPager { + return p.client.ListAll(options) +} + +type firewallsListAllPager interface { + pager + PageResponse() armnetwork.AzureFirewallsListAllResponse +} + +type firewallsClient interface { + ListAll(options *armnetwork.AzureFirewallsListAllOptions) firewallsListAllPager +} + +type firewallsClientImpl struct { + client *armnetwork.AzureFirewallsClient +} + +func (s firewallsClientImpl) ListAll(options *armnetwork.AzureFirewallsListAllOptions) firewallsListAllPager { + return s.client.ListAll(options) +} + +type subnetsListPager interface { + pager + PageResponse() armnetwork.SubnetsListResponse +} + +type subnetsClient interface { + List(resourceGroupName, virtualNetworkName string, options *armnetwork.SubnetsListOptions) subnetsListPager +} + +type subnetsClientImpl struct { + client *armnetwork.SubnetsClient +} + +func (s subnetsClientImpl) List(resourceGroupName, virtualNetworkName string, options *armnetwork.SubnetsListOptions) subnetsListPager { + return s.client.List(resourceGroupName, virtualNetworkName, options) +} + +type virtualNetworksClient interface { + ListAll(options *armnetwork.VirtualNetworksListAllOptions) virtualNetworksListAllPager +} + +type virtualNetworksListAllPager interface { + pager + PageResponse() armnetwork.VirtualNetworksListAllResponse +} + +type virtualNetworksClientImpl struct { + client *armnetwork.VirtualNetworksClient +} + +func (c virtualNetworksClientImpl) ListAll(options *armnetwork.VirtualNetworksListAllOptions) virtualNetworksListAllPager { + return c.client.ListAll(options) +} + +type routeTablesClient interface { + ListAll(options *armnetwork.RouteTablesListAllOptions) routeTablesListAllPager +} + +type routeTablesListAllPager interface { + pager + PageResponse() armnetwork.RouteTablesListAllResponse +} + +type routeTablesClientImpl struct { + client *armnetwork.RouteTablesClient +} + +func (c routeTablesClientImpl) ListAll(options *armnetwork.RouteTablesListAllOptions) routeTablesListAllPager { + return c.client.ListAll(options) +} + +type networkSecurityGroupsListAllPager interface { + pager + PageResponse() armnetwork.NetworkSecurityGroupsListAllResponse +} + +type networkSecurityGroupsClient interface { + ListAll(options *armnetwork.NetworkSecurityGroupsListAllOptions) networkSecurityGroupsListAllPager +} + +type networkSecurityGroupsClientImpl struct { + client *armnetwork.NetworkSecurityGroupsClient +} + +func (s networkSecurityGroupsClientImpl) ListAll(options *armnetwork.NetworkSecurityGroupsListAllOptions) networkSecurityGroupsListAllPager { + return s.client.ListAll(options) +} + +type loadBalancersListAllPager interface { + pager + PageResponse() armnetwork.LoadBalancersListAllResponse +} + +type loadBalancersClient interface { + ListAll(options *armnetwork.LoadBalancersListAllOptions) loadBalancersListAllPager +} + +type loadBalancersClientImpl struct { + client *armnetwork.LoadBalancersClient +} + +func (s loadBalancersClientImpl) ListAll(options *armnetwork.LoadBalancersListAllOptions) loadBalancersListAllPager { + return s.client.ListAll(options) +} + +type loadBalancerRulesListAllPager interface { + pager + PageResponse() armnetwork.LoadBalancerLoadBalancingRulesListResponse +} + +type loadBalancerRulesClient interface { + List(string, string, *armnetwork.LoadBalancerLoadBalancingRulesListOptions) loadBalancerRulesListAllPager +} + +type loadBalancerRulesClientImpl struct { + client *armnetwork.LoadBalancerLoadBalancingRulesClient +} + +func (s loadBalancerRulesClientImpl) List(resourceGroupName string, loadBalancerName string, options *armnetwork.LoadBalancerLoadBalancingRulesListOptions) loadBalancerRulesListAllPager { + return s.client.List(resourceGroupName, loadBalancerName, options) +} + +type networkRepository struct { + virtualNetworksClient virtualNetworksClient + routeTableClient routeTablesClient + subnetsClient subnetsClient + firewallsClient firewallsClient + publicIPAddressesClient publicIPAddressesClient + networkSecurityGroupsClient networkSecurityGroupsClient + loadBalancersClient loadBalancersClient + loadBalancerRulesClient loadBalancerRulesClient + cache cache.Cache +} + +func NewNetworkRepository(cred azcore.TokenCredential, options *arm.ClientOptions, config common.AzureProviderConfig, cache cache.Cache) *networkRepository { + return &networkRepository{ + &virtualNetworksClientImpl{client: armnetwork.NewVirtualNetworksClient(config.SubscriptionID, cred, options)}, + &routeTablesClientImpl{client: armnetwork.NewRouteTablesClient(config.SubscriptionID, cred, options)}, + &subnetsClientImpl{client: armnetwork.NewSubnetsClient(config.SubscriptionID, cred, options)}, + &firewallsClientImpl{client: armnetwork.NewAzureFirewallsClient(config.SubscriptionID, cred, options)}, + &publicIPAddressesClientImpl{client: armnetwork.NewPublicIPAddressesClient(config.SubscriptionID, cred, options)}, + &networkSecurityGroupsClientImpl{client: armnetwork.NewNetworkSecurityGroupsClient(config.SubscriptionID, cred, options)}, + &loadBalancersClientImpl{client: armnetwork.NewLoadBalancersClient(config.SubscriptionID, cred, options)}, + &loadBalancerRulesClientImpl{armnetwork.NewLoadBalancerLoadBalancingRulesClient(config.SubscriptionID, cred, options)}, + cache, + } +} + +func (s *networkRepository) ListAllVirtualNetworks() ([]*armnetwork.VirtualNetwork, error) { + + cacheKey := "ListAllVirtualNetworks" + v := s.cache.GetAndLock(cacheKey) + defer s.cache.Unlock(cacheKey) + if v != nil { + return v.([]*armnetwork.VirtualNetwork), nil + } + + pager := s.virtualNetworksClient.ListAll(nil) + results := make([]*armnetwork.VirtualNetwork, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + results = append(results, resp.VirtualNetworksListAllResult.VirtualNetworkListResult.Value...) + } + + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put(cacheKey, results) + + return results, nil +} + +func (s *networkRepository) ListAllRouteTables() ([]*armnetwork.RouteTable, error) { + cacheKey := "ListAllRouteTables" + v := s.cache.GetAndLock(cacheKey) + defer s.cache.Unlock(cacheKey) + if v != nil { + return v.([]*armnetwork.RouteTable), nil + } + + pager := s.routeTableClient.ListAll(nil) + results := make([]*armnetwork.RouteTable, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + results = append(results, resp.RouteTablesListAllResult.RouteTableListResult.Value...) + } + + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put(cacheKey, results) + + return results, nil +} + +func (s *networkRepository) ListAllSubnets(virtualNetwork *armnetwork.VirtualNetwork) ([]*armnetwork.Subnet, error) { + + cacheKey := fmt.Sprintf("ListAllSubnets_%s", *virtualNetwork.ID) + + if v := s.cache.Get(cacheKey); v != nil { + return v.([]*armnetwork.Subnet), nil + } + + res, err := azure.ParseResourceID(*virtualNetwork.ID) + if err != nil { + return nil, err + } + + pager := s.subnetsClient.List(res.ResourceGroup, *virtualNetwork.Name, nil) + results := make([]*armnetwork.Subnet, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + results = append(results, resp.SubnetsListResult.SubnetListResult.Value...) + } + + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put(cacheKey, results) + + return results, nil +} + +func (s *networkRepository) ListAllFirewalls() ([]*armnetwork.AzureFirewall, error) { + + cacheKey := "ListAllFirewalls" + + if v := s.cache.Get(cacheKey); v != nil { + return v.([]*armnetwork.AzureFirewall), nil + } + + pager := s.firewallsClient.ListAll(nil) + results := make([]*armnetwork.AzureFirewall, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + results = append(results, resp.AzureFirewallsListAllResult.AzureFirewallListResult.Value...) + } + + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put(cacheKey, results) + + return results, nil +} + +func (s *networkRepository) ListAllPublicIPAddresses() ([]*armnetwork.PublicIPAddress, error) { + cacheKey := "ListAllPublicIPAddresses" + + if v := s.cache.Get(cacheKey); v != nil { + return v.([]*armnetwork.PublicIPAddress), nil + } + + pager := s.publicIPAddressesClient.ListAll(nil) + results := make([]*armnetwork.PublicIPAddress, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + results = append(results, resp.PublicIPAddressesListAllResult.PublicIPAddressListResult.Value...) + } + + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put(cacheKey, results) + + return results, nil +} + +func (s *networkRepository) ListAllSecurityGroups() ([]*armnetwork.NetworkSecurityGroup, error) { + cacheKey := "networkListAllSecurityGroups" + if v := s.cache.Get(cacheKey); v != nil { + return v.([]*armnetwork.NetworkSecurityGroup), nil + } + + pager := s.networkSecurityGroupsClient.ListAll(nil) + results := make([]*armnetwork.NetworkSecurityGroup, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + results = append(results, resp.Value...) + } + + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put(cacheKey, results) + + return results, nil +} + +func (s *networkRepository) ListAllLoadBalancers() ([]*armnetwork.LoadBalancer, error) { + cacheKey := "networkListAllLoadBalancers" + defer s.cache.Unlock(cacheKey) + if v := s.cache.GetAndLock(cacheKey); v != nil { + return v.([]*armnetwork.LoadBalancer), nil + } + + pager := s.loadBalancersClient.ListAll(nil) + results := make([]*armnetwork.LoadBalancer, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + results = append(results, resp.Value...) + } + + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put(cacheKey, results) + return results, nil +} + +func (s *networkRepository) ListLoadBalancerRules(loadBalancer *armnetwork.LoadBalancer) ([]*armnetwork.LoadBalancingRule, error) { + cacheKey := fmt.Sprintf("networkListLoadBalancerRules_%s", *loadBalancer.ID) + if v := s.cache.Get(cacheKey); v != nil { + return v.([]*armnetwork.LoadBalancingRule), nil + } + + loadBalancerResource, err := azure.ParseResourceID(*loadBalancer.ID) + if err != nil { + return nil, err + } + + pager := s.loadBalancerRulesClient.List(loadBalancerResource.ResourceGroup, loadBalancerResource.ResourceName, &armnetwork.LoadBalancerLoadBalancingRulesListOptions{}) + results := make([]*armnetwork.LoadBalancingRule, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + results = append(results, resp.Value...) + } + + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put(cacheKey, results) + return results, nil +} diff --git a/enumeration/remote/azurerm/repository/network_test.go b/enumeration/remote/azurerm/repository/network_test.go new file mode 100644 index 00000000..450064d9 --- /dev/null +++ b/enumeration/remote/azurerm/repository/network_test.go @@ -0,0 +1,1172 @@ +package repository + +import ( + "context" + "fmt" + cache2 "github.com/snyk/driftctl/enumeration/remote/cache" + "reflect" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func Test_ListAllVirtualNetwork_MultiplesResults(t *testing.T) { + + expected := []*armnetwork.VirtualNetwork{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("network1"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("network2"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("network3"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("network4"), + }, + }, + } + + fakeClient := &mockVirtualNetworkClient{} + + mockPager := &mockVirtualNetworksListAllPager{} + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armnetwork.VirtualNetworksListAllResponse{ + VirtualNetworksListAllResult: armnetwork.VirtualNetworksListAllResult{ + VirtualNetworkListResult: armnetwork.VirtualNetworkListResult{ + Value: []*armnetwork.VirtualNetwork{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("network1"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("network2"), + }, + }, + }, + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armnetwork.VirtualNetworksListAllResponse{ + VirtualNetworksListAllResult: armnetwork.VirtualNetworksListAllResult{ + VirtualNetworkListResult: armnetwork.VirtualNetworkListResult{ + Value: []*armnetwork.VirtualNetwork{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("network3"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("network4"), + }, + }, + }, + }, + }, + }).Times(1) + + fakeClient.On("ListAll", mock.Anything).Return(mockPager) + + c := &cache2.MockCache{} + c.On("GetAndLock", "ListAllVirtualNetworks").Return(nil).Times(1) + c.On("Unlock", "ListAllVirtualNetworks").Times(1) + c.On("Put", "ListAllVirtualNetworks", expected).Return(true).Times(1) + s := &networkRepository{ + virtualNetworksClient: fakeClient, + cache: c, + } + got, err := s.ListAllVirtualNetworks() + if err != nil { + t.Errorf("ListAllVirtualNetworks() error = %v", err) + return + } + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllVirtualNetworks() got = %v, want %v", got, expected) + } +} + +func Test_ListAllVirtualNetwork_MultiplesResults_WithCache(t *testing.T) { + + expected := []*armnetwork.VirtualNetwork{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("network3"), + }, + }, + } + + fakeClient := &mockVirtualNetworkClient{} + + c := &cache2.MockCache{} + c.On("GetAndLock", "ListAllVirtualNetworks").Return(expected).Times(1) + c.On("Unlock", "ListAllVirtualNetworks").Times(1) + s := &networkRepository{ + virtualNetworksClient: fakeClient, + cache: c, + } + got, err := s.ListAllVirtualNetworks() + if err != nil { + t.Errorf("ListAllVirtualNetworks() error = %v", err) + return + } + + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllVirtualNetworks() got = %v, want %v", got, expected) + } +} + +func Test_ListAllVirtualNetwork_Error_OnPageResponse(t *testing.T) { + + fakeClient := &mockVirtualNetworkClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockVirtualNetworksListAllPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armnetwork.VirtualNetworksListAllResponse{}).Times(1) + + fakeClient.On("ListAll", mock.Anything).Return(mockPager) + + s := &networkRepository{ + virtualNetworksClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllVirtualNetworks() + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +func Test_ListAllVirtualNetwork_Error(t *testing.T) { + + fakeClient := &mockVirtualNetworkClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockVirtualNetworksListAllPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + + fakeClient.On("ListAll", mock.Anything).Return(mockPager) + + s := &networkRepository{ + virtualNetworksClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllVirtualNetworks() + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +func Test_ListAllRouteTables_MultiplesResults(t *testing.T) { + + expected := []*armnetwork.RouteTable{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("table1"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("table2"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("table3"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("table4"), + }, + }, + } + + fakeClient := &mockRouteTablesClient{} + + mockPager := &mockRouteTablesListAllPager{} + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armnetwork.RouteTablesListAllResponse{ + RouteTablesListAllResult: armnetwork.RouteTablesListAllResult{ + RouteTableListResult: armnetwork.RouteTableListResult{ + Value: expected[:2], + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armnetwork.RouteTablesListAllResponse{ + RouteTablesListAllResult: armnetwork.RouteTablesListAllResult{ + RouteTableListResult: armnetwork.RouteTableListResult{ + Value: expected[2:], + }, + }, + }).Times(1) + + fakeClient.On("ListAll", mock.Anything).Return(mockPager) + + c := &cache2.MockCache{} + c.On("GetAndLock", "ListAllRouteTables").Return(nil).Times(1) + c.On("Unlock", "ListAllRouteTables").Times(1) + c.On("Put", "ListAllRouteTables", expected).Return(true).Times(1) + s := &networkRepository{ + routeTableClient: fakeClient, + cache: c, + } + got, err := s.ListAllRouteTables() + if err != nil { + t.Errorf("ListAllRouteTables() error = %v", err) + return + } + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllRouteTables() got = %v, want %v", got, expected) + } +} + +func Test_ListAllRouteTables_MultiplesResults_WithCache(t *testing.T) { + + expected := []*armnetwork.RouteTable{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("table1"), + }, + }, + } + + fakeClient := &mockRouteTablesClient{} + + c := &cache2.MockCache{} + c.On("GetAndLock", "ListAllRouteTables").Return(expected).Times(1) + c.On("Unlock", "ListAllRouteTables").Times(1) + s := &networkRepository{ + routeTableClient: fakeClient, + cache: c, + } + got, err := s.ListAllRouteTables() + if err != nil { + t.Errorf("ListAllRouteTables() error = %v", err) + return + } + + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllRouteTables() got = %v, want %v", got, expected) + } +} + +func Test_ListAllRouteTables_Error_OnPageResponse(t *testing.T) { + + fakeClient := &mockRouteTablesClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockRouteTablesListAllPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armnetwork.RouteTablesListAllResponse{}).Times(1) + + fakeClient.On("ListAll", mock.Anything).Return(mockPager) + + s := &networkRepository{ + routeTableClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllRouteTables() + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +func Test_ListAllRouteTables_Error(t *testing.T) { + + fakeClient := &mockRouteTablesClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockRouteTablesListAllPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + + fakeClient.On("ListAll", mock.Anything).Return(mockPager) + + s := &networkRepository{ + routeTableClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllRouteTables() + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +func Test_ListAllSubnets_MultiplesResults(t *testing.T) { + + network := &armnetwork.VirtualNetwork{ + Resource: armnetwork.Resource{ + Name: to.StringPtr("network1"), + ID: to.StringPtr("/subscriptions/7bfb2c5c-0000-0000-0000-fffa356eb406/resourceGroups/test-dev/providers/Microsoft.Network/virtualNetworks/network1"), + }, + } + + expected := []*armnetwork.Subnet{ + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("subnet1"), + }, + }, + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("subnet2"), + }, + }, + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("subnet3"), + }, + }, + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("subnet4"), + }, + }, + } + + fakeClient := &mockSubnetsClient{} + + mockPager := &mockSubnetsListPager{} + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armnetwork.SubnetsListResponse{ + SubnetsListResult: armnetwork.SubnetsListResult{ + SubnetListResult: armnetwork.SubnetListResult{ + Value: []*armnetwork.Subnet{ + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("subnet1"), + }, + }, + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("subnet2"), + }, + }, + }, + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armnetwork.SubnetsListResponse{ + SubnetsListResult: armnetwork.SubnetsListResult{ + SubnetListResult: armnetwork.SubnetListResult{ + Value: []*armnetwork.Subnet{ + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("subnet3"), + }, + }, + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("subnet4"), + }, + }, + }, + }, + }, + }).Times(1) + + fakeClient.On("List", "test-dev", "network1", mock.Anything).Return(mockPager) + + c := &cache2.MockCache{} + cacheKey := fmt.Sprintf("ListAllSubnets_%s", *network.ID) + c.On("Get", cacheKey).Return(nil).Times(1) + c.On("Put", cacheKey, expected).Return(true).Times(1) + s := &networkRepository{ + subnetsClient: fakeClient, + cache: c, + } + got, err := s.ListAllSubnets(network) + if err != nil { + t.Errorf("ListAllSubnets() error = %v", err) + return + } + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllSubnets() got = %v, want %v", got, expected) + } +} + +func Test_ListAllSubnets_MultiplesResults_WithCache(t *testing.T) { + + network := &armnetwork.VirtualNetwork{ + Resource: armnetwork.Resource{ + ID: to.StringPtr("networkID"), + }, + } + + expected := []*armnetwork.Subnet{ + { + Name: to.StringPtr("network1"), + }, + } + fakeClient := &mockSubnetsClient{} + + c := &cache2.MockCache{} + c.On("Get", "ListAllSubnets_networkID").Return(expected).Times(1) + s := &networkRepository{ + subnetsClient: fakeClient, + cache: c, + } + got, err := s.ListAllSubnets(network) + if err != nil { + t.Errorf("ListAllSubnets() error = %v", err) + return + } + + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllSubnets() got = %v, want %v", got, expected) + } +} + +func Test_ListAllSubnets_Error_OnPageResponse(t *testing.T) { + + network := &armnetwork.VirtualNetwork{ + Resource: armnetwork.Resource{ + Name: to.StringPtr("network1"), + ID: to.StringPtr("/subscriptions/7bfb2c5c-0000-0000-0000-fffa356eb406/resourceGroups/test-dev/providers/Microsoft.Network/virtualNetworks/network1"), + }, + } + + fakeClient := &mockSubnetsClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockSubnetsListPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armnetwork.SubnetsListResponse{}).Times(1) + + fakeClient.On("List", "test-dev", "network1", mock.Anything).Return(mockPager) + + s := &networkRepository{ + subnetsClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllSubnets(network) + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +func Test_ListAllSubnets_Error(t *testing.T) { + + network := &armnetwork.VirtualNetwork{ + Resource: armnetwork.Resource{ + Name: to.StringPtr("network1"), + ID: to.StringPtr("/subscriptions/7bfb2c5c-0000-0000-0000-fffa356eb406/resourceGroups/test-dev/providers/Microsoft.Network/virtualNetworks/network1"), + }, + } + + fakeClient := &mockSubnetsClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockSubnetsListPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + + fakeClient.On("List", "test-dev", "network1", mock.Anything).Return(mockPager) + + s := &networkRepository{ + subnetsClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllSubnets(network) + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +func Test_ListAllSubnets_ErrorOnInvalidNetworkID(t *testing.T) { + + network := &armnetwork.VirtualNetwork{ + Resource: armnetwork.Resource{ + Name: to.StringPtr("network1"), + ID: to.StringPtr("foobar"), + }, + } + + fakeClient := &mockSubnetsClient{} + + expectedErr := errors.New("parsing failed for foobar. Invalid resource Id format") + + s := &networkRepository{ + subnetsClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllSubnets(network) + + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr.Error(), err.Error()) + assert.Nil(t, got) +} + +func Test_ListAllFirewalls_MultiplesResults(t *testing.T) { + + expected := []*armnetwork.AzureFirewall{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("firewall1"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("firewall2"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("firewall3"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("firewall4"), + }, + }, + } + + fakeClient := &mockFirewallsClient{} + + mockPager := &mockFirewallsListAllPager{} + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armnetwork.AzureFirewallsListAllResponse{ + AzureFirewallsListAllResult: armnetwork.AzureFirewallsListAllResult{ + AzureFirewallListResult: armnetwork.AzureFirewallListResult{ + Value: expected[:2], + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armnetwork.AzureFirewallsListAllResponse{ + AzureFirewallsListAllResult: armnetwork.AzureFirewallsListAllResult{ + AzureFirewallListResult: armnetwork.AzureFirewallListResult{ + Value: expected[2:], + }, + }, + }).Times(1) + + fakeClient.On("ListAll", mock.Anything).Return(mockPager) + + c := &cache2.MockCache{} + c.On("Get", "ListAllFirewalls").Return(nil).Times(1) + c.On("Put", "ListAllFirewalls", expected).Return(true).Times(1) + s := &networkRepository{ + firewallsClient: fakeClient, + cache: c, + } + got, err := s.ListAllFirewalls() + if err != nil { + t.Errorf("ListAllFirewalls() error = %v", err) + return + } + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllFirewalls() got = %v, want %v", got, expected) + } +} + +func Test_ListAllFirewalls_MultiplesResults_WithCache(t *testing.T) { + + expected := []*armnetwork.AzureFirewall{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("firewall1"), + }, + }, + } + + fakeClient := &mockFirewallsClient{} + + c := &cache2.MockCache{} + c.On("Get", "ListAllFirewalls").Return(expected).Times(1) + s := &networkRepository{ + firewallsClient: fakeClient, + cache: c, + } + got, err := s.ListAllFirewalls() + if err != nil { + t.Errorf("ListAllFirewalls() error = %v", err) + return + } + + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllFirewalls() got = %v, want %v", got, expected) + } +} + +func Test_ListAllFirewalls_Error_OnPageResponse(t *testing.T) { + + fakeClient := &mockFirewallsClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockFirewallsListAllPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armnetwork.AzureFirewallsListAllResponse{}).Times(1) + + fakeClient.On("ListAll", mock.Anything).Return(mockPager) + + s := &networkRepository{ + firewallsClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllFirewalls() + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +func Test_ListAllFirewalls_Error(t *testing.T) { + + fakeClient := &mockFirewallsClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockFirewallsListAllPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + + fakeClient.On("ListAll", mock.Anything).Return(mockPager) + + s := &networkRepository{ + firewallsClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllFirewalls() + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +func Test_ListAllPublicIPAddresses_MultiplesResults(t *testing.T) { + + expected := []*armnetwork.PublicIPAddress{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("ip1"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("ip2"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("ip3"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("ip4"), + }, + }, + } + + fakeClient := &mockPublicIPAddressesClient{} + + mockPager := &mockPublicIPAddressesListAllPager{} + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armnetwork.PublicIPAddressesListAllResponse{ + PublicIPAddressesListAllResult: armnetwork.PublicIPAddressesListAllResult{ + PublicIPAddressListResult: armnetwork.PublicIPAddressListResult{ + Value: expected[:2], + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armnetwork.PublicIPAddressesListAllResponse{ + PublicIPAddressesListAllResult: armnetwork.PublicIPAddressesListAllResult{ + PublicIPAddressListResult: armnetwork.PublicIPAddressListResult{ + Value: expected[2:], + }, + }, + }).Times(1) + + fakeClient.On("ListAll", mock.Anything).Return(mockPager) + + c := &cache2.MockCache{} + c.On("Get", "ListAllPublicIPAddresses").Return(nil).Times(1) + c.On("Put", "ListAllPublicIPAddresses", expected).Return(true).Times(1) + s := &networkRepository{ + publicIPAddressesClient: fakeClient, + cache: c, + } + got, err := s.ListAllPublicIPAddresses() + if err != nil { + t.Errorf("ListAllPublicIPAddresses() error = %v", err) + return + } + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllPublicIPAddresses() got = %v, want %v", got, expected) + } +} + +func Test_ListAllPublicIPAddresses_MultiplesResults_WithCache(t *testing.T) { + + expected := []*armnetwork.PublicIPAddress{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("ip1"), + }, + }, + } + + fakeClient := &mockPublicIPAddressesClient{} + + c := &cache2.MockCache{} + c.On("Get", "ListAllPublicIPAddresses").Return(expected).Times(1) + s := &networkRepository{ + publicIPAddressesClient: fakeClient, + cache: c, + } + got, err := s.ListAllPublicIPAddresses() + if err != nil { + t.Errorf("ListAllPublicIPAddresses() error = %v", err) + return + } + + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllPublicIPAddresses() got = %v, want %v", got, expected) + } +} + +func Test_ListAllPublicIPAddresses_Error_OnPageResponse(t *testing.T) { + + fakeClient := &mockPublicIPAddressesClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockPublicIPAddressesListAllPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armnetwork.PublicIPAddressesListAllResponse{}).Times(1) + + fakeClient.On("ListAll", mock.Anything).Return(mockPager) + + s := &networkRepository{ + publicIPAddressesClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllPublicIPAddresses() + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +func Test_ListAllPublicIPAddresses_Error(t *testing.T) { + + fakeClient := &mockPublicIPAddressesClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockPublicIPAddressesListAllPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + + fakeClient.On("ListAll", mock.Anything).Return(mockPager) + + s := &networkRepository{ + publicIPAddressesClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllPublicIPAddresses() + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +func Test_Network_ListAllSecurityGroups(t *testing.T) { + expectedResults := []*armnetwork.NetworkSecurityGroup{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("sgroup-1"), + Name: to.StringPtr("sgroup-1"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("sgroup-2"), + Name: to.StringPtr("sgroup-2"), + }, + }, + } + + testcases := []struct { + name string + mocks func(*mockNetworkSecurityGroupsListAllPager, *cache2.MockCache) + expected []*armnetwork.NetworkSecurityGroup + wantErr string + }{ + { + name: "should return security groups", + mocks: func(pager *mockNetworkSecurityGroupsListAllPager, mockCache *cache2.MockCache) { + pager.On("NextPage", context.Background()).Return(true).Times(1) + pager.On("NextPage", context.Background()).Return(false).Times(1) + pager.On("PageResponse").Return(armnetwork.NetworkSecurityGroupsListAllResponse{ + NetworkSecurityGroupsListAllResult: armnetwork.NetworkSecurityGroupsListAllResult{ + NetworkSecurityGroupListResult: armnetwork.NetworkSecurityGroupListResult{ + Value: expectedResults, + }, + }, + }).Times(1) + pager.On("Err").Return(nil).Times(2) + + mockCache.On("Get", "networkListAllSecurityGroups").Return(nil).Times(1) + mockCache.On("Put", "networkListAllSecurityGroups", expectedResults).Return(false).Times(1) + }, + expected: expectedResults, + }, + { + name: "should hit cache and return security groups", + mocks: func(pager *mockNetworkSecurityGroupsListAllPager, mockCache *cache2.MockCache) { + mockCache.On("Get", "networkListAllSecurityGroups").Return(expectedResults).Times(1) + }, + expected: expectedResults, + }, + { + name: "should return remote error", + mocks: func(pager *mockNetworkSecurityGroupsListAllPager, mockCache *cache2.MockCache) { + pager.On("NextPage", context.Background()).Return(true).Times(1) + pager.On("NextPage", context.Background()).Return(false).Times(1) + pager.On("PageResponse").Return(armnetwork.NetworkSecurityGroupsListAllResponse{}).Times(1) + pager.On("Err").Return(errors.New("remote error")).Times(1) + + mockCache.On("Get", "networkListAllSecurityGroups").Return(nil).Times(1) + }, + wantErr: "remote error", + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + fakePager := &mockNetworkSecurityGroupsListAllPager{} + fakeClient := &mockNetworkSecurityGroupsClient{} + mockCache := &cache2.MockCache{} + + fakeClient.On("ListAll", (*armnetwork.NetworkSecurityGroupsListAllOptions)(nil)).Return(fakePager).Maybe() + + tt.mocks(fakePager, mockCache) + + s := &networkRepository{ + networkSecurityGroupsClient: fakeClient, + cache: mockCache, + } + got, err := s.ListAllSecurityGroups() + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + } else { + assert.Nil(t, err) + } + + fakeClient.AssertExpectations(t) + mockCache.AssertExpectations(t) + + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("ListAllSecurityGroups() got = %v, want %v", got, tt.expected) + } + }) + } +} + +func Test_Network_ListAllLoadBalancers(t *testing.T) { + expectedResults := []*armnetwork.LoadBalancer{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("lb-1"), + Name: to.StringPtr("lb-1"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("lb-2"), + Name: to.StringPtr("lb-2"), + }, + }, + } + + testcases := []struct { + name string + mocks func(*mockLoadBalancersListAllPager, *cache2.MockCache) + expected []*armnetwork.LoadBalancer + wantErr string + }{ + { + name: "should return load balancers", + mocks: func(pager *mockLoadBalancersListAllPager, mockCache *cache2.MockCache) { + pager.On("NextPage", context.Background()).Return(true).Times(1) + pager.On("NextPage", context.Background()).Return(false).Times(1) + pager.On("PageResponse").Return(armnetwork.LoadBalancersListAllResponse{ + LoadBalancersListAllResult: armnetwork.LoadBalancersListAllResult{ + LoadBalancerListResult: armnetwork.LoadBalancerListResult{ + Value: expectedResults, + }, + }, + }).Times(1) + pager.On("Err").Return(nil).Times(2) + + mockCache.On("GetAndLock", "networkListAllLoadBalancers").Return(nil).Times(1) + mockCache.On("Put", "networkListAllLoadBalancers", expectedResults).Return(false).Times(1) + mockCache.On("Unlock", "networkListAllLoadBalancers").Return(nil).Times(1) + }, + expected: expectedResults, + }, + { + name: "should hit cache and return load balancers", + mocks: func(pager *mockLoadBalancersListAllPager, mockCache *cache2.MockCache) { + mockCache.On("GetAndLock", "networkListAllLoadBalancers").Return(expectedResults).Times(1) + mockCache.On("Unlock", "networkListAllLoadBalancers").Return(nil).Times(1) + }, + expected: expectedResults, + }, + { + name: "should return remote error", + mocks: func(pager *mockLoadBalancersListAllPager, mockCache *cache2.MockCache) { + pager.On("NextPage", context.Background()).Return(true).Times(1) + pager.On("NextPage", context.Background()).Return(false).Times(1) + pager.On("PageResponse").Return(armnetwork.LoadBalancersListAllResponse{}).Times(1) + pager.On("Err").Return(errors.New("remote error")).Times(1) + + mockCache.On("GetAndLock", "networkListAllLoadBalancers").Return(nil).Times(1) + mockCache.On("Unlock", "networkListAllLoadBalancers").Return(nil).Times(1) + }, + wantErr: "remote error", + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + fakePager := &mockLoadBalancersListAllPager{} + fakeClient := &mockLoadBalancersClient{} + mockCache := &cache2.MockCache{} + + fakeClient.On("ListAll", (*armnetwork.LoadBalancersListAllOptions)(nil)).Return(fakePager).Maybe() + + tt.mocks(fakePager, mockCache) + + s := &networkRepository{ + loadBalancersClient: fakeClient, + cache: mockCache, + } + got, err := s.ListAllLoadBalancers() + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + } else { + assert.Nil(t, err) + } + + fakeClient.AssertExpectations(t) + mockCache.AssertExpectations(t) + + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("ListAllLoadBalancers() got = %v, want %v", got, tt.expected) + } + }) + } +} + +func Test_Network_ListLoadBalancerRules(t *testing.T) { + expectedResults := []*armnetwork.LoadBalancingRule{ + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("lbrule-1"), + }, + Name: to.StringPtr("lbrule-1"), + }, + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("lbrule-1"), + }, + Name: to.StringPtr("lbrule-1"), + }, + } + + testcases := []struct { + name string + loadBalancer *armnetwork.LoadBalancer + mocks func(*mockLoadBalancerRulesClient, *mockLoadBalancerRulesListAllPager, *cache2.MockCache) + expected []*armnetwork.LoadBalancingRule + wantErr string + }{ + { + name: "should return load balancer rules", + loadBalancer: &armnetwork.LoadBalancer{ + Resource: armnetwork.Resource{ID: to.StringPtr("/subscriptions/xxx/resourceGroups/driftctl/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress")}, + }, + mocks: func(client *mockLoadBalancerRulesClient, pager *mockLoadBalancerRulesListAllPager, mockCache *cache2.MockCache) { + client.On("List", "driftctl", "PublicIPAddress", &armnetwork.LoadBalancerLoadBalancingRulesListOptions{}).Return(pager) + + pager.On("NextPage", context.Background()).Return(true).Times(1) + pager.On("NextPage", context.Background()).Return(false).Times(1) + pager.On("PageResponse").Return(armnetwork.LoadBalancerLoadBalancingRulesListResponse{ + LoadBalancerLoadBalancingRulesListResult: armnetwork.LoadBalancerLoadBalancingRulesListResult{ + LoadBalancerLoadBalancingRuleListResult: armnetwork.LoadBalancerLoadBalancingRuleListResult{ + Value: expectedResults, + }, + }, + }).Times(1) + pager.On("Err").Return(nil).Times(2) + + mockCache.On("Get", "networkListLoadBalancerRules_/subscriptions/xxx/resourceGroups/driftctl/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress").Return(nil).Times(1) + mockCache.On("Put", "networkListLoadBalancerRules_/subscriptions/xxx/resourceGroups/driftctl/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress", expectedResults).Return(false).Times(1) + }, + expected: expectedResults, + }, + { + name: "should hit cache and return load balancers", + loadBalancer: &armnetwork.LoadBalancer{ + Resource: armnetwork.Resource{ID: to.StringPtr("lb-1")}, + }, + mocks: func(client *mockLoadBalancerRulesClient, pager *mockLoadBalancerRulesListAllPager, mockCache *cache2.MockCache) { + mockCache.On("Get", "networkListLoadBalancerRules_lb-1").Return(expectedResults).Times(1) + }, + expected: expectedResults, + }, + { + name: "should return remote error", + loadBalancer: &armnetwork.LoadBalancer{ + Resource: armnetwork.Resource{ID: to.StringPtr("/subscriptions/xxx/resourceGroups/driftctl/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress")}, + }, + mocks: func(client *mockLoadBalancerRulesClient, pager *mockLoadBalancerRulesListAllPager, mockCache *cache2.MockCache) { + client.On("List", "driftctl", "PublicIPAddress", &armnetwork.LoadBalancerLoadBalancingRulesListOptions{}).Return(pager) + + pager.On("NextPage", context.Background()).Return(true).Times(1) + pager.On("NextPage", context.Background()).Return(false).Times(1) + pager.On("PageResponse").Return(armnetwork.LoadBalancerLoadBalancingRulesListResponse{}).Times(1) + pager.On("Err").Return(errors.New("remote error")).Times(1) + + mockCache.On("Get", "networkListLoadBalancerRules_/subscriptions/xxx/resourceGroups/driftctl/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress").Return(nil).Times(1) + }, + wantErr: "remote error", + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + fakePager := &mockLoadBalancerRulesListAllPager{} + fakeClient := &mockLoadBalancerRulesClient{} + mockCache := &cache2.MockCache{} + + tt.mocks(fakeClient, fakePager, mockCache) + + s := &networkRepository{ + loadBalancerRulesClient: fakeClient, + cache: mockCache, + } + got, err := s.ListLoadBalancerRules(tt.loadBalancer) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + } else { + assert.Nil(t, err) + } + + fakeClient.AssertExpectations(t) + mockCache.AssertExpectations(t) + + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("ListAllLoadBalancers() got = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/pkg/remote/azurerm/repository/pager.go b/enumeration/remote/azurerm/repository/pager.go similarity index 100% rename from pkg/remote/azurerm/repository/pager.go rename to enumeration/remote/azurerm/repository/pager.go diff --git a/enumeration/remote/azurerm/repository/postgresql.go b/enumeration/remote/azurerm/repository/postgresql.go new file mode 100644 index 00000000..f9024eaf --- /dev/null +++ b/enumeration/remote/azurerm/repository/postgresql.go @@ -0,0 +1,93 @@ +package repository + +import ( + "context" + "fmt" + "github.com/snyk/driftctl/enumeration/remote/azurerm/common" + "github.com/snyk/driftctl/enumeration/remote/cache" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql" + "github.com/Azure/go-autorest/autorest/azure" +) + +type PostgresqlRespository interface { + ListAllServers() ([]*armpostgresql.Server, error) + ListAllDatabasesByServer(server *armpostgresql.Server) ([]*armpostgresql.Database, error) +} + +type postgresqlServersClientImpl struct { + client *armpostgresql.ServersClient +} + +type postgresqlServersClient interface { + List(context.Context, *armpostgresql.ServersListOptions) (armpostgresql.ServersListResponse, error) +} + +func (c postgresqlServersClientImpl) List(ctx context.Context, options *armpostgresql.ServersListOptions) (armpostgresql.ServersListResponse, error) { + return c.client.List(ctx, options) +} + +type postgresqlDatabaseClientImpl struct { + client *armpostgresql.DatabasesClient +} + +type postgresqlDatabaseClient interface { + ListByServer(context.Context, string, string, *armpostgresql.DatabasesListByServerOptions) (armpostgresql.DatabasesListByServerResponse, error) +} + +func (c postgresqlDatabaseClientImpl) ListByServer(ctx context.Context, resGroup string, serverName string, options *armpostgresql.DatabasesListByServerOptions) (armpostgresql.DatabasesListByServerResponse, error) { + return c.client.ListByServer(ctx, resGroup, serverName, options) +} + +type postgresqlRepository struct { + serversClient postgresqlServersClient + databaseClient postgresqlDatabaseClient + cache cache.Cache +} + +func NewPostgresqlRepository(cred azcore.TokenCredential, options *arm.ClientOptions, config common.AzureProviderConfig, cache cache.Cache) *postgresqlRepository { + return &postgresqlRepository{ + postgresqlServersClientImpl{client: armpostgresql.NewServersClient(config.SubscriptionID, cred, options)}, + postgresqlDatabaseClientImpl{client: armpostgresql.NewDatabasesClient(config.SubscriptionID, cred, options)}, + cache, + } +} + +func (s *postgresqlRepository) ListAllServers() ([]*armpostgresql.Server, error) { + cacheKey := "postgresqlListAllServers" + + defer s.cache.Unlock(cacheKey) + if v := s.cache.GetAndLock(cacheKey); v != nil { + return v.([]*armpostgresql.Server), nil + } + + res, err := s.serversClient.List(context.Background(), nil) + if err != nil { + return nil, err + } + + s.cache.Put(cacheKey, res.Value) + return res.Value, nil +} + +func (s *postgresqlRepository) ListAllDatabasesByServer(server *armpostgresql.Server) ([]*armpostgresql.Database, error) { + res, err := azure.ParseResourceID(*server.ID) + if err != nil { + return nil, err + } + + cacheKey := fmt.Sprintf("postgresqlListAllDatabases_%s_%s", res.ResourceGroup, *server.Name) + if v := s.cache.Get(cacheKey); v != nil { + return v.([]*armpostgresql.Database), nil + } + + result, err := s.databaseClient.ListByServer(context.Background(), res.ResourceGroup, *server.Name, nil) + if err != nil { + return nil, err + } + + s.cache.Put(cacheKey, result.Value) + return result.Value, nil +} diff --git a/enumeration/remote/azurerm/repository/postgresql_test.go b/enumeration/remote/azurerm/repository/postgresql_test.go new file mode 100644 index 00000000..2d6ac170 --- /dev/null +++ b/enumeration/remote/azurerm/repository/postgresql_test.go @@ -0,0 +1,196 @@ +package repository + +import ( + "context" + "github.com/snyk/driftctl/enumeration/remote/cache" + "reflect" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func Test_Postgresql_ListAllServers(t *testing.T) { + expectedResults := []*armpostgresql.Server{ + { + TrackedResource: armpostgresql.TrackedResource{ + Resource: armpostgresql.Resource{ + ID: to.StringPtr("postgresql-server-1"), + }, + }, + }, + { + TrackedResource: armpostgresql.TrackedResource{ + Resource: armpostgresql.Resource{ + ID: to.StringPtr("postgresql-server-2"), + }, + }, + }, + } + + testcases := []struct { + name string + mocks func(*mockPostgresqlServersClient, *cache.MockCache) + expected []*armpostgresql.Server + wantErr string + }{ + { + name: "should return postgres servers", + mocks: func(client *mockPostgresqlServersClient, mockCache *cache.MockCache) { + client.On("List", context.Background(), mock.Anything).Return(armpostgresql.ServersListResponse{ + ServersListResult: armpostgresql.ServersListResult{ + ServerListResult: armpostgresql.ServerListResult{ + Value: expectedResults, + }, + }, + }, nil).Times(1) + + mockCache.On("GetAndLock", "postgresqlListAllServers").Return(nil).Times(1) + mockCache.On("Unlock", "postgresqlListAllServers").Return().Times(1) + mockCache.On("Put", "postgresqlListAllServers", expectedResults).Return(false).Times(1) + }, + expected: expectedResults, + }, + { + name: "should hit cache and return postgres servers", + mocks: func(client *mockPostgresqlServersClient, mockCache *cache.MockCache) { + mockCache.On("GetAndLock", "postgresqlListAllServers").Return(expectedResults).Times(1) + mockCache.On("Unlock", "postgresqlListAllServers").Return().Times(1) + }, + expected: expectedResults, + }, + { + name: "should return remote error", + mocks: func(client *mockPostgresqlServersClient, mockCache *cache.MockCache) { + client.On("List", context.Background(), mock.Anything).Return(armpostgresql.ServersListResponse{}, errors.New("remote error")).Times(1) + + mockCache.On("GetAndLock", "postgresqlListAllServers").Return(nil).Times(1) + mockCache.On("Unlock", "postgresqlListAllServers").Return().Times(1) + }, + wantErr: "remote error", + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + fakeClient := &mockPostgresqlServersClient{} + mockCache := &cache.MockCache{} + + tt.mocks(fakeClient, mockCache) + + s := &postgresqlRepository{ + serversClient: fakeClient, + cache: mockCache, + } + got, err := s.ListAllServers() + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + } else { + assert.Nil(t, err) + } + + fakeClient.AssertExpectations(t) + mockCache.AssertExpectations(t) + + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("ListAllResourceGroups() got = %v, want %v", got, tt.expected) + } + }) + } +} + +func Test_Postgresql_ListAllDatabases(t *testing.T) { + expectedResults := []*armpostgresql.Database{ + { + ProxyResource: armpostgresql.ProxyResource{ + Resource: armpostgresql.Resource{ + ID: to.StringPtr("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/res-group/providers/Microsoft.DBforPostgreSQL/servers/postgresql-server-1/databases/postgresql-db-1"), + }, + }, + }, + { + ProxyResource: armpostgresql.ProxyResource{ + Resource: armpostgresql.Resource{ + ID: to.StringPtr("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/res-group/providers/Microsoft.DBforPostgreSQL/servers/postgresql-server-1/databases/postgresql-db-2"), + }, + }, + }, + } + + testcases := []struct { + name string + mocks func(*mockPostgresqlDatabaseClient, *cache.MockCache) + expected []*armpostgresql.Database + wantErr string + }{ + { + name: "should return postgres servers", + mocks: func(client *mockPostgresqlDatabaseClient, mockCache *cache.MockCache) { + client.On("ListByServer", context.Background(), "res-group", "postgresql-server-1", (*armpostgresql.DatabasesListByServerOptions)(nil)).Return(armpostgresql.DatabasesListByServerResponse{ + DatabasesListByServerResult: armpostgresql.DatabasesListByServerResult{ + DatabaseListResult: armpostgresql.DatabaseListResult{ + Value: expectedResults, + }, + }, + }, nil).Times(1) + + mockCache.On("Get", "postgresqlListAllDatabases_res-group_postgresql-server-1").Return(nil).Times(1) + mockCache.On("Put", "postgresqlListAllDatabases_res-group_postgresql-server-1", expectedResults).Return(false).Times(1) + }, + expected: expectedResults, + }, + { + name: "should hit cache and return postgres servers", + mocks: func(client *mockPostgresqlDatabaseClient, mockCache *cache.MockCache) { + mockCache.On("Get", "postgresqlListAllDatabases_res-group_postgresql-server-1").Return(expectedResults).Times(1) + }, + expected: expectedResults, + }, + { + name: "should return remote error", + mocks: func(client *mockPostgresqlDatabaseClient, mockCache *cache.MockCache) { + mockCache.On("Get", "postgresqlListAllDatabases_res-group_postgresql-server-1").Return(nil).Times(1) + + client.On("ListByServer", context.Background(), "res-group", "postgresql-server-1", (*armpostgresql.DatabasesListByServerOptions)(nil)).Return(armpostgresql.DatabasesListByServerResponse{}, errors.New("remote error")).Times(1) + }, + wantErr: "remote error", + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + fakeClient := &mockPostgresqlDatabaseClient{} + mockCache := &cache.MockCache{} + + tt.mocks(fakeClient, mockCache) + + s := &postgresqlRepository{ + databaseClient: fakeClient, + cache: mockCache, + } + got, err := s.ListAllDatabasesByServer(&armpostgresql.Server{ + TrackedResource: armpostgresql.TrackedResource{ + Resource: armpostgresql.Resource{ + ID: to.StringPtr("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/res-group/providers/Microsoft.DBforPostgreSQL/servers/postgresql-server-1"), + Name: to.StringPtr("postgresql-server-1"), + }, + }, + }) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + } else { + assert.Nil(t, err) + } + + fakeClient.AssertExpectations(t) + mockCache.AssertExpectations(t) + + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("ListAllResourceGroups() got = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/enumeration/remote/azurerm/repository/privatedns.go b/enumeration/remote/azurerm/repository/privatedns.go new file mode 100644 index 00000000..59a9a333 --- /dev/null +++ b/enumeration/remote/azurerm/repository/privatedns.go @@ -0,0 +1,243 @@ +package repository + +import ( + "context" + "fmt" + "github.com/snyk/driftctl/enumeration/remote/azurerm/common" + "github.com/snyk/driftctl/enumeration/remote/cache" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns" + "github.com/Azure/go-autorest/autorest/azure" +) + +type PrivateDNSRepository interface { + ListAllPrivateZones() ([]*armprivatedns.PrivateZone, error) + ListAllARecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) + ListAllAAAARecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) + ListAllCNAMERecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) + ListAllPTRRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) + ListAllMXRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) + ListAllSRVRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) + ListAllTXTRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) +} + +type privateDNSZoneListPager interface { + pager + PageResponse() armprivatedns.PrivateZonesListResponse +} + +type privateDNSRecordSetListPager interface { + pager + PageResponse() armprivatedns.RecordSetsListResponse +} + +type privateRecordSetClient interface { + List(resourceGroupName string, privateZoneName string, options *armprivatedns.RecordSetsListOptions) privateDNSRecordSetListPager +} + +type privateRecordSetClientImpl struct { + client *armprivatedns.RecordSetsClient +} + +func (c *privateRecordSetClientImpl) List(resourceGroupName string, privateZoneName string, options *armprivatedns.RecordSetsListOptions) privateDNSRecordSetListPager { + return c.client.List(resourceGroupName, privateZoneName, options) +} + +type privateZonesClient interface { + List(options *armprivatedns.PrivateZonesListOptions) privateDNSZoneListPager +} + +type privateZonesClientImpl struct { + client *armprivatedns.PrivateZonesClient +} + +func (c *privateZonesClientImpl) List(options *armprivatedns.PrivateZonesListOptions) privateDNSZoneListPager { + return c.client.List(options) +} + +type privateDNSRepository struct { + zoneClient privateZonesClient + recordClient privateRecordSetClient + cache cache.Cache +} + +func NewPrivateDNSRepository(cred azcore.TokenCredential, options *arm.ClientOptions, config common.AzureProviderConfig, cache cache.Cache) *privateDNSRepository { + return &privateDNSRepository{ + &privateZonesClientImpl{armprivatedns.NewPrivateZonesClient(config.SubscriptionID, cred, options)}, + &privateRecordSetClientImpl{armprivatedns.NewRecordSetsClient(config.SubscriptionID, cred, options)}, + cache, + } +} + +func (s *privateDNSRepository) listAllRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { + cacheKey := fmt.Sprintf("privateDNSlistAllRecords-%s", *zone.ID) + v := s.cache.GetAndLock(cacheKey) + defer s.cache.Unlock(cacheKey) + if v != nil { + return v.([]*armprivatedns.RecordSet), nil + } + + res, err := azure.ParseResourceID(*zone.ID) + if err != nil { + return nil, err + } + + pager := s.recordClient.List(res.ResourceGroup, *zone.Name, nil) + results := make([]*armprivatedns.RecordSet, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + results = append(results, resp.Value...) + } + + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put(cacheKey, results) + + return results, nil +} + +func (s *privateDNSRepository) ListAllARecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { + records, err := s.listAllRecords(zone) + if err != nil { + return nil, err + } + results := make([]*armprivatedns.RecordSet, 0) + for _, record := range records { + if record.Properties.ARecords == nil { + continue + } + results = append(results, record) + + } + return results, nil +} + +func (s *privateDNSRepository) ListAllAAAARecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { + records, err := s.listAllRecords(zone) + if err != nil { + return nil, err + } + results := make([]*armprivatedns.RecordSet, 0) + for _, record := range records { + if record.Properties.AaaaRecords == nil { + continue + } + results = append(results, record) + + } + return results, nil +} + +func (s *privateDNSRepository) ListAllPTRRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { + records, err := s.listAllRecords(zone) + if err != nil { + return nil, err + } + results := make([]*armprivatedns.RecordSet, 0) + for _, record := range records { + if record.Properties.PtrRecords == nil { + continue + } + results = append(results, record) + + } + return results, nil +} + +func (s *privateDNSRepository) ListAllCNAMERecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { + records, err := s.listAllRecords(zone) + if err != nil { + return nil, err + } + results := make([]*armprivatedns.RecordSet, 0) + for _, record := range records { + if record.Properties.CnameRecord == nil { + continue + } + results = append(results, record) + + } + return results, nil +} + +func (s *privateDNSRepository) ListAllMXRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { + records, err := s.listAllRecords(zone) + if err != nil { + return nil, err + } + results := make([]*armprivatedns.RecordSet, 0) + for _, record := range records { + if record.Properties.MxRecords == nil { + continue + } + results = append(results, record) + + } + return results, nil +} + +func (s *privateDNSRepository) ListAllSRVRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { + records, err := s.listAllRecords(zone) + if err != nil { + return nil, err + } + results := make([]*armprivatedns.RecordSet, 0) + for _, record := range records { + if record.Properties.SrvRecords == nil { + continue + } + results = append(results, record) + + } + return results, nil +} + +func (s *privateDNSRepository) ListAllTXTRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { + records, err := s.listAllRecords(zone) + if err != nil { + return nil, err + } + results := make([]*armprivatedns.RecordSet, 0) + for _, record := range records { + if record.Properties.TxtRecords == nil { + continue + } + results = append(results, record) + + } + return results, nil +} + +func (s *privateDNSRepository) ListAllPrivateZones() ([]*armprivatedns.PrivateZone, error) { + cacheKey := "privateDNSListAllPrivateZones" + v := s.cache.GetAndLock(cacheKey) + defer s.cache.Unlock(cacheKey) + if v != nil { + return v.([]*armprivatedns.PrivateZone), nil + } + + pager := s.zoneClient.List(nil) + results := make([]*armprivatedns.PrivateZone, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + results = append(results, resp.Value...) + } + + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put(cacheKey, results) + + return results, nil +} diff --git a/enumeration/remote/azurerm/repository/privatedns_test.go b/enumeration/remote/azurerm/repository/privatedns_test.go new file mode 100644 index 00000000..fddcb0c4 --- /dev/null +++ b/enumeration/remote/azurerm/repository/privatedns_test.go @@ -0,0 +1,1639 @@ +package repository + +import ( + cache2 "github.com/snyk/driftctl/enumeration/remote/cache" + "reflect" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// region PrivateZone +func Test_ListAllPrivateZones_MultiplesResults(t *testing.T) { + + expected := []*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("zone1"), + }, + }, + }, + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("zone2"), + }, + }, + }, + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("zone3"), + }, + }, + }, + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("zone4"), + }, + }, + }, + } + + fakeClient := &mockPrivateZonesClient{} + + mockPager := &mockPrivateDNSZoneListPager{} + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.PrivateZonesListResponse{ + PrivateZonesListResult: armprivatedns.PrivateZonesListResult{ + PrivateZoneListResult: armprivatedns.PrivateZoneListResult{ + Value: []*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("zone1"), + }, + }, + }, + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("zone2"), + }, + }, + }, + }, + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.PrivateZonesListResponse{ + PrivateZonesListResult: armprivatedns.PrivateZonesListResult{ + PrivateZoneListResult: armprivatedns.PrivateZoneListResult{ + Value: []*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("zone3"), + }, + }, + }, + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("zone4"), + }, + }, + }, + }, + }, + }, + }).Times(1) + + fakeClient.On("List", mock.Anything).Return(mockPager) + + c := &cache2.MockCache{} + c.On("GetAndLock", "privateDNSListAllPrivateZones").Return(nil).Times(1) + c.On("Unlock", "privateDNSListAllPrivateZones").Times(1) + c.On("Put", "privateDNSListAllPrivateZones", expected).Return(true).Times(1) + s := &privateDNSRepository{ + zoneClient: fakeClient, + cache: c, + } + got, err := s.ListAllPrivateZones() + if err != nil { + t.Errorf("ListAllPrivateZones() error = %v", err) + return + } + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllPrivateZones() got = %v, want %v", got, expected) + } +} + +func Test_ListAllPrivateZones_MultiplesResults_WithCache(t *testing.T) { + + expected := []*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("zone1"), + }, + }, + }, + } + + fakeClient := &mockPrivateZonesClient{} + + c := &cache2.MockCache{} + c.On("GetAndLock", "privateDNSListAllPrivateZones").Return(expected).Times(1) + c.On("Unlock", "privateDNSListAllPrivateZones").Times(1) + + s := &privateDNSRepository{ + zoneClient: fakeClient, + cache: c, + } + got, err := s.ListAllPrivateZones() + if err != nil { + t.Errorf("ListAllPrivateZones() error = %v", err) + return + } + + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllPrivateZones() got = %v, want %v", got, expected) + } +} + +func Test_ListAllPrivateZones_Error(t *testing.T) { + + fakeClient := &mockPrivateZonesClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockPrivateDNSZoneListPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.PrivateZonesListResponse{}).Times(1) + + fakeClient.On("List", mock.Anything).Return(mockPager) + + s := &privateDNSRepository{ + zoneClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllPrivateZones() + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +// endregion + +// region ARecord +func Test_ListAllARecords_MultiplesResults(t *testing.T) { + + expected := []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + ARecords: []*armprivatedns.ARecord{ + {IPv4Address: to.StringPtr("ip")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record3"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + ARecords: []*armprivatedns.ARecord{ + {IPv4Address: to.StringPtr("ip")}, + }, + }, + }, + } + + fakeRecordSetClient := &mockPrivateRecordSetClient{} + + mockPager := &mockPrivateDNSRecordSetListPager{} + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ + RecordSetsListResult: armprivatedns.RecordSetsListResult{ + RecordSetListResult: armprivatedns.RecordSetListResult{ + Value: []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + ARecords: []*armprivatedns.ARecord{ + {IPv4Address: to.StringPtr("ip")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record2"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{}, + }, + }, + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ + RecordSetsListResult: armprivatedns.RecordSetsListResult{ + RecordSetListResult: armprivatedns.RecordSetListResult{ + Value: []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record3"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + ARecords: []*armprivatedns.ARecord{ + {IPv4Address: to.StringPtr("ip")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record4"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{}, + }, + }, + }, + }, + }).Times(1) + + fakeRecordSetClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) + + c := &cache2.MockCache{} + c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(nil).Times(1) + c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) + c.On("Put", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com", mock.Anything).Return(true).Times(1) + s := &privateDNSRepository{ + recordClient: fakeRecordSetClient, + cache: c, + } + got, err := s.ListAllARecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + if err != nil { + t.Errorf("ListAllARecords() error = %v", err) + return + } + + mockPager.AssertExpectations(t) + fakeRecordSetClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllARecords() got = %v, want %v", got, expected) + } +} + +func Test_ListAllARecords_MultiplesResults_WithCache(t *testing.T) { + + expected := []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + ARecords: []*armprivatedns.ARecord{ + {IPv4Address: to.StringPtr("ip")}, + }, + }, + }, + } + + fakeRecordSetClient := &mockPrivateRecordSetClient{} + + c := &cache2.MockCache{} + c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(expected).Times(1) + c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Times(1) + s := &privateDNSRepository{ + recordClient: fakeRecordSetClient, + cache: c, + } + got, err := s.ListAllARecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + if err != nil { + t.Errorf("ListAllARecords() error = %v", err) + return + } + + fakeRecordSetClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllARecords() got = %v, want %v", got, expected) + } +} + +func Test_ListAllARecords_Error(t *testing.T) { + + fakeClient := &mockPrivateRecordSetClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockPrivateDNSRecordSetListPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{}).Times(1) + + fakeClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) + + s := &privateDNSRepository{ + recordClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllARecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +// endregion + +// region AAAAAAARecord +func Test_ListAllAAAARecords_MultiplesResults(t *testing.T) { + + expected := []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + AaaaRecords: []*armprivatedns.AaaaRecord{ + {IPv6Address: to.StringPtr("ip")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record3"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + AaaaRecords: []*armprivatedns.AaaaRecord{ + {IPv6Address: to.StringPtr("ip")}, + }, + }, + }, + } + + fakeRecordSetClient := &mockPrivateRecordSetClient{} + + mockPager := &mockPrivateDNSRecordSetListPager{} + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ + RecordSetsListResult: armprivatedns.RecordSetsListResult{ + RecordSetListResult: armprivatedns.RecordSetListResult{ + Value: []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + AaaaRecords: []*armprivatedns.AaaaRecord{ + {IPv6Address: to.StringPtr("ip")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record2"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{}, + }, + }, + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ + RecordSetsListResult: armprivatedns.RecordSetsListResult{ + RecordSetListResult: armprivatedns.RecordSetListResult{ + Value: []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record3"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + AaaaRecords: []*armprivatedns.AaaaRecord{ + {IPv6Address: to.StringPtr("ip")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record4"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{}, + }, + }, + }, + }, + }).Times(1) + + fakeRecordSetClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) + + c := &cache2.MockCache{} + c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(nil).Times(1) + c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) + c.On("Put", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com", mock.Anything).Return(true).Times(1) + s := &privateDNSRepository{ + recordClient: fakeRecordSetClient, + cache: c, + } + got, err := s.ListAllAAAARecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + if err != nil { + t.Errorf("ListAllAAAARecords() error = %v", err) + return + } + + mockPager.AssertExpectations(t) + fakeRecordSetClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllAAAARecords() got = %v, want %v", got, expected) + } +} + +func Test_ListAllAAAARecords_MultiplesResults_WithCache(t *testing.T) { + + expected := []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + AaaaRecords: []*armprivatedns.AaaaRecord{ + {IPv6Address: to.StringPtr("ip")}, + }, + }, + }, + } + + fakeRecordSetClient := &mockPrivateRecordSetClient{} + + c := &cache2.MockCache{} + c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(expected).Times(1) + c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Times(1) + s := &privateDNSRepository{ + recordClient: fakeRecordSetClient, + cache: c, + } + got, err := s.ListAllAAAARecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + if err != nil { + t.Errorf("ListAllAAAARecords() error = %v", err) + return + } + + fakeRecordSetClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllAAAARecords() got = %v, want %v", got, expected) + } +} + +func Test_ListAllAAAARecords_Error(t *testing.T) { + + fakeClient := &mockPrivateRecordSetClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockPrivateDNSRecordSetListPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{}).Times(1) + + fakeClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) + + s := &privateDNSRepository{ + recordClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllAAAARecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +// endregion + +// region CNAMERecord +func Test_ListAllCNAMERecords_MultiplesResults(t *testing.T) { + + expected := []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + CnameRecord: &armprivatedns.CnameRecord{ + Cname: to.StringPtr("cname"), + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record3"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + CnameRecord: &armprivatedns.CnameRecord{ + Cname: to.StringPtr("cname"), + }, + }, + }, + } + + fakeRecordSetClient := &mockPrivateRecordSetClient{} + + mockPager := &mockPrivateDNSRecordSetListPager{} + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ + RecordSetsListResult: armprivatedns.RecordSetsListResult{ + RecordSetListResult: armprivatedns.RecordSetListResult{ + Value: []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + CnameRecord: &armprivatedns.CnameRecord{ + Cname: to.StringPtr("cname"), + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record2"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{}, + }, + }, + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ + RecordSetsListResult: armprivatedns.RecordSetsListResult{ + RecordSetListResult: armprivatedns.RecordSetListResult{ + Value: []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record3"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + CnameRecord: &armprivatedns.CnameRecord{ + Cname: to.StringPtr("cname"), + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record4"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{}, + }, + }, + }, + }, + }).Times(1) + + fakeRecordSetClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) + + c := &cache2.MockCache{} + c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(nil).Times(1) + c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) + c.On("Put", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com", mock.Anything).Return(true).Times(1) + s := &privateDNSRepository{ + recordClient: fakeRecordSetClient, + cache: c, + } + got, err := s.ListAllCNAMERecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + if err != nil { + t.Errorf("ListAllCNAMERecords() error = %v", err) + return + } + + mockPager.AssertExpectations(t) + fakeRecordSetClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllCNAMERecords() got = %v, want %v", got, expected) + } +} + +func Test_ListAllCNAMERecords_MultiplesResults_WithCache(t *testing.T) { + + expected := []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + CnameRecord: &armprivatedns.CnameRecord{ + Cname: to.StringPtr("cname"), + }, + }, + }, + } + + fakeRecordSetClient := &mockPrivateRecordSetClient{} + + c := &cache2.MockCache{} + + c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(expected).Times(1) + c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) + + s := &privateDNSRepository{ + recordClient: fakeRecordSetClient, + cache: c, + } + got, err := s.ListAllCNAMERecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + if err != nil { + t.Errorf("ListAllCNAMERecords() error = %v", err) + return + } + + fakeRecordSetClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllCNAMERecords() got = %v, want %v", got, expected) + } +} + +func Test_ListAllCNAMERecords_Error(t *testing.T) { + + fakeClient := &mockPrivateRecordSetClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockPrivateDNSRecordSetListPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{}).Times(1) + + fakeClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) + + s := &privateDNSRepository{ + recordClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllCNAMERecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +// endregion + +// region PTRRecord +func Test_ListAllPTRRecords_MultiplesResults(t *testing.T) { + + expected := []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + PtrRecords: []*armprivatedns.PtrRecord{ + {Ptrdname: to.StringPtr("ptrdname")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record3"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + PtrRecords: []*armprivatedns.PtrRecord{ + {Ptrdname: to.StringPtr("ptrdname")}, + }, + }, + }, + } + + fakeRecordSetClient := &mockPrivateRecordSetClient{} + + mockPager := &mockPrivateDNSRecordSetListPager{} + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ + RecordSetsListResult: armprivatedns.RecordSetsListResult{ + RecordSetListResult: armprivatedns.RecordSetListResult{ + Value: []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + PtrRecords: []*armprivatedns.PtrRecord{ + {Ptrdname: to.StringPtr("ptrdname")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record2"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{}, + }, + }, + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ + RecordSetsListResult: armprivatedns.RecordSetsListResult{ + RecordSetListResult: armprivatedns.RecordSetListResult{ + Value: []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record3"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + PtrRecords: []*armprivatedns.PtrRecord{ + {Ptrdname: to.StringPtr("ptrdname")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record4"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{}, + }, + }, + }, + }, + }).Times(1) + + fakeRecordSetClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) + + c := &cache2.MockCache{} + c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(nil).Times(1) + c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) + c.On("Put", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com", mock.Anything).Return(true).Times(1) + s := &privateDNSRepository{ + recordClient: fakeRecordSetClient, + cache: c, + } + got, err := s.ListAllPTRRecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + if err != nil { + t.Errorf("ListAllPTRRecords() error = %v", err) + return + } + + mockPager.AssertExpectations(t) + fakeRecordSetClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllPTRRecords() got = %v, want %v", got, expected) + } +} + +func Test_ListAllPTRRecords_MultiplesResults_WithCache(t *testing.T) { + + expected := []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + PtrRecords: []*armprivatedns.PtrRecord{ + {Ptrdname: to.StringPtr("ptrdname")}, + }, + }, + }, + } + + fakeRecordSetClient := &mockPrivateRecordSetClient{} + + c := &cache2.MockCache{} + c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(expected).Times(1) + c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Times(1) + s := &privateDNSRepository{ + recordClient: fakeRecordSetClient, + cache: c, + } + got, err := s.ListAllPTRRecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + if err != nil { + t.Errorf("ListAllPTRRecords() error = %v", err) + return + } + + fakeRecordSetClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllPTRRecords() got = %v, want %v", got, expected) + } +} + +func Test_ListAllPTRRecords_Error(t *testing.T) { + + fakeClient := &mockPrivateRecordSetClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockPrivateDNSRecordSetListPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{}).Times(1) + + fakeClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) + + s := &privateDNSRepository{ + recordClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllPTRRecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +// endregion + +// region MXRecord +func Test_ListAllMXRecords_MultiplesResults(t *testing.T) { + expected := []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + MxRecords: []*armprivatedns.MxRecord{ + {Exchange: to.StringPtr("ex")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record3"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + MxRecords: []*armprivatedns.MxRecord{ + {Exchange: to.StringPtr("ex")}, + }, + }, + }, + } + + fakeRecordSetClient := &mockPrivateRecordSetClient{} + + mockPager := &mockPrivateDNSRecordSetListPager{} + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ + RecordSetsListResult: armprivatedns.RecordSetsListResult{ + RecordSetListResult: armprivatedns.RecordSetListResult{ + Value: []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + MxRecords: []*armprivatedns.MxRecord{ + {Exchange: to.StringPtr("ex")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record2"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{}, + }, + }, + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ + RecordSetsListResult: armprivatedns.RecordSetsListResult{ + RecordSetListResult: armprivatedns.RecordSetListResult{ + Value: []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record3"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + MxRecords: []*armprivatedns.MxRecord{ + {Exchange: to.StringPtr("ex")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record4"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{}, + }, + }, + }, + }, + }).Times(1) + + fakeRecordSetClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) + + c := &cache2.MockCache{} + c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(nil).Times(1) + c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) + c.On("Put", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com", mock.Anything).Return(true).Times(1) + s := &privateDNSRepository{ + recordClient: fakeRecordSetClient, + cache: c, + } + got, err := s.ListAllMXRecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + if err != nil { + t.Errorf("ListAllMXRecords() error = %v", err) + return + } + + mockPager.AssertExpectations(t) + fakeRecordSetClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllMXRecords() got = %v, want %v", got, expected) + } +} + +func Test_ListAllMXRecords_MultiplesResults_WithCache(t *testing.T) { + expected := []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + MxRecords: []*armprivatedns.MxRecord{ + {Exchange: to.StringPtr("ex")}, + }, + }, + }, + } + + fakeRecordSetClient := &mockPrivateRecordSetClient{} + + c := &cache2.MockCache{} + c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(expected).Times(1) + c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Times(1) + s := &privateDNSRepository{ + recordClient: fakeRecordSetClient, + cache: c, + } + + got, err := s.ListAllMXRecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + if err != nil { + + t.Errorf("ListAllMXRecords() error = %v", err) + return + } + + fakeRecordSetClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllMXRecords() got = %v, want %v", got, expected) + } +} + +func Test_ListAllMXRecords_Error(t *testing.T) { + + fakeClient := &mockPrivateRecordSetClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockPrivateDNSRecordSetListPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{}).Times(1) + + fakeClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) + + s := &privateDNSRepository{ + recordClient: fakeClient, + cache: cache2.New(0), + } + + got, err := s.ListAllMXRecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +// endregion + +// region SRVRecord +func Test_ListAllSRVRecords_MultiplesResults(t *testing.T) { + + expected := []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + SrvRecords: []*armprivatedns.SrvRecord{ + {Target: to.StringPtr("targetname")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record3"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + SrvRecords: []*armprivatedns.SrvRecord{ + {Target: to.StringPtr("targetname")}, + }, + }, + }, + } + + fakeRecordSetClient := &mockPrivateRecordSetClient{} + + mockPager := &mockPrivateDNSRecordSetListPager{} + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ + RecordSetsListResult: armprivatedns.RecordSetsListResult{ + RecordSetListResult: armprivatedns.RecordSetListResult{ + Value: []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + SrvRecords: []*armprivatedns.SrvRecord{ + {Target: to.StringPtr("targetname")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record2"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{}, + }, + }, + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ + RecordSetsListResult: armprivatedns.RecordSetsListResult{ + RecordSetListResult: armprivatedns.RecordSetListResult{ + Value: []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record3"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + SrvRecords: []*armprivatedns.SrvRecord{ + {Target: to.StringPtr("targetname")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record4"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{}, + }, + }, + }, + }, + }).Times(1) + + fakeRecordSetClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) + + c := &cache2.MockCache{} + c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(nil).Times(1) + c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) + c.On("Put", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com", mock.Anything).Return(true).Times(1) + s := &privateDNSRepository{ + recordClient: fakeRecordSetClient, + cache: c, + } + got, err := s.ListAllSRVRecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + if err != nil { + t.Errorf("ListAllSRVRecords() error = %v", err) + return + } + + mockPager.AssertExpectations(t) + fakeRecordSetClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllSRVRecords() got = %v, want %v", got, expected) + } +} + +func Test_ListAllSRVRecords_MultiplesResults_WithCache(t *testing.T) { + + expected := []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + SrvRecords: []*armprivatedns.SrvRecord{ + {Target: to.StringPtr("targetname")}, + }, + }, + }, + } + + fakeRecordSetClient := &mockPrivateRecordSetClient{} + + c := &cache2.MockCache{} + c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(expected).Times(1) + c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Times(1) + s := &privateDNSRepository{ + recordClient: fakeRecordSetClient, + cache: c, + } + got, err := s.ListAllSRVRecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + if err != nil { + t.Errorf("ListAllSRVRecords() error = %v", err) + return + } + + fakeRecordSetClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllSRVRecords() got = %v, want %v", got, expected) + } +} + +func Test_ListAllSRVRecords_Error(t *testing.T) { + + fakeClient := &mockPrivateRecordSetClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockPrivateDNSRecordSetListPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{}).Times(1) + + fakeClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) + + s := &privateDNSRepository{ + recordClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllSRVRecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +// endregion + +// region TXTRecord +func Test_ListAllTXTRecords_MultiplesResults(t *testing.T) { + + expected := []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + TxtRecords: []*armprivatedns.TxtRecord{ + {Value: []*string{to.StringPtr("value")}}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record3"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + TxtRecords: []*armprivatedns.TxtRecord{ + {Value: []*string{to.StringPtr("value")}}, + }, + }, + }, + } + + fakeRecordSetClient := &mockPrivateRecordSetClient{} + + mockPager := &mockPrivateDNSRecordSetListPager{} + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ + RecordSetsListResult: armprivatedns.RecordSetsListResult{ + RecordSetListResult: armprivatedns.RecordSetListResult{ + Value: []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + TxtRecords: []*armprivatedns.TxtRecord{ + {Value: []*string{to.StringPtr("value")}}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record2"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{}, + }, + }, + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ + RecordSetsListResult: armprivatedns.RecordSetsListResult{ + RecordSetListResult: armprivatedns.RecordSetListResult{ + Value: []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record3"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + TxtRecords: []*armprivatedns.TxtRecord{ + {Value: []*string{to.StringPtr("value")}}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record4"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{}, + }, + }, + }, + }, + }).Times(1) + + fakeRecordSetClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) + + c := &cache2.MockCache{} + c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(nil).Times(1) + c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) + c.On("Put", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com", mock.Anything).Return(true).Times(1) + s := &privateDNSRepository{ + recordClient: fakeRecordSetClient, + cache: c, + } + got, err := s.ListAllTXTRecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + if err != nil { + t.Errorf("ListAllTXTRecords() error = %v", err) + return + } + + mockPager.AssertExpectations(t) + fakeRecordSetClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllTXTRecords() got = %v, want %v", got, expected) + } +} + +func Test_ListAllTXTRecords_MultiplesResults_WithCache(t *testing.T) { + + expected := []*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("record1"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + TxtRecords: []*armprivatedns.TxtRecord{ + {Value: []*string{to.StringPtr("value")}}, + }, + }, + }, + } + + fakeRecordSetClient := &mockPrivateRecordSetClient{} + + c := &cache2.MockCache{} + c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(expected).Times(1) + c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Times(1) + s := &privateDNSRepository{ + recordClient: fakeRecordSetClient, + cache: c, + } + got, err := s.ListAllTXTRecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + if err != nil { + t.Errorf("ListAllTXTRecords() error = %v", err) + return + } + + fakeRecordSetClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllTXTRecords() got = %v, want %v", got, expected) + } +} + +func Test_ListAllTXTRecords_Error(t *testing.T) { + + fakeClient := &mockPrivateRecordSetClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockPrivateDNSRecordSetListPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{}).Times(1) + + fakeClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) + + s := &privateDNSRepository{ + recordClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllTXTRecords(&armprivatedns.PrivateZone{ + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), + Name: to.StringPtr("zone"), + }, + }, + }) + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +// endregion diff --git a/enumeration/remote/azurerm/repository/resources.go b/enumeration/remote/azurerm/repository/resources.go new file mode 100644 index 00000000..768e91e0 --- /dev/null +++ b/enumeration/remote/azurerm/repository/resources.go @@ -0,0 +1,68 @@ +package repository + +import ( + "context" + "github.com/snyk/driftctl/enumeration/remote/azurerm/common" + "github.com/snyk/driftctl/enumeration/remote/cache" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" +) + +type ResourcesRepository interface { + ListAllResourceGroups() ([]*armresources.ResourceGroup, error) +} + +type resourcesListPager interface { + pager + PageResponse() armresources.ResourceGroupsListResponse +} + +type resourcesClient interface { + List(options *armresources.ResourceGroupsListOptions) resourcesListPager +} + +type resourcesClientImpl struct { + client *armresources.ResourceGroupsClient +} + +func (c resourcesClientImpl) List(options *armresources.ResourceGroupsListOptions) resourcesListPager { + return c.client.List(options) +} + +type resourcesRepository struct { + client resourcesClient + cache cache.Cache +} + +func NewResourcesRepository(cred azcore.TokenCredential, options *arm.ClientOptions, config common.AzureProviderConfig, cache cache.Cache) *resourcesRepository { + return &resourcesRepository{ + &resourcesClientImpl{armresources.NewResourceGroupsClient(config.SubscriptionID, cred, options)}, + cache, + } +} + +func (s *resourcesRepository) ListAllResourceGroups() ([]*armresources.ResourceGroup, error) { + cacheKey := "resourcesListAllResourceGroups" + if v := s.cache.Get(cacheKey); v != nil { + return v.([]*armresources.ResourceGroup), nil + } + + pager := s.client.List(nil) + results := make([]*armresources.ResourceGroup, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + results = append(results, resp.ResourceGroupsListResult.Value...) + } + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put(cacheKey, results) + + return results, nil +} diff --git a/enumeration/remote/azurerm/repository/resources_test.go b/enumeration/remote/azurerm/repository/resources_test.go new file mode 100644 index 00000000..dcc97d94 --- /dev/null +++ b/enumeration/remote/azurerm/repository/resources_test.go @@ -0,0 +1,152 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "reflect" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func Test_Resources_ListAllResourceGroups(t *testing.T) { + expectedResults := []*armresources.ResourceGroup{ + { + ID: to.StringPtr("/subscriptions/008b5f48-1b66-4d92-a6b6-d215b4c9b473/resourceGroups/elie-dev"), + Name: to.StringPtr("elie-dev"), + }, + { + ID: to.StringPtr("/subscriptions/008b5f48-1b66-4d92-a6b6-d215b4c9b473/resourceGroups/william-dev"), + Name: to.StringPtr("william-dev"), + }, + { + ID: to.StringPtr("/subscriptions/008b5f48-1b66-4d92-a6b6-d215b4c9b473/resourceGroups/driftctl-sj-tests"), + Name: to.StringPtr("driftctl-sj-tests"), + }, + } + + testcases := []struct { + name string + mocks func(*mockResourcesListPager, *cache.MockCache) + expected []*armresources.ResourceGroup + wantErr string + }{ + { + name: "should return resource groups", + mocks: func(mockPager *mockResourcesListPager, mockCache *cache.MockCache) { + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armresources.ResourceGroupsListResponse{ + ResourceGroupsListResult: armresources.ResourceGroupsListResult{ + ResourceGroupListResult: armresources.ResourceGroupListResult{ + Value: []*armresources.ResourceGroup{ + { + ID: to.StringPtr("/subscriptions/008b5f48-1b66-4d92-a6b6-d215b4c9b473/resourceGroups/elie-dev"), + Name: to.StringPtr("elie-dev"), + }, + { + ID: to.StringPtr("/subscriptions/008b5f48-1b66-4d92-a6b6-d215b4c9b473/resourceGroups/william-dev"), + Name: to.StringPtr("william-dev"), + }, + }, + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armresources.ResourceGroupsListResponse{ + ResourceGroupsListResult: armresources.ResourceGroupsListResult{ + ResourceGroupListResult: armresources.ResourceGroupListResult{ + Value: []*armresources.ResourceGroup{ + { + ID: to.StringPtr("/subscriptions/008b5f48-1b66-4d92-a6b6-d215b4c9b473/resourceGroups/driftctl-sj-tests"), + Name: to.StringPtr("driftctl-sj-tests"), + }, + }, + }, + }, + }).Times(1) + + mockCache.On("Get", "resourcesListAllResourceGroups").Return(nil).Times(1) + mockCache.On("Put", "resourcesListAllResourceGroups", expectedResults).Return(true).Times(1) + }, + expected: expectedResults, + }, + { + name: "should hit cache and return resource groups", + mocks: func(mockPager *mockResourcesListPager, mockCache *cache.MockCache) { + mockCache.On("Get", "resourcesListAllResourceGroups").Return(expectedResults).Times(1) + }, + expected: expectedResults, + }, + { + name: "should return remote error", + mocks: func(mockPager *mockResourcesListPager, mockCache *cache.MockCache) { + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armresources.ResourceGroupsListResponse{ + ResourceGroupsListResult: armresources.ResourceGroupsListResult{ + ResourceGroupListResult: armresources.ResourceGroupListResult{ + Value: []*armresources.ResourceGroup{}, + }, + }, + }).Times(1) + mockPager.On("Err").Return(errors.New("remote error")).Times(1) + + mockCache.On("Get", "resourcesListAllResourceGroups").Return(nil).Times(1) + }, + wantErr: "remote error", + }, + { + name: "should return remote error after fetching all pages", + mocks: func(mockPager *mockResourcesListPager, mockCache *cache.MockCache) { + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armresources.ResourceGroupsListResponse{ + ResourceGroupsListResult: armresources.ResourceGroupsListResult{ + ResourceGroupListResult: armresources.ResourceGroupListResult{ + Value: []*armresources.ResourceGroup{}, + }, + }, + }).Times(1) + mockPager.On("Err").Return(nil).Times(1) + mockPager.On("Err").Return(errors.New("remote error")).Times(1) + + mockCache.On("Get", "resourcesListAllResourceGroups").Return(nil).Times(1) + }, + wantErr: "remote error", + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + fakeClient := &mockResourcesClient{} + mockPager := &mockResourcesListPager{} + mockCache := &cache.MockCache{} + + fakeClient.On("List", mock.Anything).Maybe().Return(mockPager) + + tt.mocks(mockPager, mockCache) + + s := &resourcesRepository{ + client: fakeClient, + cache: mockCache, + } + got, err := s.ListAllResourceGroups() + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + } else { + assert.Nil(t, err) + } + + fakeClient.AssertExpectations(t) + mockPager.AssertExpectations(t) + mockCache.AssertExpectations(t) + + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("ListAllResourceGroups() got = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/enumeration/remote/azurerm/repository/storage.go b/enumeration/remote/azurerm/repository/storage.go new file mode 100644 index 00000000..94cc3e87 --- /dev/null +++ b/enumeration/remote/azurerm/repository/storage.go @@ -0,0 +1,128 @@ +package repository + +import ( + "context" + "fmt" + "github.com/snyk/driftctl/enumeration/remote/azurerm/common" + "github.com/snyk/driftctl/enumeration/remote/cache" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage" + "github.com/Azure/go-autorest/autorest/azure" +) + +type StorageRespository interface { + ListAllStorageAccount() ([]*armstorage.StorageAccount, error) + ListAllStorageContainer(account *armstorage.StorageAccount) ([]string, error) +} + +type blobContainerListPager interface { + pager + PageResponse() armstorage.BlobContainersListResponse +} + +// Interfaces are only used to create mock on Azure SDK +type blobContainerClient interface { + List(resourceGroupName string, accountName string, options *armstorage.BlobContainersListOptions) blobContainerListPager +} + +type blobContainerClientImpl struct { + client *armstorage.BlobContainersClient +} + +func (c blobContainerClientImpl) List(resourceGroupName string, accountName string, options *armstorage.BlobContainersListOptions) blobContainerListPager { + return c.client.List(resourceGroupName, accountName, options) +} + +type storageAccountListPager interface { + pager + PageResponse() armstorage.StorageAccountsListResponse +} + +type storageAccountClient interface { + List(options *armstorage.StorageAccountsListOptions) storageAccountListPager +} + +type storageAccountClientImpl struct { + client *armstorage.StorageAccountsClient +} + +func (c storageAccountClientImpl) List(options *armstorage.StorageAccountsListOptions) storageAccountListPager { + return c.client.List(options) +} + +type storageRepository struct { + storageAccountsClient storageAccountClient + blobContainerClient blobContainerClient + cache cache.Cache +} + +func NewStorageRepository(cred azcore.TokenCredential, options *arm.ClientOptions, config common.AzureProviderConfig, cache cache.Cache) *storageRepository { + return &storageRepository{ + storageAccountClientImpl{client: armstorage.NewStorageAccountsClient(config.SubscriptionID, cred, options)}, + blobContainerClientImpl{client: armstorage.NewBlobContainersClient(config.SubscriptionID, cred, options)}, + cache, + } +} + +func (s *storageRepository) ListAllStorageAccount() ([]*armstorage.StorageAccount, error) { + + cacheKey := "ListAllStorageAccount" + v := s.cache.GetAndLock(cacheKey) + defer s.cache.Unlock(cacheKey) + if v != nil { + return v.([]*armstorage.StorageAccount), nil + } + + pager := s.storageAccountsClient.List(nil) + results := make([]*armstorage.StorageAccount, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + results = append(results, resp.StorageAccountsListResult.StorageAccountListResult.Value...) + } + + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put(cacheKey, results) + + return results, nil +} + +func (s *storageRepository) ListAllStorageContainer(account *armstorage.StorageAccount) ([]string, error) { + + cacheKey := fmt.Sprintf("ListAllStorageContainer_%s", *account.Name) + if v := s.cache.Get(cacheKey); v != nil { + return v.([]string), nil + } + + res, err := azure.ParseResourceID(*account.ID) + if err != nil { + return nil, err + } + + pager := s.blobContainerClient.List(res.ResourceGroup, *account.Name, nil) + results := make([]string, 0) + for pager.NextPage(context.Background()) { + resp := pager.PageResponse() + if err := pager.Err(); err != nil { + return nil, err + } + for _, item := range resp.BlobContainersListResult.ListContainerItems.Value { + results = append(results, fmt.Sprintf("%s%s", *account.Properties.PrimaryEndpoints.Blob, *item.Name)) + } + } + + if err := pager.Err(); err != nil { + return nil, err + } + + s.cache.Put(cacheKey, results) + + return results, nil +} diff --git a/enumeration/remote/azurerm/repository/storage_test.go b/enumeration/remote/azurerm/repository/storage_test.go new file mode 100644 index 00000000..9ec4bb48 --- /dev/null +++ b/enumeration/remote/azurerm/repository/storage_test.go @@ -0,0 +1,373 @@ +package repository + +import ( + cache2 "github.com/snyk/driftctl/enumeration/remote/cache" + "reflect" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func Test_ListAllStorageAccount_MultiplesResults(t *testing.T) { + + expected := []*armstorage.StorageAccount{ + { + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("account1"), + }, + }, + }, + { + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("account2"), + }, + }, + }, + { + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("account3"), + }, + }, + }, + { + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("account4"), + }, + }, + }, + } + + fakeClient := &mockStorageAccountClient{} + + mockPager := &mockStorageAccountListPager{} + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armstorage.StorageAccountsListResponse{ + StorageAccountsListResult: armstorage.StorageAccountsListResult{ + StorageAccountListResult: armstorage.StorageAccountListResult{ + Value: []*armstorage.StorageAccount{ + { + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("account1"), + }, + }, + }, + { + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("account2"), + }, + }, + }, + }, + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armstorage.StorageAccountsListResponse{ + StorageAccountsListResult: armstorage.StorageAccountsListResult{ + StorageAccountListResult: armstorage.StorageAccountListResult{ + Value: []*armstorage.StorageAccount{ + { + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("account3"), + }, + }, + }, + { + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("account4"), + }, + }, + }, + }, + }, + }, + }).Times(1) + + fakeClient.On("List", mock.Anything).Return(mockPager) + + c := &cache2.MockCache{} + c.On("GetAndLock", "ListAllStorageAccount").Return(nil).Times(1) + c.On("Unlock", "ListAllStorageAccount").Times(1) + c.On("Put", "ListAllStorageAccount", expected).Return(true).Times(1) + s := &storageRepository{ + storageAccountsClient: fakeClient, + cache: c, + } + got, err := s.ListAllStorageAccount() + if err != nil { + t.Errorf("ListAllStorageAccount() error = %v", err) + return + } + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllStorageAccount() got = %v, want %v", got, expected) + } +} + +func Test_ListAllStorageAccount_MultiplesResults_WithCache(t *testing.T) { + + expected := []*armstorage.StorageAccount{ + { + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("account1"), + }, + }, + }, + } + + fakeClient := &mockStorageAccountClient{} + + c := &cache2.MockCache{} + c.On("GetAndLock", "ListAllStorageAccount").Return(expected).Times(1) + c.On("Unlock", "ListAllStorageAccount").Times(1) + s := &storageRepository{ + storageAccountsClient: fakeClient, + cache: c, + } + got, err := s.ListAllStorageAccount() + if err != nil { + t.Errorf("ListAllStorageAccount() error = %v", err) + return + } + + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + if !reflect.DeepEqual(got, expected) { + t.Errorf("ListAllStorageAccount() got = %v, want %v", got, expected) + } +} + +func Test_ListAllStorageAccount_Error(t *testing.T) { + + fakeClient := &mockStorageAccountClient{} + + expectedErr := errors.New("unexpected error") + + mockPager := &mockStorageAccountListPager{} + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("PageResponse").Return(armstorage.StorageAccountsListResponse{}).Times(1) + + fakeClient.On("List", mock.Anything).Return(mockPager) + + s := &storageRepository{ + storageAccountsClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllStorageAccount() + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + + assert.Equal(t, expectedErr, err) + assert.Nil(t, got) +} + +func Test_ListAllStorageContainer_MultiplesResults(t *testing.T) { + + account := armstorage.StorageAccount{ + Properties: &armstorage.StorageAccountProperties{ + PrimaryEndpoints: &armstorage.Endpoints{ + Blob: to.StringPtr("https://testeliedriftctl.blob.core.windows.net/"), + }, + }, + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: to.StringPtr("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/foobar/providers/Microsoft.Storage/storageAccounts/testeliedriftctl"), + Name: to.StringPtr("testeliedriftctl"), + }, + }, + } + + expected := []string{ + "https://testeliedriftctl.blob.core.windows.net/container1", + "https://testeliedriftctl.blob.core.windows.net/container2", + "https://testeliedriftctl.blob.core.windows.net/container3", + "https://testeliedriftctl.blob.core.windows.net/container4", + } + + fakeClient := &mockBlobContainerClient{} + + mockPager := &mockBlobContainerListPager{} + mockPager.On("Err").Return(nil).Times(3) + mockPager.On("NextPage", mock.Anything).Return(true).Times(2) + mockPager.On("NextPage", mock.Anything).Return(false).Times(1) + mockPager.On("PageResponse").Return(armstorage.BlobContainersListResponse{ + BlobContainersListResult: armstorage.BlobContainersListResult{ + ListContainerItems: armstorage.ListContainerItems{ + Value: []*armstorage.ListContainerItem{ + { + AzureEntityResource: armstorage.AzureEntityResource{ + Resource: armstorage.Resource{ + Name: to.StringPtr("container1"), + }, + }, + }, + { + AzureEntityResource: armstorage.AzureEntityResource{ + Resource: armstorage.Resource{ + Name: to.StringPtr("container2"), + }, + }, + }, + }, + }, + }, + }).Times(1) + mockPager.On("PageResponse").Return(armstorage.BlobContainersListResponse{ + BlobContainersListResult: armstorage.BlobContainersListResult{ + ListContainerItems: armstorage.ListContainerItems{ + Value: []*armstorage.ListContainerItem{ + { + AzureEntityResource: armstorage.AzureEntityResource{ + Resource: armstorage.Resource{ + Name: to.StringPtr("container3"), + }, + }, + }, + { + AzureEntityResource: armstorage.AzureEntityResource{ + Resource: armstorage.Resource{ + Name: to.StringPtr("container4"), + }, + }, + }, + }, + }, + }, + }).Times(1) + + fakeClient.On("List", "foobar", "testeliedriftctl", (*armstorage.BlobContainersListOptions)(nil)).Return(mockPager) + + c := &cache2.MockCache{} + c.On("Get", "ListAllStorageContainer_testeliedriftctl").Return(nil).Times(1) + c.On("Put", "ListAllStorageContainer_testeliedriftctl", expected).Return(true).Times(1) + s := &storageRepository{ + blobContainerClient: fakeClient, + cache: c, + } + got, err := s.ListAllStorageContainer(&account) + if err != nil { + t.Errorf("ListAllStorageAccount() error = %v", err) + return + } + + mockPager.AssertExpectations(t) + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + assert.Equal(t, expected, got) +} + +func Test_ListAllStorageContainer_MultiplesResults_WithCache(t *testing.T) { + + account := armstorage.StorageAccount{ + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: to.StringPtr("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/foobar/providers/Microsoft.Storage/storageAccounts/testeliedriftctl"), + Name: to.StringPtr("testeliedriftctl"), + }, + }, + } + + expected := []string{ + "https://testeliedriftctl.blob.core.windows.net/container1", + } + + fakeClient := &mockBlobContainerClient{} + + c := &cache2.MockCache{} + c.On("Get", "ListAllStorageContainer_testeliedriftctl").Return(expected).Times(1) + s := &storageRepository{ + blobContainerClient: fakeClient, + cache: c, + } + got, err := s.ListAllStorageContainer(&account) + if err != nil { + t.Errorf("ListAllStorageAccount() error = %v", err) + return + } + + fakeClient.AssertExpectations(t) + c.AssertExpectations(t) + + assert.Equal(t, expected, got) +} + +func Test_ListAllStorageContainer_InvalidStorageAccountResourceID(t *testing.T) { + + account := armstorage.StorageAccount{ + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: to.StringPtr("foobar"), + Name: to.StringPtr(""), + }, + }, + } + + fakeClient := &mockBlobContainerClient{} + + s := &storageRepository{ + blobContainerClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllStorageContainer(&account) + + fakeClient.AssertExpectations(t) + + assert.Nil(t, got) + assert.Equal(t, "parsing failed for foobar. Invalid resource Id format", err.Error()) +} + +func Test_ListAllStorageContainer_Error(t *testing.T) { + + account := armstorage.StorageAccount{ + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: to.StringPtr("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/foobar/providers/Microsoft.Storage/storageAccounts/testeliedriftctl"), + Name: to.StringPtr("testeliedriftctl"), + }, + }, + } + + expectedErr := errors.New("sample error") + + fakeClient := &mockBlobContainerClient{} + mockPager := &mockBlobContainerListPager{} + mockPager.On("NextPage", mock.Anything).Return(true).Times(1) + mockPager.On("Err").Return(expectedErr).Times(1) + mockPager.On("PageResponse").Return(armstorage.BlobContainersListResponse{}).Times(1) + + fakeClient.On("List", "foobar", "testeliedriftctl", (*armstorage.BlobContainersListOptions)(nil)).Return(mockPager) + + s := &storageRepository{ + blobContainerClient: fakeClient, + cache: cache2.New(0), + } + got, err := s.ListAllStorageContainer(&account) + + fakeClient.AssertExpectations(t) + mockPager.AssertExpectations(t) + + assert.Nil(t, got) + assert.Equal(t, expectedErr, err) +} diff --git a/enumeration/remote/azurerm_compute_scanner_test.go b/enumeration/remote/azurerm_compute_scanner_test.go new file mode 100644 index 00000000..cb12ac8f --- /dev/null +++ b/enumeration/remote/azurerm_compute_scanner_test.go @@ -0,0 +1,235 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + azurerm2 "github.com/snyk/driftctl/enumeration/remote/azurerm" + repository2 "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceazure "github.com/snyk/driftctl/enumeration/resource/azurerm" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestAzurermCompute_Image(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockComputeRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no images", + mocks: func(repository *repository2.MockComputeRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllImages").Return([]*armcompute.Image{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "error listing images", + mocks: func(repository *repository2.MockComputeRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllImages").Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzureImageResourceType), + }, + { + test: "multiple images including an invalid ID", + mocks: func(repository *repository2.MockComputeRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllImages").Return([]*armcompute.Image{ + { + Resource: armcompute.Resource{ + ID: to.StringPtr("/subscriptions/4e411884-65b0-4911-bc80-52f9a21942a2/resourceGroups/testgroup/providers/Microsoft.Compute/images/image1"), + Name: to.StringPtr("image1"), + }, + }, + { + Resource: armcompute.Resource{ + ID: to.StringPtr("/subscriptions/4e411884-65b0-4911-bc80-52f9a21942a2/resourceGroups/testgroup/providers/Microsoft.Compute/images/image2"), + Name: to.StringPtr("image2"), + }, + }, + { + Resource: armcompute.Resource{ + ID: to.StringPtr("/invalid-id/image3"), + Name: to.StringPtr("image3"), + }, + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "/subscriptions/4e411884-65b0-4911-bc80-52f9a21942a2/resourceGroups/testgroup/providers/Microsoft.Compute/images/image1") + assert.Equal(t, got[0].ResourceType(), resourceazure.AzureImageResourceType) + + assert.Equal(t, got[1].ResourceId(), "/subscriptions/4e411884-65b0-4911-bc80-52f9a21942a2/resourceGroups/testgroup/providers/Microsoft.Compute/images/image2") + assert.Equal(t, got[1].ResourceType(), resourceazure.AzureImageResourceType) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockComputeRepository{} + c.mocks(fakeRepo, alerter) + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermImageEnumerator(fakeRepo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermCompute_SSHPublicKey(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockComputeRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no public key", + dirName: "azurerm_ssh_public_key_empty", + mocks: func(repository *repository2.MockComputeRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllSSHPublicKeys").Return([]*armcompute.SSHPublicKeyResource{}, nil) + }, + }, + { + test: "error listing public keys", + dirName: "azurerm_ssh_public_key_empty", + mocks: func(repository *repository2.MockComputeRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllSSHPublicKeys").Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzureSSHPublicKeyResourceType), + }, + { + test: "multiple public keys", + dirName: "azurerm_ssh_public_key_multiple", + mocks: func(repository *repository2.MockComputeRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllSSHPublicKeys").Return([]*armcompute.SSHPublicKeyResource{ + { + Resource: armcompute.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/TESTRESGROUP/providers/Microsoft.Compute/sshPublicKeys/example-key"), + Name: to.StringPtr("example-key"), + }, + }, + { + Resource: armcompute.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/TESTRESGROUP/providers/Microsoft.Compute/sshPublicKeys/example-key2"), + Name: to.StringPtr("example-key2"), + }, + }, + }, nil) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockComputeRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ComputeRepository = fakeRepo + providerVersion := "2.71.0" + realProvider, err := terraform2.InitTestAzureProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) + if err != nil { + t.Fatal(err) + } + clientOptions := &arm.ClientOptions{} + repo = repository2.NewComputeRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermSSHPublicKeyEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceazure.AzureSSHPublicKeyResourceType, common2.NewGenericDetailsFetcher(resourceazure.AzureSSHPublicKeyResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceazure.AzureSSHPublicKeyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/azurerm_containerregistry_scanner_test.go b/enumeration/remote/azurerm_containerregistry_scanner_test.go new file mode 100644 index 00000000..070e77d6 --- /dev/null +++ b/enumeration/remote/azurerm_containerregistry_scanner_test.go @@ -0,0 +1,114 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/azurerm" + repository2 "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + "github.com/snyk/driftctl/enumeration/remote/common" + error2 "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceazure "github.com/snyk/driftctl/enumeration/resource/azurerm" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestAzurermContainerRegistry(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockContainerRegistryRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no container registry", + mocks: func(repository *repository2.MockContainerRegistryRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllContainerRegistries").Return([]*armcontainerregistry.Registry{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "error listing container registry", + mocks: func(repository *repository2.MockContainerRegistryRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllContainerRegistries").Return(nil, dummyError) + }, + wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureContainerRegistryResourceType), + }, + { + test: "multiple container registries", + mocks: func(repository *repository2.MockContainerRegistryRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllContainerRegistries").Return([]*armcontainerregistry.Registry{ + { + Resource: armcontainerregistry.Resource{ + ID: to.StringPtr("registry1"), + }, + }, + { + Resource: armcontainerregistry.Resource{ + ID: to.StringPtr("registry2"), + }, + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "registry1") + assert.Equal(t, got[0].ResourceType(), resourceazure.AzureContainerRegistryResourceType) + + assert.Equal(t, got[1].ResourceId(), "registry2") + assert.Equal(t, got[1].ResourceType(), resourceazure.AzureContainerRegistryResourceType) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockContainerRegistryRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ContainerRegistryRepository = fakeRepo + + remoteLibrary.AddEnumerator(azurerm.NewAzurermContainerRegistryEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/azurerm_network_scanner_test.go b/enumeration/remote/azurerm_network_scanner_test.go new file mode 100644 index 00000000..629ce5a7 --- /dev/null +++ b/enumeration/remote/azurerm_network_scanner_test.go @@ -0,0 +1,1014 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + azurerm2 "github.com/snyk/driftctl/enumeration/remote/azurerm" + repository2 "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + error2 "github.com/snyk/driftctl/enumeration/remote/error" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceazure "github.com/snyk/driftctl/enumeration/resource/azurerm" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestAzurermVirtualNetwork(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockNetworkRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no virtual network", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllVirtualNetworks").Return([]*armnetwork.VirtualNetwork{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "error listing virtual network", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllVirtualNetworks").Return(nil, dummyError) + }, + wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureVirtualNetworkResourceType), + }, + { + test: "multiple virtual network", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllVirtualNetworks").Return([]*armnetwork.VirtualNetwork{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("network1"), + Name: to.StringPtr("network1"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("network2"), + Name: to.StringPtr("network2"), + }, + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "network1") + assert.Equal(t, got[0].ResourceType(), resourceazure.AzureVirtualNetworkResourceType) + + assert.Equal(t, got[1].ResourceId(), "network2") + assert.Equal(t, got[1].ResourceType(), resourceazure.AzureVirtualNetworkResourceType) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockNetworkRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.NetworkRepository = fakeRepo + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermVirtualNetworkEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermRouteTables(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockNetworkRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no route tables", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRouteTables").Return([]*armnetwork.RouteTable{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "error listing route tables", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRouteTables").Return(nil, dummyError) + }, + wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureRouteTableResourceType), + }, + { + test: "multiple route tables", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRouteTables").Return([]*armnetwork.RouteTable{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("route1"), + Name: to.StringPtr("route1"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("route2"), + Name: to.StringPtr("route2"), + }, + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "route1") + assert.Equal(t, got[0].ResourceType(), resourceazure.AzureRouteTableResourceType) + + assert.Equal(t, got[1].ResourceId(), "route2") + assert.Equal(t, got[1].ResourceType(), resourceazure.AzureRouteTableResourceType) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockNetworkRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.NetworkRepository = fakeRepo + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermRouteTableEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermRoutes(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockNetworkRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no route tables", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRouteTables").Return([]*armnetwork.RouteTable{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "no routes", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRouteTables").Return([]*armnetwork.RouteTable{ + { + Properties: &armnetwork.RouteTablePropertiesFormat{ + Routes: []*armnetwork.Route{}, + }, + }, + { + Properties: &armnetwork.RouteTablePropertiesFormat{ + Routes: []*armnetwork.Route{}, + }, + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "error listing route tables", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRouteTables").Return(nil, dummyError) + }, + wantErr: error2.NewResourceListingErrorWithType(dummyError, resourceazure.AzureRouteResourceType, resourceazure.AzureRouteTableResourceType), + }, + { + test: "multiple routes", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllRouteTables").Return([]*armnetwork.RouteTable{ + { + Resource: armnetwork.Resource{ + Name: to.StringPtr("table1"), + }, + Properties: &armnetwork.RouteTablePropertiesFormat{ + Routes: []*armnetwork.Route{ + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("route1"), + }, + Name: to.StringPtr("route1"), + }, + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("route2"), + }, + Name: to.StringPtr("route2"), + }, + }, + }, + }, + { + Resource: armnetwork.Resource{ + Name: to.StringPtr("table2"), + }, + Properties: &armnetwork.RouteTablePropertiesFormat{ + Routes: []*armnetwork.Route{ + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("route3"), + }, + Name: to.StringPtr("route3"), + }, + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("route4"), + }, + Name: to.StringPtr("route4"), + }, + }, + }, + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 4) + + assert.Equal(t, "route1", got[0].ResourceId()) + assert.Equal(t, resourceazure.AzureRouteResourceType, got[0].ResourceType()) + + assert.Equal(t, "route2", got[1].ResourceId()) + assert.Equal(t, resourceazure.AzureRouteResourceType, got[1].ResourceType()) + + assert.Equal(t, "route3", got[2].ResourceId()) + assert.Equal(t, resourceazure.AzureRouteResourceType, got[2].ResourceType()) + + assert.Equal(t, "route4", got[3].ResourceId()) + assert.Equal(t, resourceazure.AzureRouteResourceType, got[3].ResourceType()) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockNetworkRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.NetworkRepository = fakeRepo + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermRouteEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermSubnets(t *testing.T) { + + dummyError := errors.New("this is an error") + + networks := []*armnetwork.VirtualNetwork{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("network1"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("network2"), + }, + }, + } + + tests := []struct { + test string + mocks func(*repository2.MockNetworkRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no subnets", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllVirtualNetworks").Return(networks, nil) + repository.On("ListAllSubnets", networks[0]).Return([]*armnetwork.Subnet{}, nil).Times(1) + repository.On("ListAllSubnets", networks[1]).Return([]*armnetwork.Subnet{}, nil).Times(1) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "error listing virtual network", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllVirtualNetworks").Return(nil, dummyError) + }, + wantErr: error2.NewResourceListingErrorWithType(dummyError, resourceazure.AzureSubnetResourceType, resourceazure.AzureVirtualNetworkResourceType), + }, + { + test: "error listing subnets", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllVirtualNetworks").Return(networks, nil) + repository.On("ListAllSubnets", networks[0]).Return(nil, dummyError).Times(1) + }, + wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureSubnetResourceType), + }, + { + test: "multiple subnets", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllVirtualNetworks").Return(networks, nil) + repository.On("ListAllSubnets", networks[0]).Return([]*armnetwork.Subnet{ + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("subnet1"), + }, + }, + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("subnet2"), + }, + }, + }, nil).Times(1) + repository.On("ListAllSubnets", networks[1]).Return([]*armnetwork.Subnet{ + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("subnet3"), + }, + }, + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("subnet4"), + }, + }, + }, nil).Times(1) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 4) + + assert.Equal(t, got[0].ResourceId(), "subnet1") + assert.Equal(t, got[0].ResourceType(), resourceazure.AzureSubnetResourceType) + + assert.Equal(t, got[1].ResourceId(), "subnet2") + assert.Equal(t, got[1].ResourceType(), resourceazure.AzureSubnetResourceType) + + assert.Equal(t, got[2].ResourceId(), "subnet3") + assert.Equal(t, got[2].ResourceType(), resourceazure.AzureSubnetResourceType) + + assert.Equal(t, got[3].ResourceId(), "subnet4") + assert.Equal(t, got[3].ResourceType(), resourceazure.AzureSubnetResourceType) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockNetworkRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.NetworkRepository = fakeRepo + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermSubnetEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermFirewalls(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockNetworkRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no firewall", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllFirewalls").Return([]*armnetwork.AzureFirewall{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "error listing firewalls", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllFirewalls").Return(nil, dummyError) + }, + wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureFirewallResourceType), + }, + { + test: "multiple firewalls", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllFirewalls").Return([]*armnetwork.AzureFirewall{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("firewall1"), // Here we don't care to have a valid ID, it is for testing purpose only + Name: to.StringPtr("firewall1"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("firewall2"), + Name: to.StringPtr("firewall2"), + }, + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "firewall1") + assert.Equal(t, got[0].ResourceType(), resourceazure.AzureFirewallResourceType) + + assert.Equal(t, got[1].ResourceId(), "firewall2") + assert.Equal(t, got[1].ResourceType(), resourceazure.AzureFirewallResourceType) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockNetworkRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.NetworkRepository = fakeRepo + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermFirewallsEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermPublicIP(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockNetworkRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no public IP", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPublicIPAddresses").Return([]*armnetwork.PublicIPAddress{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "error listing public IPs", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPublicIPAddresses").Return(nil, dummyError) + }, + wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzurePublicIPResourceType), + }, + { + test: "multiple public IP", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPublicIPAddresses").Return([]*armnetwork.PublicIPAddress{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("ip1"), // Here we don't care to have a valid ID, it is for testing purpose only + Name: to.StringPtr("ip1"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("ip2"), + Name: to.StringPtr("ip2"), + }, + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "ip1") + assert.Equal(t, got[0].ResourceType(), resourceazure.AzurePublicIPResourceType) + + assert.Equal(t, got[1].ResourceId(), "ip2") + assert.Equal(t, got[1].ResourceType(), resourceazure.AzurePublicIPResourceType) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockNetworkRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.NetworkRepository = fakeRepo + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermPublicIPEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermSecurityGroups(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockNetworkRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no security group", + dirName: "azurerm_network_security_group_empty", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllSecurityGroups").Return([]*armnetwork.NetworkSecurityGroup{}, nil) + }, + }, + { + test: "error listing security groups", + dirName: "azurerm_network_security_group_empty", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllSecurityGroups").Return(nil, dummyError) + }, + wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureNetworkSecurityGroupResourceType), + }, + { + test: "multiple security groups", + dirName: "azurerm_network_security_group_multiple", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllSecurityGroups").Return([]*armnetwork.NetworkSecurityGroup{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/example-resources/providers/Microsoft.Network/networkSecurityGroups/acceptanceTestSecurityGroup1"), + Name: to.StringPtr("acceptanceTestSecurityGroup1"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/example-resources/providers/Microsoft.Network/networkSecurityGroups/acceptanceTestSecurityGroup2"), + Name: to.StringPtr("acceptanceTestSecurityGroup2"), + }, + }, + }, nil) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockNetworkRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.NetworkRepository = fakeRepo + providerVersion := "2.71.0" + realProvider, err := terraform2.InitTestAzureProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) + if err != nil { + t.Fatal(err) + } + clientOptions := &arm.ClientOptions{} + repo = repository2.NewNetworkRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermNetworkSecurityGroupEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceazure.AzureNetworkSecurityGroupResourceType, common2.NewGenericDetailsFetcher(resourceazure.AzureNetworkSecurityGroupResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceazure.AzureNetworkSecurityGroupResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermLoadBalancers(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockNetworkRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no load balancer", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllLoadBalancers").Return([]*armnetwork.LoadBalancer{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "error listing load balancers", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllLoadBalancers").Return(nil, dummyError) + }, + wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureLoadBalancerResourceType), + }, + { + test: "multiple load balancers", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllLoadBalancers").Return([]*armnetwork.LoadBalancer{ + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("lb-1"), // Here we don't care to have a valid ID, it is for testing purpose only + Name: to.StringPtr("lb-1"), + }, + }, + { + Resource: armnetwork.Resource{ + ID: to.StringPtr("lb-2"), + Name: to.StringPtr("lb-2"), + }, + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "lb-1") + assert.Equal(t, got[0].ResourceType(), resourceazure.AzureLoadBalancerResourceType) + + assert.Equal(t, got[1].ResourceId(), "lb-2") + assert.Equal(t, got[1].ResourceType(), resourceazure.AzureLoadBalancerResourceType) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockNetworkRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.NetworkRepository = fakeRepo + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermLoadBalancerEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermLoadBalancerRules(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockNetworkRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no load balancer rule", + dirName: "azurerm_lb_rule_empty", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + loadbalancer := &armnetwork.LoadBalancer{ + Resource: armnetwork.Resource{ + ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress"), + Name: to.StringPtr("testlb"), + }, + } + + repository.On("ListAllLoadBalancers").Return([]*armnetwork.LoadBalancer{loadbalancer}, nil) + + repository.On("ListLoadBalancerRules", loadbalancer).Return([]*armnetwork.LoadBalancingRule{}, nil) + }, + }, + { + test: "error listing load balancer rules", + dirName: "azurerm_lb_rule_empty", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllLoadBalancers").Return(nil, dummyError) + }, + wantErr: error2.NewResourceListingErrorWithType(dummyError, resourceazure.AzureLoadBalancerRuleResourceType, resourceazure.AzureLoadBalancerResourceType), + }, + { + test: "multiple load balancer rules", + dirName: "azurerm_lb_rule_multiple", + mocks: func(repository *repository2.MockNetworkRepository, alerter *mocks.AlerterInterface) { + loadbalancer := &armnetwork.LoadBalancer{ + Resource: armnetwork.Resource{ + ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress"), + Name: to.StringPtr("TestLoadBalancer"), + }, + } + + repository.On("ListAllLoadBalancers").Return([]*armnetwork.LoadBalancer{loadbalancer}, nil) + + repository.On("ListLoadBalancerRules", loadbalancer).Return([]*armnetwork.LoadBalancingRule{ + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/loadBalancingRules/LBRule"), + }, + Name: to.StringPtr("LBRule"), + }, + { + SubResource: armnetwork.SubResource{ + ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/loadBalancingRules/LBRule2"), + }, + Name: to.StringPtr("LBRule2"), + }, + }, nil).Once() + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockNetworkRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.NetworkRepository = fakeRepo + providerVersion := "2.71.0" + realProvider, err := terraform2.InitTestAzureProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) + if err != nil { + t.Fatal(err) + } + clientOptions := &arm.ClientOptions{} + repo = repository2.NewNetworkRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermLoadBalancerRuleEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceazure.AzureLoadBalancerRuleResourceType, common2.NewGenericDetailsFetcher(resourceazure.AzureLoadBalancerRuleResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceazure.AzureLoadBalancerRuleResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/azurerm_postgresql_scanner_test.go b/enumeration/remote/azurerm_postgresql_scanner_test.go new file mode 100644 index 00000000..9de49e5d --- /dev/null +++ b/enumeration/remote/azurerm_postgresql_scanner_test.go @@ -0,0 +1,246 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + azurerm2 "github.com/snyk/driftctl/enumeration/remote/azurerm" + repository2 "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceazure "github.com/snyk/driftctl/enumeration/resource/azurerm" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestAzurermPostgresqlServer(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockPostgresqlRespository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no postgres server", + mocks: func(repository *repository2.MockPostgresqlRespository, alerter *mocks.AlerterInterface) { + repository.On("ListAllServers").Return([]*armpostgresql.Server{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "error listing postgres servers", + mocks: func(repository *repository2.MockPostgresqlRespository, alerter *mocks.AlerterInterface) { + repository.On("ListAllServers").Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePostgresqlServerResourceType), + }, + { + test: "multiple postgres servers", + mocks: func(repository *repository2.MockPostgresqlRespository, alerter *mocks.AlerterInterface) { + repository.On("ListAllServers").Return([]*armpostgresql.Server{ + { + TrackedResource: armpostgresql.TrackedResource{ + Resource: armpostgresql.Resource{ + ID: to.StringPtr("server1"), + Name: to.StringPtr("server1"), + }, + }, + }, + { + TrackedResource: armpostgresql.TrackedResource{ + Resource: armpostgresql.Resource{ + ID: to.StringPtr("server2"), + Name: to.StringPtr("server2"), + }, + }, + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "server1") + assert.Equal(t, got[0].ResourceType(), resourceazure.AzurePostgresqlServerResourceType) + + assert.Equal(t, got[1].ResourceId(), "server2") + assert.Equal(t, got[1].ResourceType(), resourceazure.AzurePostgresqlServerResourceType) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockPostgresqlRespository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.PostgresqlRespository = fakeRepo + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermPostgresqlServerEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermPostgresqlDatabase(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockPostgresqlRespository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no postgres database", + mocks: func(repository *repository2.MockPostgresqlRespository, alerter *mocks.AlerterInterface) { + repository.On("ListAllServers").Return([]*armpostgresql.Server{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "error listing postgres servers", + mocks: func(repository *repository2.MockPostgresqlRespository, alerter *mocks.AlerterInterface) { + repository.On("ListAllServers").Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePostgresqlDatabaseResourceType, resourceazure.AzurePostgresqlServerResourceType), + }, + { + test: "error listing postgres databases", + mocks: func(repository *repository2.MockPostgresqlRespository, alerter *mocks.AlerterInterface) { + repository.On("ListAllServers").Return([]*armpostgresql.Server{ + { + TrackedResource: armpostgresql.TrackedResource{ + Resource: armpostgresql.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/api-rg-pro/providers/Microsoft.DBforPostgreSQL/servers/postgresql-server-8791542"), + Name: to.StringPtr("postgresql-server-8791542"), + }, + }, + }, + }, nil).Once() + + repository.On("ListAllDatabasesByServer", mock.IsType(&armpostgresql.Server{})).Return(nil, dummyError).Once() + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePostgresqlDatabaseResourceType), + }, + { + test: "multiple postgres databases", + mocks: func(repository *repository2.MockPostgresqlRespository, alerter *mocks.AlerterInterface) { + repository.On("ListAllServers").Return([]*armpostgresql.Server{ + { + TrackedResource: armpostgresql.TrackedResource{ + Resource: armpostgresql.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/api-rg-pro/providers/Microsoft.DBforPostgreSQL/servers/postgresql-server-8791542"), + Name: to.StringPtr("postgresql-server-8791542"), + }, + }, + }, + }, nil).Once() + + repository.On("ListAllDatabasesByServer", mock.IsType(&armpostgresql.Server{})).Return([]*armpostgresql.Database{ + { + ProxyResource: armpostgresql.ProxyResource{ + Resource: armpostgresql.Resource{ + ID: to.StringPtr("db1"), + Name: to.StringPtr("db1"), + }, + }, + }, + { + ProxyResource: armpostgresql.ProxyResource{ + Resource: armpostgresql.Resource{ + ID: to.StringPtr("db2"), + Name: to.StringPtr("db2"), + }, + }, + }, + }, nil).Once() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "db1") + assert.Equal(t, got[0].ResourceType(), resourceazure.AzurePostgresqlDatabaseResourceType) + + assert.Equal(t, got[1].ResourceId(), "db2") + assert.Equal(t, got[1].ResourceType(), resourceazure.AzurePostgresqlDatabaseResourceType) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockPostgresqlRespository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.PostgresqlRespository = fakeRepo + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermPostgresqlDatabaseEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/azurerm_privatedns_scanner_test.go b/enumeration/remote/azurerm_privatedns_scanner_test.go new file mode 100644 index 00000000..e546d41e --- /dev/null +++ b/enumeration/remote/azurerm_privatedns_scanner_test.go @@ -0,0 +1,1220 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + azurerm2 "github.com/snyk/driftctl/enumeration/remote/azurerm" + repository2 "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + terraform2 "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceazure "github.com/snyk/driftctl/enumeration/resource/azurerm" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testresource "github.com/snyk/driftctl/test/resource" + terraformtest "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestAzurermPrivateDNSZone(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockPrivateDNSRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no private zone", + dirName: "azurerm_private_dns_private_zone_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) + }, + }, + { + test: "error listing private zones", + dirName: "azurerm_private_dns_private_zone_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSZoneResourceType), + }, + { + test: "multiple private zones", + dirName: "azurerm_private_dns_private_zone_multiple", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), + Name: to.StringPtr("thisisatestusingtf.com"), + }, + }, + }, + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf2.com"), + Name: to.StringPtr("thisisatestusingtf2.com"), + }, + }, + }, + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/testmartin.com"), + Name: to.StringPtr("testmartin.com"), + }, + }, + }, + }, nil) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform2.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform2.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockPrivateDNSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.PrivateDNSRepository = fakeRepo + providerVersion := "2.71.0" + realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraformtest.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) + if err != nil { + t.Fatal(err) + } + clientOptions := &arm.ClientOptions{} + repo = repository2.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermPrivateDNSZoneEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSZoneResourceType, common2.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSZoneResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSZoneResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermPrivateDNSARecord(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockPrivateDNSRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no private a record", + dirName: "azurerm_private_dns_a_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) + }, + }, + { + test: "error listing private zone", + dirName: "azurerm_private_dns_a_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePrivateDNSARecordResourceType, resourceazure.AzurePrivateDNSZoneResourceType), + }, + { + test: "error listing private a records", + dirName: "azurerm_private_dns_a_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), + Name: to.StringPtr("thisisatestusingtf.com"), + }, + }, + }, + }, nil) + repository.On("ListAllARecords", mock.Anything).Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSARecordResourceType), + }, + { + test: "multiple private a records", + dirName: "azurerm_private_dns_a_record_multiple", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), + Name: to.StringPtr("thisisatestusingtf.com"), + }, + }, + }, + }, nil) + + repository.On("ListAllARecords", mock.Anything).Return([]*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/A/test"), + Name: to.StringPtr("test"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + ARecords: []*armprivatedns.ARecord{ + {IPv4Address: to.StringPtr("10.0.180.17")}, + {IPv4Address: to.StringPtr("10.0.180.20")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/A/othertest"), + Name: to.StringPtr("othertest"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + ARecords: []*armprivatedns.ARecord{ + {IPv4Address: to.StringPtr("10.0.180.20")}, + }, + }, + }, + }, nil).Once() + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform2.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform2.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockPrivateDNSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.PrivateDNSRepository = fakeRepo + providerVersion := "2.71.0" + realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraformtest.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) + if err != nil { + t.Fatal(err) + } + clientOptions := &arm.ClientOptions{} + repo = repository2.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermPrivateDNSARecordEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSARecordResourceType, common2.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSARecordResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSARecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermPrivateDNSAAAARecord(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockPrivateDNSRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no private aaaa record", + dirName: "azurerm_private_dns_aaaa_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) + }, + }, + { + test: "error listing private zone", + dirName: "azurerm_private_dns_aaaa_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePrivateDNSAAAARecordResourceType, resourceazure.AzurePrivateDNSZoneResourceType), + }, + { + test: "error listing private aaaa records", + dirName: "azurerm_private_dns_aaaa_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), + Name: to.StringPtr("thisisatestusingtf.com"), + }, + }, + }, + }, nil) + repository.On("ListAllAAAARecords", mock.Anything).Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSAAAARecordResourceType), + }, + { + test: "multiple private aaaaa records", + dirName: "azurerm_private_dns_aaaaa_record_multiple", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), + Name: to.StringPtr("thisisatestusingtf.com"), + }, + }, + }, + }, nil) + + repository.On("ListAllAAAARecords", mock.Anything).Return([]*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/AAAA/test"), + Name: to.StringPtr("test"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + AaaaRecords: []*armprivatedns.AaaaRecord{ + {IPv6Address: to.StringPtr("fd5d:70bc:930e:d008:0000:0000:0000:7334")}, + {IPv6Address: to.StringPtr("fd5d:70bc:930e:d008::7335")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/AAAA/othertest"), + Name: to.StringPtr("othertest"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + AaaaRecords: []*armprivatedns.AaaaRecord{ + {IPv6Address: to.StringPtr("fd5d:70bc:930e:d008:0000:0000:0000:7334")}, + {IPv6Address: to.StringPtr("fd5d:70bc:930e:d008::7335")}, + }, + }, + }, + }, nil).Once() + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform2.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform2.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockPrivateDNSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.PrivateDNSRepository = fakeRepo + providerVersion := "2.71.0" + realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraformtest.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) + if err != nil { + t.Fatal(err) + } + clientOptions := &arm.ClientOptions{} + repo = repository2.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermPrivateDNSAAAARecordEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSAAAARecordResourceType, common2.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSAAAARecordResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSAAAARecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermPrivateDNSCNAMERecord(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockPrivateDNSRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no private cname record", + dirName: "azurerm_private_dns_cname_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) + }, + }, + { + test: "error listing private zone", + dirName: "azurerm_private_dns_cname_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePrivateDNSCNameRecordResourceType, resourceazure.AzurePrivateDNSZoneResourceType), + }, + { + test: "error listing private cname records", + dirName: "azurerm_private_dns_cname_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), + Name: to.StringPtr("thisisatestusingtf.com"), + }, + }, + }, + }, nil) + repository.On("ListAllCNAMERecords", mock.Anything).Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSCNameRecordResourceType), + }, + { + test: "multiple private cname records", + dirName: "azurerm_private_dns_cname_record_multiple", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), + Name: to.StringPtr("thisisatestusingtf.com"), + }, + }, + }, + }, nil) + + repository.On("ListAllCNAMERecords", mock.Anything).Return([]*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/CNAME/test"), + Name: to.StringPtr("test"), + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/CNAME/othertest"), + Name: to.StringPtr("othertest"), + }, + }, + }, + }, nil).Once() + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform2.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform2.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockPrivateDNSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.PrivateDNSRepository = fakeRepo + providerVersion := "2.71.0" + realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraformtest.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) + if err != nil { + t.Fatal(err) + } + clientOptions := &arm.ClientOptions{} + repo = repository2.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermPrivateDNSCNameRecordEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSCNameRecordResourceType, common2.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSCNameRecordResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSCNameRecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermPrivateDNSPTRRecord(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockPrivateDNSRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no private ptr record", + dirName: "azurerm_private_dns_ptr_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) + }, + }, + { + test: "error listing private zone", + dirName: "azurerm_private_dns_ptr_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePrivateDNSPTRRecordResourceType, resourceazure.AzurePrivateDNSZoneResourceType), + }, + { + test: "error listing private ptr records", + dirName: "azurerm_private_dns_ptr_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), + Name: to.StringPtr("thisisatestusingtf.com"), + }, + }, + }, + }, nil) + repository.On("ListAllPTRRecords", mock.Anything).Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSPTRRecordResourceType), + }, + { + test: "multiple private ptra records", + dirName: "azurerm_private_dns_ptr_record_multiple", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), + Name: to.StringPtr("thisisatestusingtf.com"), + }, + }, + }, + }, nil) + + repository.On("ListAllPTRRecords", mock.Anything).Return([]*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/PTR/othertestptr"), + Name: to.StringPtr("othertestptr"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + PtrRecords: []*armprivatedns.PtrRecord{ + {Ptrdname: to.StringPtr("ptr1.thisisatestusingtf.com")}, + {Ptrdname: to.StringPtr("ptr2.thisisatestusingtf.com")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/PTR/testptr"), + Name: to.StringPtr("testptr"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + PtrRecords: []*armprivatedns.PtrRecord{ + {Ptrdname: to.StringPtr("ptr3.thisisatestusingtf.com")}, + }, + }, + }, + }, nil).Once() + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform2.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform2.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockPrivateDNSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.PrivateDNSRepository = fakeRepo + providerVersion := "2.71.0" + realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraformtest.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) + if err != nil { + t.Fatal(err) + } + clientOptions := &arm.ClientOptions{} + repo = repository2.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermPrivateDNSPTRRecordEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSPTRRecordResourceType, common2.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSPTRRecordResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSPTRRecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermPrivateDNSMXRecord(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockPrivateDNSRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no private mx record", + dirName: "azurerm_private_dns_mx_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) + }, + }, + { + test: "error listing private zone", + dirName: "azurerm_private_dns_mx_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePrivateDNSMXRecordResourceType, resourceazure.AzurePrivateDNSZoneResourceType), + }, + { + test: "error listing private mx records", + dirName: "azurerm_private_dns_mx_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), + Name: to.StringPtr("thisisatestusingtf.com"), + }, + }, + }, + }, nil) + repository.On("ListAllMXRecords", mock.Anything).Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSMXRecordResourceType), + }, + { + test: "multiple private mx records", + dirName: "azurerm_private_dns_mx_record_multiple", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), + Name: to.StringPtr("thisisatestusingtf.com"), + }, + }, + }, + }, nil) + + repository.On("ListAllMXRecords", mock.Anything).Return([]*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/MX/othertestmx"), + Name: to.StringPtr("othertestmx"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + MxRecords: []*armprivatedns.MxRecord{ + {Exchange: to.StringPtr("ex1")}, + {Exchange: to.StringPtr("ex2")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/MX/testmx"), + Name: to.StringPtr("testmx"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + MxRecords: []*armprivatedns.MxRecord{ + {Exchange: to.StringPtr("ex1")}, + {Exchange: to.StringPtr("ex2")}, + }, + }, + }, + }, nil).Once() + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform2.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform2.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockPrivateDNSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.PrivateDNSRepository = fakeRepo + providerVersion := "2.71.0" + realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraformtest.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) + if err != nil { + t.Fatal(err) + } + clientOptions := &arm.ClientOptions{} + repo = repository2.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermPrivateDNSMXRecordEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSMXRecordResourceType, common2.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSMXRecordResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSMXRecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermPrivateDNSSRVRecord(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockPrivateDNSRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no private srv record", + dirName: "azurerm_private_dns_srv_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) + }, + }, + { + test: "error listing private zone", + dirName: "azurerm_private_dns_srv_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePrivateDNSSRVRecordResourceType, resourceazure.AzurePrivateDNSZoneResourceType), + }, + { + test: "error listing private srv records", + dirName: "azurerm_private_dns_srv_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), + Name: to.StringPtr("thisisatestusingtf.com"), + }, + }, + }, + }, nil) + repository.On("ListAllSRVRecords", mock.Anything).Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSSRVRecordResourceType), + }, + { + test: "multiple private srv records", + dirName: "azurerm_private_dns_srv_record_multiple", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), + Name: to.StringPtr("thisisatestusingtf.com"), + }, + }, + }, + }, nil) + + repository.On("ListAllSRVRecords", mock.Anything).Return([]*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/SRV/othertestptr"), + Name: to.StringPtr("othertestptr"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + SrvRecords: []*armprivatedns.SrvRecord{ + {Target: to.StringPtr("srv1.thisisatestusingtf.com")}, + {Target: to.StringPtr("srv2.thisisatestusingtf.com")}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/SRV/testptr"), + Name: to.StringPtr("testptr"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + PtrRecords: []*armprivatedns.PtrRecord{ + {Ptrdname: to.StringPtr("srv3.thisisatestusingtf.com")}, + }, + }, + }, + }, nil).Once() + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform2.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform2.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockPrivateDNSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.PrivateDNSRepository = fakeRepo + providerVersion := "2.71.0" + realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraformtest.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) + if err != nil { + t.Fatal(err) + } + clientOptions := &arm.ClientOptions{} + repo = repository2.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermPrivateDNSSRVRecordEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSSRVRecordResourceType, common2.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSSRVRecordResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSSRVRecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermPrivateDNSTXTRecord(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + dirName string + mocks func(*repository2.MockPrivateDNSRepository, *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no private txt record", + dirName: "azurerm_private_dns_txt_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) + }, + }, + { + test: "error listing private zone", + dirName: "azurerm_private_dns_txt_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePrivateDNSTXTRecordResourceType, resourceazure.AzurePrivateDNSZoneResourceType), + }, + { + test: "error listing private txt records", + dirName: "azurerm_private_dns_txt_record_empty", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), + Name: to.StringPtr("thisisatestusingtf.com"), + }, + }, + }, + }, nil) + repository.On("ListAllTXTRecords", mock.Anything).Return(nil, dummyError) + }, + wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSTXTRecordResourceType), + }, + { + test: "multiple private txt records", + dirName: "azurerm_private_dns_txt_record_multiple", + mocks: func(repository *repository2.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ + { + TrackedResource: armprivatedns.TrackedResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), + Name: to.StringPtr("thisisatestusingtf.com"), + }, + }, + }, + }, nil) + + repository.On("ListAllTXTRecords", mock.Anything).Return([]*armprivatedns.RecordSet{ + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/TXT/othertesttxt"), + Name: to.StringPtr("othertesttxt"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + TxtRecords: []*armprivatedns.TxtRecord{ + {Value: []*string{to.StringPtr("this is value line 1")}}, + {Value: []*string{to.StringPtr("this is value line 2")}}, + }, + }, + }, + { + ProxyResource: armprivatedns.ProxyResource{ + Resource: armprivatedns.Resource{ + ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/TXT/testtxt"), + Name: to.StringPtr("testtxt"), + }, + }, + Properties: &armprivatedns.RecordSetProperties{ + PtrRecords: []*armprivatedns.PtrRecord{ + {Ptrdname: to.StringPtr("this is value line 3")}, + }, + }, + }, + }, nil).Once() + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform2.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform2.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockPrivateDNSRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.PrivateDNSRepository = fakeRepo + providerVersion := "2.71.0" + realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) + if err != nil { + t.Fatal(err) + } + provider := terraformtest.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) + if err != nil { + t.Fatal(err) + } + clientOptions := &arm.ClientOptions{} + repo = repository2.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermPrivateDNSTXTRecordEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSTXTRecordResourceType, common2.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSTXTRecordResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + + if err != nil { + return + } + test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSTXTRecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/azurerm_resources_scanner_test.go b/enumeration/remote/azurerm_resources_scanner_test.go new file mode 100644 index 00000000..40b427f9 --- /dev/null +++ b/enumeration/remote/azurerm_resources_scanner_test.go @@ -0,0 +1,112 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/azurerm" + repository2 "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + "github.com/snyk/driftctl/enumeration/remote/common" + error2 "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceazure "github.com/snyk/driftctl/enumeration/resource/azurerm" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestAzurermResourceGroups(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockResourcesRepository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no resource group", + mocks: func(repository *repository2.MockResourcesRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllResourceGroups").Return([]*armresources.ResourceGroup{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "error listing resource groups", + mocks: func(repository *repository2.MockResourcesRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllResourceGroups").Return(nil, dummyError) + }, + wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureResourceGroupResourceType), + }, + { + test: "multiple resource groups", + mocks: func(repository *repository2.MockResourcesRepository, alerter *mocks.AlerterInterface) { + repository.On("ListAllResourceGroups").Return([]*armresources.ResourceGroup{ + { + ID: to.StringPtr("group1"), + Name: to.StringPtr("group1"), + }, + { + ID: to.StringPtr("group2"), + Name: to.StringPtr("group2"), + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "group1") + assert.Equal(t, got[0].ResourceType(), resourceazure.AzureResourceGroupResourceType) + + assert.Equal(t, got[1].ResourceId(), "group2") + assert.Equal(t, got[1].ResourceType(), resourceazure.AzureResourceGroupResourceType) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockResourcesRepository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.ResourcesRepository = fakeRepo + + remoteLibrary.AddEnumerator(azurerm.NewAzurermResourceGroupEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/azurerm_storage_scanner_test.go b/enumeration/remote/azurerm_storage_scanner_test.go new file mode 100644 index 00000000..581f1a54 --- /dev/null +++ b/enumeration/remote/azurerm_storage_scanner_test.go @@ -0,0 +1,262 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + azurerm2 "github.com/snyk/driftctl/enumeration/remote/azurerm" + repository2 "github.com/snyk/driftctl/enumeration/remote/azurerm/repository" + "github.com/snyk/driftctl/enumeration/remote/common" + error2 "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + resourceazure "github.com/snyk/driftctl/enumeration/resource/azurerm" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestAzurermStorageAccount(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockStorageRespository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no storage accounts", + mocks: func(repository *repository2.MockStorageRespository, alerter *mocks.AlerterInterface) { + repository.On("ListAllStorageAccount").Return([]*armstorage.StorageAccount{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "error listing storage accounts", + mocks: func(repository *repository2.MockStorageRespository, alerter *mocks.AlerterInterface) { + repository.On("ListAllStorageAccount").Return(nil, dummyError) + }, + wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureStorageAccountResourceType), + }, + { + test: "multiple storage accounts", + mocks: func(repository *repository2.MockStorageRespository, alerter *mocks.AlerterInterface) { + repository.On("ListAllStorageAccount").Return([]*armstorage.StorageAccount{ + { + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("testeliedriftctl1"), + }, + }, + }, + { + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("testeliedriftctl2"), + }, + }, + }, + }, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "testeliedriftctl1") + assert.Equal(t, got[0].ResourceType(), resourceazure.AzureStorageAccountResourceType) + + assert.Equal(t, got[1].ResourceId(), "testeliedriftctl2") + assert.Equal(t, got[1].ResourceType(), resourceazure.AzureStorageAccountResourceType) + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockStorageRespository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.StorageRespository = fakeRepo + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermStorageAccountEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} + +func TestAzurermStorageContainer(t *testing.T) { + + dummyError := errors.New("this is an error") + + tests := []struct { + test string + mocks func(*repository2.MockStorageRespository, *mocks.AlerterInterface) + assertExpected func(t *testing.T, got []*resource.Resource) + wantErr error + }{ + { + test: "no storage accounts", + mocks: func(repository *repository2.MockStorageRespository, alerter *mocks.AlerterInterface) { + repository.On("ListAllStorageAccount").Return([]*armstorage.StorageAccount{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "no storage containers", + mocks: func(repository *repository2.MockStorageRespository, alerter *mocks.AlerterInterface) { + account1 := &armstorage.StorageAccount{ + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("testeliedriftctl1"), + }, + }, + } + account2 := &armstorage.StorageAccount{ + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("testeliedriftctl1"), + }, + }, + } + repository.On("ListAllStorageAccount").Return([]*armstorage.StorageAccount{ + account1, + account2, + }, nil) + repository.On("ListAllStorageContainer", account1).Return([]string{}, nil) + repository.On("ListAllStorageContainer", account2).Return([]string{}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "error listing storage accounts", + mocks: func(repository *repository2.MockStorageRespository, alerter *mocks.AlerterInterface) { + repository.On("ListAllStorageAccount").Return(nil, dummyError) + }, + wantErr: error2.NewResourceListingErrorWithType(dummyError, resourceazure.AzureStorageContainerResourceType, resourceazure.AzureStorageAccountResourceType), + }, + { + test: "error listing storage container", + mocks: func(repository *repository2.MockStorageRespository, alerter *mocks.AlerterInterface) { + account := &armstorage.StorageAccount{ + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("testeliedriftctl1"), + }, + }, + } + repository.On("ListAllStorageAccount").Return([]*armstorage.StorageAccount{account}, nil) + repository.On("ListAllStorageContainer", account).Return(nil, dummyError) + }, + wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureStorageContainerResourceType), + }, + { + test: "multiple storage containers", + mocks: func(repository *repository2.MockStorageRespository, alerter *mocks.AlerterInterface) { + account1 := &armstorage.StorageAccount{ + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("testeliedriftctl1"), + }, + }, + } + account2 := &armstorage.StorageAccount{ + TrackedResource: armstorage.TrackedResource{ + Resource: armstorage.Resource{ + ID: func(s string) *string { return &s }("testeliedriftctl2"), + }, + }, + } + repository.On("ListAllStorageAccount").Return([]*armstorage.StorageAccount{ + account1, + account2, + }, nil) + repository.On("ListAllStorageContainer", account1).Return([]string{"https://testeliedriftctl1.blob.core.windows.net/container1", "https://testeliedriftctl1.blob.core.windows.net/container2"}, nil) + repository.On("ListAllStorageContainer", account2).Return([]string{"https://testeliedriftctl2.blob.core.windows.net/container3", "https://testeliedriftctl2.blob.core.windows.net/container4"}, nil) + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 4) + + for _, container := range got { + assert.Equal(t, container.ResourceType(), resourceazure.AzureStorageContainerResourceType) + } + + assert.Equal(t, got[0].ResourceId(), "https://testeliedriftctl1.blob.core.windows.net/container1") + assert.Equal(t, got[1].ResourceId(), "https://testeliedriftctl1.blob.core.windows.net/container2") + assert.Equal(t, got[2].ResourceId(), "https://testeliedriftctl2.blob.core.windows.net/container3") + assert.Equal(t, got[3].ResourceId(), "https://testeliedriftctl2.blob.core.windows.net/container4") + }, + }, + } + + providerVersion := "2.71.0" + schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) + resourceazure.InitResourcesMetadata(schemaRepository) + factory := terraform.NewTerraformResourceFactory(schemaRepository) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + + scanOptions := ScannerOptions{} + remoteLibrary := common.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + fakeRepo := &repository2.MockStorageRespository{} + c.mocks(fakeRepo, alerter) + + var repo repository2.StorageRespository = fakeRepo + + remoteLibrary.AddEnumerator(azurerm2.NewAzurermStorageContainerEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + c.assertExpected(tt, got) + alerter.AssertExpectations(tt) + fakeRepo.AssertExpectations(tt) + }) + } +} diff --git a/pkg/remote/cache/cache.go b/enumeration/remote/cache/cache.go similarity index 100% rename from pkg/remote/cache/cache.go rename to enumeration/remote/cache/cache.go diff --git a/enumeration/remote/cache/cache_test.go b/enumeration/remote/cache/cache_test.go new file mode 100644 index 00000000..0598bd09 --- /dev/null +++ b/enumeration/remote/cache/cache_test.go @@ -0,0 +1,157 @@ +package cache + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/stretchr/testify/assert" +) + +func BenchmarkCache(b *testing.B) { + cache := New(500) + for i := 0; i < b.N; i++ { + key := fmt.Sprintf("test-key-%d", i) + data := make([]*resource.Resource, 1024) + assert.Equal(b, false, cache.Put(key, data)) + assert.Equal(b, data, cache.Get(key)) + } +} + +func TestCache(t *testing.T) { + t.Run("should return nil on non-existing key", func(t *testing.T) { + cache := New(5) + assert.Equal(t, nil, cache.Get("test")) + assert.Equal(t, 0, cache.Len()) + }) + + t.Run("should retrieve newly added key", func(t *testing.T) { + cache := New(5) + assert.Equal(t, false, cache.Put("s3", []string{})) + assert.Equal(t, []string{}, cache.Get("s3")) + assert.Equal(t, 1, cache.Len()) + }) + + t.Run("should override existing key", func(t *testing.T) { + cache := New(5) + assert.Equal(t, false, cache.Put("s3", []string{})) + assert.Equal(t, []string{}, cache.Get("s3")) + + assert.Equal(t, true, cache.Put("s3", []string{"test"})) + assert.Equal(t, []string{"test"}, cache.Get("s3")) + assert.Equal(t, 1, cache.Len()) + }) + + t.Run("should delete the least used keys", func(t *testing.T) { + keys := []struct { + key string + value interface{} + }{ + {key: "test-0", value: nil}, + {key: "test-1", value: nil}, + {key: "test-2", value: nil}, + {key: "test-3", value: nil}, + {key: "test-4", value: nil}, + {key: "test-5", value: nil}, + {key: "test-6", value: "value"}, + {key: "test-7", value: "value"}, + {key: "test-8", value: "value"}, + {key: "test-9", value: "value"}, + {key: "test-10", value: "value"}, + } + + cache := New(5) + for i := 0; i <= 10; i++ { + cache.Put(fmt.Sprintf("test-%d", i), "value") + } + for _, k := range keys { + assert.Equal(t, k.value, cache.Get(k.key)) + } + assert.Equal(t, 5, cache.Len()) + }) + + t.Run("should ignore keys when capacity is 0", func(t *testing.T) { + keys := []struct { + key string + value interface{} + }{ + { + "test", + []string{"slice"}, + }, + { + "test", + []string{}, + }, + { + "test2", + []*resource.Resource{}, + }, + } + cache := New(0) + + for _, k := range keys { + assert.Equal(t, false, cache.Put(k.key, k.value)) + assert.Equal(t, nil, cache.Get(k.key)) + } + assert.Equal(t, 0, cache.Len()) + }) + + t.Run("cache will not panic for parallel calls", func(t *testing.T) { + key := "sameKeyForMultiplesRoutines" + + cache := New(1) + + wg := sync.WaitGroup{} + missCount := 0 + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + hit := cache.Get(key) + if hit != nil { + return + } + missCount++ + time.Sleep(10 * time.Millisecond) + cache.Put(key, "value") + }() + } + + wg.Wait() + assert.Equal(t, cache.Get(key), "value") + assert.Greater(t, missCount, 1) + }) + + t.Run("cache should be missed only once with parallel calls and GetAndLock usage", func(t *testing.T) { + key := "sameKeyForMultiplesRoutines" + + cache := New(1) + + nbRoutines := 100 + wg := sync.WaitGroup{} + wg.Add(nbRoutines) + + missCount := 0 + for i := 0; i < nbRoutines; i++ { + go func() { + defer wg.Done() + hit := cache.GetAndLock(key) + defer cache.Unlock(key) + if hit != nil { + return + } + missCount++ + time.Sleep(1 * time.Millisecond) + cache.Put(key, "value") + }() + } + + wg.Wait() + assert.Equal(t, cache.Get(key), "value") + assert.Equal(t, 1, missCount) + }) +} diff --git a/pkg/remote/cache/mock_Cache.go b/enumeration/remote/cache/mock_Cache.go similarity index 100% rename from pkg/remote/cache/mock_Cache.go rename to enumeration/remote/cache/mock_Cache.go diff --git a/enumeration/remote/common/details_fetcher.go b/enumeration/remote/common/details_fetcher.go new file mode 100644 index 00000000..f990a61b --- /dev/null +++ b/enumeration/remote/common/details_fetcher.go @@ -0,0 +1,54 @@ +package common + +import ( + "github.com/sirupsen/logrus" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/terraform" +) + +type DetailsFetcher interface { + ReadDetails(*resource.Resource) (*resource.Resource, error) +} + +type GenericDetailsFetcher struct { + resType resource.ResourceType + reader terraform.ResourceReader + deserializer *resource.Deserializer +} + +func NewGenericDetailsFetcher(resType resource.ResourceType, provider terraform.ResourceReader, deserializer *resource.Deserializer) *GenericDetailsFetcher { + return &GenericDetailsFetcher{ + resType: resType, + reader: provider, + deserializer: deserializer, + } +} + +func (f *GenericDetailsFetcher) ReadDetails(res *resource.Resource) (*resource.Resource, error) { + attributes := map[string]string{} + if res.Schema().ResolveReadAttributesFunc != nil { + attributes = res.Schema().ResolveReadAttributesFunc(res) + } + ctyVal, err := f.reader.ReadResource(terraform.ReadResourceArgs{ + Ty: f.resType, + ID: res.ResourceId(), + Attributes: attributes, + }) + if err != nil { + return nil, remoteerror.NewResourceScanningError(err, res.ResourceType(), res.ResourceId()) + } + if ctyVal.IsNull() { + logrus.WithFields(logrus.Fields{ + "type": f.resType, + "id": res.ResourceId(), + }).Debug("Got null while reading resource details") + return nil, nil + } + deserializedRes, err := f.deserializer.DeserializeOne(string(f.resType), *ctyVal) + if err != nil { + return nil, err + } + + return deserializedRes, nil +} diff --git a/enumeration/remote/common/library.go b/enumeration/remote/common/library.go new file mode 100644 index 00000000..4b0459ba --- /dev/null +++ b/enumeration/remote/common/library.go @@ -0,0 +1,38 @@ +package common + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +type Enumerator interface { + SupportedType() resource.ResourceType + Enumerate() ([]*resource.Resource, error) +} + +type RemoteLibrary struct { + enumerators []Enumerator + detailsFetchers map[resource.ResourceType]DetailsFetcher +} + +func NewRemoteLibrary() *RemoteLibrary { + return &RemoteLibrary{ + make([]Enumerator, 0), + make(map[resource.ResourceType]DetailsFetcher), + } +} + +func (r *RemoteLibrary) AddEnumerator(enumerator Enumerator) { + r.enumerators = append(r.enumerators, enumerator) +} + +func (r *RemoteLibrary) Enumerators() []Enumerator { + return r.enumerators +} + +func (r *RemoteLibrary) AddDetailsFetcher(ty resource.ResourceType, detailsFetcher DetailsFetcher) { + r.detailsFetchers[ty] = detailsFetcher +} + +func (r *RemoteLibrary) GetDetailsFetcher(ty resource.ResourceType) DetailsFetcher { + return r.detailsFetchers[ty] +} diff --git a/enumeration/remote/common/mock_Enumerator.go b/enumeration/remote/common/mock_Enumerator.go new file mode 100644 index 00000000..98ef337e --- /dev/null +++ b/enumeration/remote/common/mock_Enumerator.go @@ -0,0 +1,50 @@ +// Code generated by mockery v0.0.0-dev. DO NOT EDIT. + +package common + +import ( + resource "github.com/snyk/driftctl/enumeration/resource" + mock "github.com/stretchr/testify/mock" +) + +// MockEnumerator is an autogenerated mock type for the Enumerator type +type MockEnumerator struct { + mock.Mock +} + +// Enumerate provides a mock function with given fields: +func (_m *MockEnumerator) Enumerate() ([]*resource.Resource, error) { + ret := _m.Called() + + var r0 []*resource.Resource + if rf, ok := ret.Get(0).(func() []*resource.Resource); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*resource.Resource) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SupportedType provides a mock function with given fields: +func (_m *MockEnumerator) SupportedType() resource.ResourceType { + ret := _m.Called() + + var r0 resource.ResourceType + if rf, ok := ret.Get(0).(func() resource.ResourceType); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(resource.ResourceType) + } + + return r0 +} diff --git a/enumeration/remote/common/providers.go b/enumeration/remote/common/providers.go new file mode 100644 index 00000000..8afd7ab1 --- /dev/null +++ b/enumeration/remote/common/providers.go @@ -0,0 +1,30 @@ +package common + +import ( + tf "github.com/snyk/driftctl/enumeration/terraform" + "github.com/snyk/driftctl/enumeration/terraform/lock" +) + +type RemoteParameter string + +const ( + RemoteAWSTerraform = "aws+tf" + RemoteGithubTerraform = "github+tf" + RemoteGoogleTerraform = "gcp+tf" + RemoteAzureTerraform = "azure+tf" +) + +var remoteParameterMapping = map[RemoteParameter]string{ + RemoteAWSTerraform: tf.AWS, + RemoteGithubTerraform: tf.GITHUB, + RemoteGoogleTerraform: tf.GOOGLE, + RemoteAzureTerraform: tf.AZURE, +} + +func (p RemoteParameter) GetProviderAddress() *lock.ProviderAddress { + return &lock.ProviderAddress{ + Hostname: "registry.terraform.io", + Namespace: "hashicorp", + Type: remoteParameterMapping[p], + } +} diff --git a/pkg/remote/error/errors.go b/enumeration/remote/error/errors.go similarity index 100% rename from pkg/remote/error/errors.go rename to enumeration/remote/error/errors.go diff --git a/enumeration/remote/github/github_branch_protection_enumerator.go b/enumeration/remote/github/github_branch_protection_enumerator.go new file mode 100644 index 00000000..eddc4e75 --- /dev/null +++ b/enumeration/remote/github/github_branch_protection_enumerator.go @@ -0,0 +1,45 @@ +package github + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/github" +) + +type GithubBranchProtectionEnumerator struct { + repository GithubRepository + factory resource.ResourceFactory +} + +func NewGithubBranchProtectionEnumerator(repo GithubRepository, factory resource.ResourceFactory) *GithubBranchProtectionEnumerator { + return &GithubBranchProtectionEnumerator{ + repository: repo, + factory: factory, + } +} + +func (g *GithubBranchProtectionEnumerator) SupportedType() resource.ResourceType { + return github.GithubBranchProtectionResourceType +} + +func (g *GithubBranchProtectionEnumerator) Enumerate() ([]*resource.Resource, error) { + ids, err := g.repository.ListBranchProtection() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(g.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(ids)) + + for _, id := range ids { + results = append( + results, + g.factory.CreateAbstractResource( + string(g.SupportedType()), + id, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/github/github_membership_enumerator.go b/enumeration/remote/github/github_membership_enumerator.go new file mode 100644 index 00000000..4322e39d --- /dev/null +++ b/enumeration/remote/github/github_membership_enumerator.go @@ -0,0 +1,45 @@ +package github + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/github" +) + +type GithubMembershipEnumerator struct { + Membership GithubRepository + factory resource.ResourceFactory +} + +func NewGithubMembershipEnumerator(repo GithubRepository, factory resource.ResourceFactory) *GithubMembershipEnumerator { + return &GithubMembershipEnumerator{ + Membership: repo, + factory: factory, + } +} + +func (g *GithubMembershipEnumerator) SupportedType() resource.ResourceType { + return github.GithubMembershipResourceType +} + +func (g *GithubMembershipEnumerator) Enumerate() ([]*resource.Resource, error) { + ids, err := g.Membership.ListMembership() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(g.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(ids)) + + for _, id := range ids { + results = append( + results, + g.factory.CreateAbstractResource( + string(g.SupportedType()), + id, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/github/github_repository_enumerator.go b/enumeration/remote/github/github_repository_enumerator.go new file mode 100644 index 00000000..e5150ffa --- /dev/null +++ b/enumeration/remote/github/github_repository_enumerator.go @@ -0,0 +1,45 @@ +package github + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/github" +) + +type GithubRepositoryEnumerator struct { + repository GithubRepository + factory resource.ResourceFactory +} + +func NewGithubRepositoryEnumerator(repo GithubRepository, factory resource.ResourceFactory) *GithubRepositoryEnumerator { + return &GithubRepositoryEnumerator{ + repository: repo, + factory: factory, + } +} + +func (g *GithubRepositoryEnumerator) SupportedType() resource.ResourceType { + return github.GithubRepositoryResourceType +} + +func (g *GithubRepositoryEnumerator) Enumerate() ([]*resource.Resource, error) { + ids, err := g.repository.ListRepositories() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(g.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(ids)) + + for _, id := range ids { + results = append( + results, + g.factory.CreateAbstractResource( + string(g.SupportedType()), + id, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/github/github_team_enumerator.go b/enumeration/remote/github/github_team_enumerator.go new file mode 100644 index 00000000..c4221190 --- /dev/null +++ b/enumeration/remote/github/github_team_enumerator.go @@ -0,0 +1,47 @@ +package github + +import ( + "fmt" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/github" +) + +type GithubTeamEnumerator struct { + repository GithubRepository + factory resource.ResourceFactory +} + +func NewGithubTeamEnumerator(repo GithubRepository, factory resource.ResourceFactory) *GithubTeamEnumerator { + return &GithubTeamEnumerator{ + repository: repo, + factory: factory, + } +} + +func (g *GithubTeamEnumerator) SupportedType() resource.ResourceType { + return github.GithubTeamResourceType +} + +func (g *GithubTeamEnumerator) Enumerate() ([]*resource.Resource, error) { + resourceList, err := g.repository.ListTeams() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(g.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resourceList)) + + for _, team := range resourceList { + results = append( + results, + g.factory.CreateAbstractResource( + string(g.SupportedType()), + fmt.Sprintf("%d", team.DatabaseId), + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/github/github_team_membership_enumerator.go b/enumeration/remote/github/github_team_membership_enumerator.go new file mode 100644 index 00000000..ae610c9b --- /dev/null +++ b/enumeration/remote/github/github_team_membership_enumerator.go @@ -0,0 +1,45 @@ +package github + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/github" +) + +type GithubTeamMembershipEnumerator struct { + repository GithubRepository + factory resource.ResourceFactory +} + +func NewGithubTeamMembershipEnumerator(repo GithubRepository, factory resource.ResourceFactory) *GithubTeamMembershipEnumerator { + return &GithubTeamMembershipEnumerator{ + repository: repo, + factory: factory, + } +} + +func (g *GithubTeamMembershipEnumerator) SupportedType() resource.ResourceType { + return github.GithubTeamMembershipResourceType +} + +func (g *GithubTeamMembershipEnumerator) Enumerate() ([]*resource.Resource, error) { + ids, err := g.repository.ListTeamMemberships() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(g.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(ids)) + + for _, id := range ids { + results = append( + results, + g.factory.CreateAbstractResource( + string(g.SupportedType()), + id, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/github/init.go b/enumeration/remote/github/init.go new file mode 100644 index 00000000..64a97851 --- /dev/null +++ b/enumeration/remote/github/init.go @@ -0,0 +1,63 @@ +package github + +import ( + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/github" + "github.com/snyk/driftctl/enumeration/terraform" +) + +/** + * Initialize remote (configure credentials, launch tf providers and start gRPC clients) + * Required to use Scanner + */ + +func Init(version string, alerter *alerter.Alerter, + providerLibrary *terraform.ProviderLibrary, + remoteLibrary *common2.RemoteLibrary, + progress enumeration.ProgressCounter, + resourceSchemaRepository *resource.SchemaRepository, + factory resource.ResourceFactory, + configDir string) error { + + provider, err := NewGithubTerraformProvider(version, progress, configDir) + if err != nil { + return err + } + err = provider.Init() + if err != nil { + return err + } + + repositoryCache := cache.New(100) + + repository := NewGithubRepository(provider.GetConfig(), repositoryCache) + deserializer := resource.NewDeserializer(factory) + providerLibrary.AddProvider(terraform.GITHUB, provider) + + remoteLibrary.AddEnumerator(NewGithubTeamEnumerator(repository, factory)) + remoteLibrary.AddDetailsFetcher(github.GithubTeamResourceType, common2.NewGenericDetailsFetcher(github.GithubTeamResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewGithubRepositoryEnumerator(repository, factory)) + remoteLibrary.AddDetailsFetcher(github.GithubRepositoryResourceType, common2.NewGenericDetailsFetcher(github.GithubRepositoryResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewGithubMembershipEnumerator(repository, factory)) + remoteLibrary.AddDetailsFetcher(github.GithubMembershipResourceType, common2.NewGenericDetailsFetcher(github.GithubMembershipResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewGithubTeamMembershipEnumerator(repository, factory)) + remoteLibrary.AddDetailsFetcher(github.GithubTeamMembershipResourceType, common2.NewGenericDetailsFetcher(github.GithubTeamMembershipResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewGithubBranchProtectionEnumerator(repository, factory)) + remoteLibrary.AddDetailsFetcher(github.GithubBranchProtectionResourceType, common2.NewGenericDetailsFetcher(github.GithubBranchProtectionResourceType, provider, deserializer)) + + err = resourceSchemaRepository.Init(terraform.GITHUB, provider.Version(), provider.Schema()) + if err != nil { + return err + } + github.InitResourcesMetadata(resourceSchemaRepository) + + return nil +} diff --git a/pkg/remote/github/mock_GithubRepository.go b/enumeration/remote/github/mock_GithubRepository.go similarity index 100% rename from pkg/remote/github/mock_GithubRepository.go rename to enumeration/remote/github/mock_GithubRepository.go diff --git a/enumeration/remote/github/provider.go b/enumeration/remote/github/provider.go new file mode 100644 index 00000000..b2b595ad --- /dev/null +++ b/enumeration/remote/github/provider.go @@ -0,0 +1,76 @@ +package github + +import ( + "os" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/terraform" + terraform2 "github.com/snyk/driftctl/enumeration/terraform" +) + +type GithubTerraformProvider struct { + *terraform.TerraformProvider + name string + version string +} + +type githubConfig struct { + Token string + Owner string `cty:"owner"` + Organization string +} + +func NewGithubTerraformProvider(version string, progress enumeration.ProgressCounter, configDir string) (*GithubTerraformProvider, error) { + if version == "" { + version = "4.4.0" + } + p := &GithubTerraformProvider{ + version: version, + name: "github", + } + installer, err := terraform2.NewProviderInstaller(terraform2.ProviderConfig{ + Key: p.name, + Version: version, + ConfigDir: configDir, + }) + if err != nil { + return nil, err + } + tfProvider, err := terraform.NewTerraformProvider(installer, terraform.TerraformProviderConfig{ + Name: p.name, + DefaultAlias: p.GetConfig().getDefaultOwner(), + GetProviderConfig: func(owner string) interface{} { + return githubConfig{ + Owner: p.GetConfig().getDefaultOwner(), + } + }, + }, progress) + if err != nil { + return nil, err + } + p.TerraformProvider = tfProvider + return p, err +} + +func (c githubConfig) getDefaultOwner() string { + if c.Organization != "" { + return c.Organization + } + return c.Owner +} + +func (p GithubTerraformProvider) GetConfig() githubConfig { + return githubConfig{ + Token: os.Getenv("GITHUB_TOKEN"), + Owner: os.Getenv("GITHUB_OWNER"), + Organization: os.Getenv("GITHUB_ORGANIZATION"), + } +} + +func (p *GithubTerraformProvider) Name() string { + return p.name +} + +func (p *GithubTerraformProvider) Version() string { + return p.version +} diff --git a/enumeration/remote/github/repository.go b/enumeration/remote/github/repository.go new file mode 100644 index 00000000..80cfaed0 --- /dev/null +++ b/enumeration/remote/github/repository.go @@ -0,0 +1,361 @@ +package github + +import ( + "context" + "fmt" + "github.com/snyk/driftctl/enumeration/remote/cache" + + "github.com/shurcooL/githubv4" + "golang.org/x/oauth2" +) + +type GithubRepository interface { + ListRepositories() ([]string, error) + ListTeams() ([]Team, error) + ListMembership() ([]string, error) + ListTeamMemberships() ([]string, error) + ListBranchProtection() ([]string, error) +} + +type GithubGraphQLClient interface { + Query(ctx context.Context, q interface{}, variables map[string]interface{}) error +} + +type githubRepository struct { + client GithubGraphQLClient + ctx context.Context + config githubConfig + cache cache.Cache +} + +func NewGithubRepository(config githubConfig, c cache.Cache) *githubRepository { + ctx := context.Background() + ts := oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: config.Token}, + ) + oauthClient := oauth2.NewClient(ctx, ts) + + repo := &githubRepository{ + client: githubv4.NewClient(oauthClient), + ctx: context.Background(), + config: config, + cache: c, + } + + return repo +} + +func (r *githubRepository) ListRepositories() ([]string, error) { + if v := r.cache.Get("githubListRepositories"); v != nil { + return v.([]string), nil + } + + if r.config.Organization != "" { + results, err := r.listRepoForOrg() + if err != nil { + return nil, err + } + r.cache.Put("githubListRepositories", results) + return results, nil + } + + results, err := r.listRepoForOwner() + if err != nil { + return nil, err + } + r.cache.Put("githubListRepositories", results) + return results, nil +} + +type pageInfo struct { + EndCursor githubv4.String + HasNextPage bool +} + +type listRepoForOrgQuery struct { + Organization struct { + Repositories struct { + Nodes []struct { + Name string + } + PageInfo pageInfo + } `graphql:"repositories(first: 100, after: $cursor)"` + } `graphql:"organization(login: $org)"` +} + +func (r *githubRepository) listRepoForOrg() ([]string, error) { + query := listRepoForOrgQuery{} + variables := map[string]interface{}{ + "org": (githubv4.String)(r.config.Organization), + "cursor": (*githubv4.String)(nil), + } + var results []string + for { + err := r.client.Query(r.ctx, &query, variables) + if err != nil { + return nil, err + } + for _, repo := range query.Organization.Repositories.Nodes { + results = append(results, repo.Name) + } + if !query.Organization.Repositories.PageInfo.HasNextPage { + break + } + variables["cursor"] = githubv4.NewString(query.Organization.Repositories.PageInfo.EndCursor) + } + return results, nil +} + +type listRepoForOwnerQuery struct { + Viewer struct { + Repositories struct { + Nodes []struct { + Name string + } + PageInfo struct { + EndCursor githubv4.String + HasNextPage bool + } + } `graphql:"repositories(first: 100, after: $cursor)"` + } +} + +func (r githubRepository) listRepoForOwner() ([]string, error) { + query := listRepoForOwnerQuery{} + variables := map[string]interface{}{ + "cursor": (*githubv4.String)(nil), + } + var results []string + for { + err := r.client.Query(r.ctx, &query, variables) + if err != nil { + return nil, err + } + for _, repo := range query.Viewer.Repositories.Nodes { + results = append(results, repo.Name) + } + if !query.Viewer.Repositories.PageInfo.HasNextPage { + break + } + variables["cursor"] = githubv4.NewString(query.Viewer.Repositories.PageInfo.EndCursor) + } + return results, nil +} + +type listTeamsQuery struct { + Organization struct { + Teams struct { + Nodes []struct { + DatabaseId int + Slug string + } + PageInfo struct { + EndCursor githubv4.String + HasNextPage bool + } + } `graphql:"teams(first: 100, after: $cursor)"` + } `graphql:"organization(login: $login)"` +} + +type Team struct { + DatabaseId int + Slug string +} + +func (r githubRepository) ListTeams() ([]Team, error) { + if v := r.cache.Get("githubListTeams"); v != nil { + return v.([]Team), nil + } + + query := listTeamsQuery{} + results := make([]Team, 0) + if r.config.Organization == "" { + r.cache.Put("githubListTeams", results) + return results, nil + } + variables := map[string]interface{}{ + "cursor": (*githubv4.String)(nil), + "login": (githubv4.String)(r.config.Organization), + } + for { + err := r.client.Query(r.ctx, &query, variables) + if err != nil { + return nil, err + } + for _, team := range query.Organization.Teams.Nodes { + results = append(results, Team{ + DatabaseId: team.DatabaseId, + Slug: team.Slug, + }) + } + if !query.Organization.Teams.PageInfo.HasNextPage { + break + } + variables["cursor"] = githubv4.NewString(query.Organization.Teams.PageInfo.EndCursor) + } + + r.cache.Put("githubListTeams", results) + return results, nil +} + +type listMembership struct { + Organization struct { + MembersWithRole struct { + Nodes []struct { + Login string + } + PageInfo struct { + EndCursor githubv4.String + HasNextPage bool + } + } `graphql:"membersWithRole(first: 100, after: $cursor)"` + } `graphql:"organization(login: $login)"` +} + +func (r *githubRepository) ListMembership() ([]string, error) { + if v := r.cache.Get("githubListMembership"); v != nil { + return v.([]string), nil + } + + query := listMembership{} + results := make([]string, 0) + if r.config.Organization == "" { + r.cache.Put("githubListMembership", results) + return results, nil + } + variables := map[string]interface{}{ + "cursor": (*githubv4.String)(nil), + "login": (githubv4.String)(r.config.Organization), + } + for { + err := r.client.Query(r.ctx, &query, variables) + if err != nil { + return nil, err + } + for _, membership := range query.Organization.MembersWithRole.Nodes { + results = append(results, fmt.Sprintf("%s:%s", r.config.Organization, membership.Login)) + } + if !query.Organization.MembersWithRole.PageInfo.HasNextPage { + break + } + variables["cursor"] = githubv4.NewString(query.Organization.MembersWithRole.PageInfo.EndCursor) + } + + r.cache.Put("githubListMembership", results) + return results, nil +} + +type listTeamMembershipsQuery struct { + Organization struct { + Team struct { + Members struct { + Nodes []struct { + Login string + } + PageInfo struct { + EndCursor githubv4.String + HasNextPage bool + } + } `graphql:"members(first: 100, after: $cursor)"` + } `graphql:"team(slug: $slug)"` + } `graphql:"organization(login: $login)"` +} + +func (r githubRepository) ListTeamMemberships() ([]string, error) { + if v := r.cache.Get("githubListTeamMemberships"); v != nil { + return v.([]string), nil + } + + teamList, err := r.ListTeams() + if err != nil { + return nil, err + } + + query := listTeamMembershipsQuery{} + results := make([]string, 0) + if r.config.Organization == "" { + r.cache.Put("githubListTeamMemberships", results) + return results, nil + } + variables := map[string]interface{}{ + "login": (githubv4.String)(r.config.Organization), + } + + for _, team := range teamList { + variables["slug"] = (githubv4.String)(team.Slug) + variables["cursor"] = (*githubv4.String)(nil) + for { + err := r.client.Query(r.ctx, &query, variables) + if err != nil { + return nil, err + } + for _, membership := range query.Organization.Team.Members.Nodes { + results = append(results, fmt.Sprintf("%d:%s", team.DatabaseId, membership.Login)) + } + if !query.Organization.Team.Members.PageInfo.HasNextPage { + break + } + variables["cursor"] = query.Organization.Team.Members.PageInfo.EndCursor + } + } + + r.cache.Put("githubListTeamMemberships", results) + return results, nil +} + +type listBranchProtectionQuery struct { + Repository struct { + BranchProtectionRules struct { + Nodes []struct { + Id string + } + PageInfo struct { + EndCursor githubv4.String + HasNextPage bool + } + } `graphql:"branchProtectionRules(first: 1, after: $cursor)"` + } `graphql:"repository(owner: $owner, name: $name)"` +} + +func (r *githubRepository) ListBranchProtection() ([]string, error) { + if v := r.cache.Get("githubListBranchProtection"); v != nil { + return v.([]string), nil + } + + repoList, err := r.ListRepositories() + if err != nil { + return nil, err + } + + results := make([]string, 0) + query := listBranchProtectionQuery{} + variables := map[string]interface{}{ + "cursor": (*githubv4.String)(nil), + "owner": (githubv4.String)(r.config.getDefaultOwner()), + "name": (githubv4.String)(""), + } + + for _, repo := range repoList { + variables["name"] = (githubv4.String)(repo) + variables["cursor"] = (*githubv4.String)(nil) + for { + err := r.client.Query(r.ctx, &query, variables) + if err != nil { + return nil, err + } + for _, protection := range query.Repository.BranchProtectionRules.Nodes { + results = append(results, protection.Id) + } + + variables["cursor"] = query.Repository.BranchProtectionRules.PageInfo.EndCursor + + if !query.Repository.BranchProtectionRules.PageInfo.HasNextPage { + break + } + } + + } + + r.cache.Put("githubListBranchProtection", results) + return results, nil +} diff --git a/enumeration/remote/github/repository_test.go b/enumeration/remote/github/repository_test.go new file mode 100644 index 00000000..215aafa2 --- /dev/null +++ b/enumeration/remote/github/repository_test.go @@ -0,0 +1,920 @@ +package github + +import ( + "context" + "github.com/snyk/driftctl/enumeration/remote/cache" + "testing" + + "github.com/pkg/errors" + "github.com/shurcooL/githubv4" + "github.com/snyk/driftctl/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestListRepositoriesForUser_WithError(t *testing.T) { + mockedClient := mocks.GithubGraphQLClient{} + expectedError := errors.New("test error from graphql") + mockedClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(expectedError) + + r := githubRepository{ + client: &mockedClient, + config: githubConfig{}, + cache: cache.New(1), + } + + _, err := r.ListRepositories() + assert.Equal(t, expectedError, err) +} + +func TestListRepositoriesForUser(t *testing.T) { + mockedClient := mocks.GithubGraphQLClient{} + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listRepoForOwnerQuery) + if !ok { + return false + } + q.Viewer.Repositories.Nodes = []struct{ Name string }{ + { + Name: "repo1", + }, + { + Name: "repo2", + }, + } + q.Viewer.Repositories.PageInfo = pageInfo{ + EndCursor: "next", + HasNextPage: true, + } + return true + }), + map[string]interface{}{ + "cursor": (*githubv4.String)(nil), + }).Return(nil).Once() + + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listRepoForOwnerQuery) + if !ok { + return false + } + q.Viewer.Repositories.Nodes = []struct{ Name string }{ + { + Name: "repo3", + }, + { + Name: "repo4", + }, + } + q.Viewer.Repositories.PageInfo = pageInfo{ + HasNextPage: false, + } + return true + }), + map[string]interface{}{ + "cursor": githubv4.NewString("next"), + }).Return(nil).Once() + + store := cache.New(1) + r := githubRepository{ + client: &mockedClient, + ctx: context.TODO(), + config: githubConfig{}, + cache: store, + } + + repos, err := r.ListRepositories() + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, []string{ + "repo1", + "repo2", + "repo3", + "repo4", + }, repos) + + // Check that results were cached + cachedData, err := r.ListRepositories() + assert.NoError(t, err) + assert.Equal(t, repos, cachedData) + assert.IsType(t, []string{}, store.Get("githubListRepositories")) +} + +func TestListRepositoriesForOrganization_WithError(t *testing.T) { + mockedClient := mocks.GithubGraphQLClient{} + expectedError := errors.New("test error from graphql") + mockedClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(expectedError) + + r := githubRepository{ + client: &mockedClient, + config: githubConfig{ + Organization: "testorg", + }, + cache: cache.New(1), + } + + _, err := r.ListRepositories() + assert.Equal(t, expectedError, err) +} + +func TestListRepositoriesForOrganization(t *testing.T) { + mockedClient := mocks.GithubGraphQLClient{} + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listRepoForOrgQuery) + if !ok { + return false + } + q.Organization.Repositories.Nodes = []struct { + Name string + }{ + { + Name: "repo1", + }, + { + Name: "repo2", + }, + } + q.Organization.Repositories.PageInfo = pageInfo{ + EndCursor: "next", + HasNextPage: true, + } + return true + }), + map[string]interface{}{ + "org": (githubv4.String)("testorg"), + "cursor": (*githubv4.String)(nil), + }).Return(nil).Once() + + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listRepoForOrgQuery) + if !ok { + return false + } + q.Organization.Repositories.Nodes = []struct { + Name string + }{ + { + Name: "repo3", + }, + { + Name: "repo4", + }, + } + q.Organization.Repositories.PageInfo = pageInfo{ + HasNextPage: false, + } + return true + }), + map[string]interface{}{ + "org": (githubv4.String)("testorg"), + "cursor": githubv4.NewString("next"), + }).Return(nil).Once() + + store := cache.New(1) + r := githubRepository{ + client: &mockedClient, + ctx: context.TODO(), + config: githubConfig{ + Organization: "testorg", + }, + cache: store, + } + + repos, err := r.ListRepositories() + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, []string{ + "repo1", + "repo2", + "repo3", + "repo4", + }, repos) + + // Check that results were cached + cachedData, err := r.ListRepositories() + assert.NoError(t, err) + assert.Equal(t, repos, cachedData) + assert.IsType(t, []string{}, store.Get("githubListRepositories")) +} + +func TestListTeams_WithError(t *testing.T) { + mockedClient := mocks.GithubGraphQLClient{} + expectedError := errors.New("test error from graphql") + mockedClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(expectedError) + + r := githubRepository{ + client: &mockedClient, + config: githubConfig{ + Organization: "testorg", + }, + cache: cache.New(1), + } + + _, err := r.ListTeams() + assert.Equal(t, expectedError, err) +} + +func TestListTeams_WithoutOrganization(t *testing.T) { + r := githubRepository{cache: cache.New(1)} + + teams, err := r.ListTeams() + assert.Nil(t, err) + assert.Equal(t, []Team{}, teams) +} + +func TestListTeams(t *testing.T) { + mockedClient := mocks.GithubGraphQLClient{} + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listTeamsQuery) + if !ok { + return false + } + q.Organization.Teams.Nodes = []struct { + DatabaseId int + Slug string + }{ + { + DatabaseId: 1, + Slug: "1", + }, + { + DatabaseId: 2, + Slug: "2", + }, + } + q.Organization.Teams.PageInfo = pageInfo{ + EndCursor: "next", + HasNextPage: true, + } + return true + }), + map[string]interface{}{ + "login": (githubv4.String)("testorg"), + "cursor": (*githubv4.String)(nil), + }).Return(nil).Once() + + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listTeamsQuery) + if !ok { + return false + } + q.Organization.Teams.Nodes = []struct { + DatabaseId int + Slug string + }{ + { + DatabaseId: 3, + Slug: "3", + }, + { + DatabaseId: 4, + Slug: "4", + }, + } + q.Organization.Teams.PageInfo = pageInfo{ + HasNextPage: false, + } + return true + }), + map[string]interface{}{ + "login": (githubv4.String)("testorg"), + "cursor": githubv4.NewString("next"), + }).Return(nil).Once() + + store := cache.New(1) + r := githubRepository{ + client: &mockedClient, + ctx: context.TODO(), + config: githubConfig{ + Organization: "testorg", + }, + cache: store, + } + + teams, err := r.ListTeams() + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, []Team{ + {1, "1"}, + {2, "2"}, + {3, "3"}, + {4, "4"}, + }, teams) + + // Check that results were cached + cachedData, err := r.ListTeams() + assert.NoError(t, err) + assert.Equal(t, teams, cachedData) + assert.IsType(t, []Team{}, store.Get("githubListTeams")) +} + +func TestListTeamMemberships_WithTeamListingError(t *testing.T) { + mockedClient := mocks.GithubGraphQLClient{} + expectedError := errors.New("test error from graphql") + mockedClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(expectedError) + + r := githubRepository{ + client: &mockedClient, + config: githubConfig{ + Organization: "testorg", + }, + cache: cache.New(1), + } + + _, err := r.ListTeamMemberships() + assert.Equal(t, expectedError, err) +} + +func TestListTeamMemberships_WithError(t *testing.T) { + mockedClient := mocks.GithubGraphQLClient{} + + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listTeamsQuery) + if !ok { + return false + } + q.Organization.Teams.Nodes = []struct { + DatabaseId int + Slug string + }{ + { + DatabaseId: 1, + Slug: "foo", + }, + } + q.Organization.Teams.PageInfo = pageInfo{ + HasNextPage: false, + } + return true + }), + map[string]interface{}{ + "login": (githubv4.String)("testorg"), + "cursor": (*githubv4.String)(nil), + }).Return(nil) + + expectedError := errors.New("test error from graphql") + mockedClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(expectedError) + + r := githubRepository{ + client: &mockedClient, + config: githubConfig{ + Organization: "testorg", + }, + cache: cache.New(1), + } + + _, err := r.ListTeamMemberships() + assert.Equal(t, expectedError, err) +} + +func TestListTeamMemberships_WithoutOrganization(t *testing.T) { + r := githubRepository{cache: cache.New(1)} + + teams, err := r.ListTeamMemberships() + assert.Nil(t, err) + assert.Equal(t, []string{}, teams) +} + +func TestListTeamMemberships(t *testing.T) { + mockedClient := mocks.GithubGraphQLClient{} + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listTeamsQuery) + if !ok { + return false + } + q.Organization.Teams.Nodes = []struct { + DatabaseId int + Slug string + }{ + { + DatabaseId: 1, + Slug: "foo", + }, + { + DatabaseId: 2, + Slug: "bar", + }, + } + q.Organization.Teams.PageInfo = pageInfo{ + HasNextPage: false, + } + return true + }), + map[string]interface{}{ + "login": (githubv4.String)("testorg"), + "cursor": (*githubv4.String)(nil), + }).Return(nil).Once() + + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listTeamMembershipsQuery) + if !ok { + return false + } + q.Organization.Team.Members.Nodes = []struct { + Login string + }{ + { + Login: "user-1", + }, + { + Login: "user-2", + }, + } + q.Organization.Team.Members.PageInfo = pageInfo{ + EndCursor: "next", + HasNextPage: true, + } + return true + }), + map[string]interface{}{ + "login": (githubv4.String)("testorg"), + "cursor": (*githubv4.String)(nil), + "slug": (githubv4.String)("foo"), + }).Return(nil).Once() + + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listTeamMembershipsQuery) + if !ok { + return false + } + q.Organization.Team.Members.Nodes = []struct { + Login string + }{ + { + Login: "user-3", + }, + { + Login: "user-4", + }, + } + q.Organization.Team.Members.PageInfo = pageInfo{ + HasNextPage: false, + } + return true + }), + map[string]interface{}{ + "login": (githubv4.String)("testorg"), + "cursor": (githubv4.String)("next"), + "slug": (githubv4.String)("foo"), + }).Return(nil).Once() + + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listTeamMembershipsQuery) + if !ok { + return false + } + q.Organization.Team.Members.Nodes = []struct { + Login string + }{ + { + Login: "user-5", + }, + { + Login: "user-6", + }, + } + q.Organization.Team.Members.PageInfo = pageInfo{ + HasNextPage: false, + } + return true + }), + map[string]interface{}{ + "login": (githubv4.String)("testorg"), + "cursor": (*githubv4.String)(nil), + "slug": (githubv4.String)("bar"), + }).Return(nil).Once() + + store := cache.New(1) + r := githubRepository{ + client: &mockedClient, + ctx: context.TODO(), + config: githubConfig{ + Organization: "testorg", + }, + cache: store, + } + + memberships, err := r.ListTeamMemberships() + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, []string{ + "1:user-1", + "1:user-2", + "1:user-3", + "1:user-4", + "2:user-5", + "2:user-6", + }, memberships) + + // Check that results were cached + cachedData, err := r.ListTeamMemberships() + assert.NoError(t, err) + assert.Equal(t, memberships, cachedData) + assert.IsType(t, []string{}, store.Get("githubListTeamMemberships")) +} + +func TestListMembership_WithError(t *testing.T) { + mockedClient := mocks.GithubGraphQLClient{} + expectedError := errors.New("test error from graphql") + mockedClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(expectedError) + + r := githubRepository{ + client: &mockedClient, + config: githubConfig{ + Organization: "testorg", + }, + cache: cache.New(1), + } + + _, err := r.ListMembership() + assert.Equal(t, expectedError, err) +} + +func TestListMembership_WithoutOrganization(t *testing.T) { + r := githubRepository{cache: cache.New(1)} + + teams, err := r.ListMembership() + assert.Nil(t, err) + assert.Equal(t, []string{}, teams) +} + +func TestListMembership(t *testing.T) { + mockedClient := mocks.GithubGraphQLClient{} + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listMembership) + if !ok { + return false + } + q.Organization.MembersWithRole.Nodes = []struct { + Login string + }{ + { + Login: "user-admin", + }, + { + Login: "user-non-admin-1", + }, + } + q.Organization.MembersWithRole.PageInfo = pageInfo{ + EndCursor: "next", + HasNextPage: true, + } + return true + }), + map[string]interface{}{ + "login": (githubv4.String)("testorg"), + "cursor": (*githubv4.String)(nil), + }).Return(nil).Once() + + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listMembership) + if !ok { + return false + } + q.Organization.MembersWithRole.Nodes = []struct { + Login string + }{ + { + Login: "user-non-admin-2", + }, + { + Login: "user-non-admin-3", + }, + } + q.Organization.MembersWithRole.PageInfo = pageInfo{ + HasNextPage: false, + } + return true + }), + map[string]interface{}{ + "login": (githubv4.String)("testorg"), + "cursor": githubv4.NewString("next"), + }).Return(nil).Once() + + store := cache.New(1) + r := githubRepository{ + client: &mockedClient, + ctx: context.TODO(), + config: githubConfig{ + Organization: "testorg", + }, + cache: store, + } + + teams, err := r.ListMembership() + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, []string{ + "testorg:user-admin", + "testorg:user-non-admin-1", + "testorg:user-non-admin-2", + "testorg:user-non-admin-3", + }, teams) + + // Check that results were cached + cachedData, err := r.ListMembership() + assert.NoError(t, err) + assert.Equal(t, teams, cachedData) + assert.IsType(t, []string{}, store.Get("githubListMembership")) + +} + +func TestListBranchProtection_WithRepoListingError(t *testing.T) { + mockedClient := mocks.GithubGraphQLClient{} + expectedError := errors.New("test error from graphql") + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listRepoForOrgQuery) + if !ok { + return false + } + q.Organization.Repositories.Nodes = []struct { + Name string + }{ + { + Name: "repo1", + }, + { + Name: "repo2", + }, + } + q.Organization.Repositories.PageInfo = pageInfo{ + HasNextPage: false, + } + return true + }), + map[string]interface{}{ + "org": (githubv4.String)("my-organization"), + "cursor": (*githubv4.String)(nil), + }).Return(expectedError) + + r := githubRepository{ + client: &mockedClient, + config: githubConfig{ + Organization: "my-organization", + }, + cache: cache.New(1), + } + + _, err := r.ListBranchProtection() + assert.Equal(t, expectedError, err) +} + +func TestListBranchProtection_WithError(t *testing.T) { + mockedClient := mocks.GithubGraphQLClient{} + expectedError := errors.New("test error from graphql") + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listRepoForOrgQuery) + if !ok { + return false + } + q.Organization.Repositories.Nodes = []struct { + Name string + }{ + { + Name: "repo1", + }, + { + Name: "repo2", + }, + } + q.Organization.Repositories.PageInfo = pageInfo{ + HasNextPage: false, + } + return true + }), + map[string]interface{}{ + "org": (githubv4.String)("testorg"), + "cursor": (*githubv4.String)(nil), + }).Return(nil) + + mockedClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(expectedError) + + r := githubRepository{ + client: &mockedClient, + config: githubConfig{ + Organization: "testorg", + }, + cache: cache.New(1), + } + + _, err := r.ListBranchProtection() + assert.Equal(t, expectedError, err) +} + +func TestListBranchProtection(t *testing.T) { + mockedClient := mocks.GithubGraphQLClient{} + + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listRepoForOrgQuery) + if !ok { + return false + } + q.Organization.Repositories.Nodes = []struct { + Name string + }{ + { + Name: "repo1", + }, + { + Name: "repo2", + }, + } + q.Organization.Repositories.PageInfo = pageInfo{ + HasNextPage: false, + } + return true + }), + map[string]interface{}{ + "org": (githubv4.String)("my-organization"), + "cursor": (*githubv4.String)(nil), + }).Return(nil).Once() + + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listBranchProtectionQuery) + if !ok { + return false + } + q.Repository.BranchProtectionRules.Nodes = []struct { + Id string + }{ + { + Id: "id1", + }, + { + Id: "id2", + }, + } + q.Repository.BranchProtectionRules.PageInfo = pageInfo{ + EndCursor: "nextPage", + HasNextPage: true, + } + return true + }), + map[string]interface{}{ + "owner": (githubv4.String)("my-organization"), + "name": (githubv4.String)("repo1"), + "cursor": (*githubv4.String)(nil), + }).Return(nil).Once() + + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listBranchProtectionQuery) + if !ok { + return false + } + q.Repository.BranchProtectionRules.Nodes = []struct { + Id string + }{ + { + Id: "id3", + }, + { + Id: "id4", + }, + } + q.Repository.BranchProtectionRules.PageInfo = pageInfo{ + EndCursor: "nextPage", + HasNextPage: false, + } + return true + }), + map[string]interface{}{ + "owner": (githubv4.String)("my-organization"), + "name": (githubv4.String)("repo1"), + "cursor": (githubv4.String)("nextPage"), + }).Return(nil).Once() + + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listBranchProtectionQuery) + if !ok { + return false + } + q.Repository.BranchProtectionRules.Nodes = []struct { + Id string + }{ + { + Id: "id5", + }, + { + Id: "id6", + }, + } + q.Repository.BranchProtectionRules.PageInfo = pageInfo{ + EndCursor: "nextPage", + HasNextPage: true, + } + return true + }), + map[string]interface{}{ + "owner": (githubv4.String)("my-organization"), + "name": (githubv4.String)("repo2"), + "cursor": (*githubv4.String)(nil), + }).Return(nil).Once() + + mockedClient.On("Query", + mock.Anything, + mock.MatchedBy(func(query interface{}) bool { + q, ok := query.(*listBranchProtectionQuery) + if !ok { + return false + } + q.Repository.BranchProtectionRules.Nodes = []struct { + Id string + }{ + { + Id: "id7", + }, + { + Id: "id8", + }, + } + q.Repository.BranchProtectionRules.PageInfo = pageInfo{ + EndCursor: "nextPage", + HasNextPage: false, + } + return true + }), + map[string]interface{}{ + "owner": (githubv4.String)("my-organization"), + "name": (githubv4.String)("repo2"), + "cursor": (githubv4.String)("nextPage"), + }).Return(nil).Once() + + store := cache.New(1) + r := githubRepository{ + client: &mockedClient, + ctx: context.TODO(), + config: githubConfig{ + Organization: "my-organization", + }, + cache: store, + } + + teams, err := r.ListBranchProtection() + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, []string{ + "id1", + "id2", + "id3", + "id4", + "id5", + "id6", + "id7", + "id8", + }, teams) + + // Check that results were cached + cachedData, err := r.ListBranchProtection() + assert.NoError(t, err) + assert.Equal(t, teams, cachedData) + assert.IsType(t, []string{}, store.Get("githubListBranchProtection")) +} diff --git a/enumeration/remote/github_branch_protection_scanner_test.go b/enumeration/remote/github_branch_protection_scanner_test.go new file mode 100644 index 00000000..e2801285 --- /dev/null +++ b/enumeration/remote/github_branch_protection_scanner_test.go @@ -0,0 +1,125 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + github2 "github.com/snyk/driftctl/enumeration/remote/github" + terraform2 "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/pkg/errors" + githubres "github.com/snyk/driftctl/enumeration/resource/github" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + tftest "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/mock" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + "github.com/stretchr/testify/assert" +) + +func TestScanGithubBranchProtection(t *testing.T) { + + cases := []struct { + test string + dirName string + mocks func(*github2.MockGithubRepository, *mocks.AlerterInterface) + err error + }{ + { + test: "no branch protection", + dirName: "github_branch_protection_empty", + mocks: func(client *github2.MockGithubRepository, alerter *mocks.AlerterInterface) { + client.On("ListBranchProtection").Return([]string{}, nil) + }, + err: nil, + }, + { + test: "Multiple branch protections", + dirName: "github_branch_protection_multiples", + mocks: func(client *github2.MockGithubRepository, alerter *mocks.AlerterInterface) { + client.On("ListBranchProtection").Return([]string{ + "MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzI=", // "repo0:main" + "MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0Nzg=", // "repo0:toto" + "MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzQ=", // "repo1:main" + "MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0ODA=", // "repo1:toto" + "MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzE=", // "repo2:main" + "MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0Nzc=", // "repo2:toto" + }, nil) + }, + err: nil, + }, + { + test: "cannot list branch protections", + dirName: "github_branch_protection_empty", + mocks: func(client *github2.MockGithubRepository, alerter *mocks.AlerterInterface) { + client.On("ListBranchProtection").Return(nil, errors.New("Your token has not been granted the required scopes to execute this query.")) + + alerter.On("SendAlert", githubres.GithubBranchProtectionResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteGithubTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Your token has not been granted the required scopes to execute this query."), githubres.GithubBranchProtectionResourceType, githubres.GithubBranchProtectionResourceType), alerts.EnumerationPhase)).Return() + }, + err: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("github", "4.4.0") + githubres.InitResourcesMetadata(schemaRepository) + factory := terraform2.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + + providerLibrary := terraform2.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + mockedRepo := github2.MockGithubRepository{} + c.mocks(&mockedRepo, alerter) + + var repo github2.GithubRepository = &mockedRepo + + realProvider, err := tftest.InitTestGithubProvider(providerLibrary, "4.4.0") + if err != nil { + t.Fatal(err) + } + provider := tftest.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = github2.NewGithubRepository(realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(github2.NewGithubBranchProtectionEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(githubres.GithubBranchProtectionResourceType, common2.NewGenericDetailsFetcher(githubres.GithubBranchProtectionResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, githubres.GithubBranchProtectionResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + mockedRepo.AssertExpectations(tt) + alerter.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/github_membership_scanner_test.go b/enumeration/remote/github_membership_scanner_test.go new file mode 100644 index 00000000..acb9a687 --- /dev/null +++ b/enumeration/remote/github_membership_scanner_test.go @@ -0,0 +1,121 @@ +package remote + +import ( + "errors" + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + github2 "github.com/snyk/driftctl/enumeration/remote/github" + terraform2 "github.com/snyk/driftctl/enumeration/terraform" + + githubres "github.com/snyk/driftctl/enumeration/resource/github" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + tftest "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/mock" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + "github.com/stretchr/testify/assert" +) + +func TestScanGithubMembership(t *testing.T) { + + cases := []struct { + test string + dirName string + mocks func(*github2.MockGithubRepository, *mocks.AlerterInterface) + err error + }{ + { + test: "no members", + dirName: "github_membership_empty", + mocks: func(client *github2.MockGithubRepository, alerter *mocks.AlerterInterface) { + client.On("ListMembership").Return([]string{}, nil) + }, + err: nil, + }, + { + test: "Multiple membership with admin and member roles", + dirName: "github_membership_multiple", + mocks: func(client *github2.MockGithubRepository, alerter *mocks.AlerterInterface) { + client.On("ListMembership").Return([]string{ + "driftctl-test:driftctl-acceptance-tester", + "driftctl-test:eliecharra", + }, nil) + }, + err: nil, + }, + { + test: "cannot list membership", + dirName: "github_membership_empty", + mocks: func(client *github2.MockGithubRepository, alerter *mocks.AlerterInterface) { + client.On("ListMembership").Return(nil, errors.New("Your token has not been granted the required scopes to execute this query.")) + + alerter.On("SendAlert", githubres.GithubMembershipResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteGithubTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Your token has not been granted the required scopes to execute this query."), githubres.GithubMembershipResourceType, githubres.GithubMembershipResourceType), alerts.EnumerationPhase)).Return() + }, + err: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("github", "4.4.0") + githubres.InitResourcesMetadata(schemaRepository) + factory := terraform2.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + + providerLibrary := terraform2.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + mockedRepo := github2.MockGithubRepository{} + c.mocks(&mockedRepo, alerter) + + var repo github2.GithubRepository = &mockedRepo + + realProvider, err := tftest.InitTestGithubProvider(providerLibrary, "4.4.0") + if err != nil { + t.Fatal(err) + } + provider := tftest.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = github2.NewGithubRepository(realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(github2.NewGithubMembershipEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(githubres.GithubMembershipResourceType, common2.NewGenericDetailsFetcher(githubres.GithubMembershipResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, githubres.GithubMembershipResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + mockedRepo.AssertExpectations(tt) + alerter.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/github_repository_scanner_test.go b/enumeration/remote/github_repository_scanner_test.go new file mode 100644 index 00000000..14bac7c4 --- /dev/null +++ b/enumeration/remote/github_repository_scanner_test.go @@ -0,0 +1,121 @@ +package remote + +import ( + "errors" + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + github2 "github.com/snyk/driftctl/enumeration/remote/github" + terraform2 "github.com/snyk/driftctl/enumeration/terraform" + + githubres "github.com/snyk/driftctl/enumeration/resource/github" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + tftest "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/mock" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + "github.com/stretchr/testify/assert" +) + +func TestScanGithubRepository(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*github2.MockGithubRepository, *mocks.AlerterInterface) + err error + }{ + { + test: "no github repos", + dirName: "github_repository_empty", + mocks: func(client *github2.MockGithubRepository, alerter *mocks.AlerterInterface) { + client.On("ListRepositories").Return([]string{}, nil) + }, + err: nil, + }, + { + test: "Multiple github repos Table", + dirName: "github_repository_multiple", + mocks: func(client *github2.MockGithubRepository, alerter *mocks.AlerterInterface) { + client.On("ListRepositories").Return([]string{ + "driftctl", + "driftctl-demos", + }, nil) + }, + err: nil, + }, + { + test: "cannot list repositories", + dirName: "github_repository_empty", + mocks: func(client *github2.MockGithubRepository, alerter *mocks.AlerterInterface) { + client.On("ListRepositories").Return(nil, errors.New("Your token has not been granted the required scopes to execute this query.")) + + alerter.On("SendAlert", githubres.GithubRepositoryResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteGithubTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Your token has not been granted the required scopes to execute this query."), githubres.GithubRepositoryResourceType, githubres.GithubRepositoryResourceType), alerts.EnumerationPhase)).Return() + }, + err: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("github", "4.4.0") + githubres.InitResourcesMetadata(schemaRepository) + factory := terraform2.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + + providerLibrary := terraform2.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + mockedRepo := github2.MockGithubRepository{} + c.mocks(&mockedRepo, alerter) + + var repo github2.GithubRepository = &mockedRepo + + realProvider, err := tftest.InitTestGithubProvider(providerLibrary, "4.4.0") + if err != nil { + t.Fatal(err) + } + provider := tftest.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = github2.NewGithubRepository(realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(github2.NewGithubRepositoryEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(githubres.GithubRepositoryResourceType, common2.NewGenericDetailsFetcher(githubres.GithubRepositoryResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, githubres.GithubRepositoryResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + mockedRepo.AssertExpectations(tt) + alerter.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/github_team_membership_scanner_test.go b/enumeration/remote/github_team_membership_scanner_test.go new file mode 100644 index 00000000..5bb078e7 --- /dev/null +++ b/enumeration/remote/github_team_membership_scanner_test.go @@ -0,0 +1,121 @@ +package remote + +import ( + "errors" + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + github2 "github.com/snyk/driftctl/enumeration/remote/github" + terraform2 "github.com/snyk/driftctl/enumeration/terraform" + + githubres "github.com/snyk/driftctl/enumeration/resource/github" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + tftest "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/mock" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + "github.com/stretchr/testify/assert" +) + +func TestScanGithubTeamMembership(t *testing.T) { + + cases := []struct { + test string + dirName string + mocks func(*github2.MockGithubRepository, *mocks.AlerterInterface) + err error + }{ + { + test: "no github team memberships", + dirName: "github_team_membership_empty", + mocks: func(client *github2.MockGithubRepository, alerter *mocks.AlerterInterface) { + client.On("ListTeamMemberships").Return([]string{}, nil) + }, + err: nil, + }, + { + test: "multiple github team memberships", + dirName: "github_team_membership_multiple", + mocks: func(client *github2.MockGithubRepository, alerter *mocks.AlerterInterface) { + client.On("ListTeamMemberships").Return([]string{ + "4570529:driftctl-acceptance-tester", + "4570529:wbeuil", + }, nil) + }, + err: nil, + }, + { + test: "cannot list team membership", + dirName: "github_team_membership_empty", + mocks: func(client *github2.MockGithubRepository, alerter *mocks.AlerterInterface) { + client.On("ListTeamMemberships").Return(nil, errors.New("Your token has not been granted the required scopes to execute this query.")) + + alerter.On("SendAlert", githubres.GithubTeamMembershipResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteGithubTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Your token has not been granted the required scopes to execute this query."), githubres.GithubTeamMembershipResourceType, githubres.GithubTeamMembershipResourceType), alerts.EnumerationPhase)).Return() + }, + err: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("github", "4.4.0") + githubres.InitResourcesMetadata(schemaRepository) + factory := terraform2.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + + providerLibrary := terraform2.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + mockedRepo := github2.MockGithubRepository{} + c.mocks(&mockedRepo, alerter) + + var repo github2.GithubRepository = &mockedRepo + + realProvider, err := tftest.InitTestGithubProvider(providerLibrary, "4.4.0") + if err != nil { + t.Fatal(err) + } + provider := tftest.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = github2.NewGithubRepository(realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(github2.NewGithubTeamMembershipEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(githubres.GithubTeamMembershipResourceType, common2.NewGenericDetailsFetcher(githubres.GithubTeamMembershipResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, githubres.GithubTeamMembershipResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + mockedRepo.AssertExpectations(tt) + alerter.AssertExpectations(tt) + }) + } +} diff --git a/enumeration/remote/github_team_scanner_test.go b/enumeration/remote/github_team_scanner_test.go new file mode 100644 index 00000000..f99f3c2d --- /dev/null +++ b/enumeration/remote/github_team_scanner_test.go @@ -0,0 +1,122 @@ +package remote + +import ( + "errors" + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + github2 "github.com/snyk/driftctl/enumeration/remote/github" + terraform2 "github.com/snyk/driftctl/enumeration/terraform" + + githubres "github.com/snyk/driftctl/enumeration/resource/github" + "github.com/snyk/driftctl/mocks" + + testresource "github.com/snyk/driftctl/test/resource" + tftest "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/mock" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + "github.com/stretchr/testify/assert" +) + +func TestScanGithubTeam(t *testing.T) { + + tests := []struct { + test string + dirName string + mocks func(*github2.MockGithubRepository, *mocks.AlerterInterface) + err error + }{ + { + test: "no github teams", + dirName: "github_teams_empty", + mocks: func(client *github2.MockGithubRepository, alerter *mocks.AlerterInterface) { + client.On("ListTeams").Return([]github2.Team{}, nil) + }, + err: nil, + }, + { + test: "Multiple github teams with parent", + dirName: "github_teams_multiple", + mocks: func(client *github2.MockGithubRepository, alerter *mocks.AlerterInterface) { + client.On("ListTeams").Return([]github2.Team{ + {DatabaseId: 4556811}, // github_team.team1 + {DatabaseId: 4556812}, // github_team.team2 + {DatabaseId: 4556814}, // github_team.with_parent + }, nil) + }, + err: nil, + }, + { + test: "cannot list teams", + dirName: "github_teams_empty", + mocks: func(client *github2.MockGithubRepository, alerter *mocks.AlerterInterface) { + client.On("ListTeams").Return(nil, errors.New("Your token has not been granted the required scopes to execute this query.")) + + alerter.On("SendAlert", githubres.GithubTeamResourceType, alerts.NewRemoteAccessDeniedAlert(common2.RemoteGithubTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Your token has not been granted the required scopes to execute this query."), githubres.GithubTeamResourceType, githubres.GithubTeamResourceType), alerts.EnumerationPhase)).Return() + }, + err: nil, + }, + } + + schemaRepository := testresource.InitFakeSchemaRepository("github", "4.4.0") + githubres.InitResourcesMetadata(schemaRepository) + factory := terraform2.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range tests { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + + providerLibrary := terraform2.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + mockedRepo := github2.MockGithubRepository{} + c.mocks(&mockedRepo, alerter) + + var repo github2.GithubRepository = &mockedRepo + + realProvider, err := tftest.InitTestGithubProvider(providerLibrary, "4.4.0") + if err != nil { + t.Fatal(err) + } + provider := tftest.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + if shouldUpdate { + err := realProvider.Init() + if err != nil { + t.Fatal(err) + } + provider.ShouldUpdate() + repo = github2.NewGithubRepository(realProvider.GetConfig(), cache.New(0)) + } + + remoteLibrary.AddEnumerator(github2.NewGithubTeamEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(githubres.GithubTeamResourceType, common2.NewGenericDetailsFetcher(githubres.GithubTeamResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.err) + if err != nil { + return + } + test.TestAgainstGoldenFile(got, githubres.GithubTeamResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) + mockedRepo.AssertExpectations(tt) + alerter.AssertExpectations(tt) + }) + } +} diff --git a/pkg/remote/google/config/config.go b/enumeration/remote/google/config/config.go similarity index 100% rename from pkg/remote/google/config/config.go rename to enumeration/remote/google/config/config.go diff --git a/enumeration/remote/google/google_bigquery_dataset_enumerator.go b/enumeration/remote/google/google_bigquery_dataset_enumerator.go new file mode 100644 index 00000000..1c799aa9 --- /dev/null +++ b/enumeration/remote/google/google_bigquery_dataset_enumerator.go @@ -0,0 +1,49 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleBigqueryDatasetEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleBigqueryDatasetEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleBigqueryDatasetEnumerator { + return &GoogleBigqueryDatasetEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleBigqueryDatasetEnumerator) SupportedType() resource.ResourceType { + return google.GoogleBigqueryDatasetResourceType +} + +func (e *GoogleBigqueryDatasetEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllDatasets() + + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{ + "friendly_name": res.DisplayName, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_bigquery_table_enumerator.go b/enumeration/remote/google/google_bigquery_table_enumerator.go new file mode 100644 index 00000000..6b4dbea1 --- /dev/null +++ b/enumeration/remote/google/google_bigquery_table_enumerator.go @@ -0,0 +1,49 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleBigqueryTableEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleBigqueryTableEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleBigqueryTableEnumerator { + return &GoogleBigqueryTableEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleBigqueryTableEnumerator) SupportedType() resource.ResourceType { + return google.GoogleBigqueryTableResourceType +} + +func (e *GoogleBigqueryTableEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllTables() + + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{ + "friendly_name": res.DisplayName, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_bigtable_instance_enumerator.go b/enumeration/remote/google/google_bigtable_instance_enumerator.go new file mode 100644 index 00000000..ef4fed10 --- /dev/null +++ b/enumeration/remote/google/google_bigtable_instance_enumerator.go @@ -0,0 +1,53 @@ +package google + +import ( + "github.com/sirupsen/logrus" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleBigTableInstanceEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleBigTableInstanceEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleBigTableInstanceEnumerator { + return &GoogleBigTableInstanceEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleBigTableInstanceEnumerator) SupportedType() resource.ResourceType { + return google.GoogleBigTableInstanceResourceType +} + +func (e *GoogleBigTableInstanceEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllBigtableInstances() + + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + name, exist := res.GetResource().GetData().GetFields()["name"] + if !exist || name.GetStringValue() == "" { + logrus.WithField("name", res.GetName()).Warn("Unable to retrieve resource name") + continue + } + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + name.GetStringValue(), + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_bigtable_table_enumerator.go b/enumeration/remote/google/google_bigtable_table_enumerator.go new file mode 100644 index 00000000..32dc677b --- /dev/null +++ b/enumeration/remote/google/google_bigtable_table_enumerator.go @@ -0,0 +1,53 @@ +package google + +import ( + "github.com/sirupsen/logrus" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleBigtableTableEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleBigtableTableEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleBigtableTableEnumerator { + return &GoogleBigtableTableEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleBigtableTableEnumerator) SupportedType() resource.ResourceType { + return google.GoogleBigtableTableResourceType +} + +func (e *GoogleBigtableTableEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllBigtableTables() + + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + name, exist := res.GetResource().GetData().GetFields()["name"] + if !exist || name.GetStringValue() == "" { + logrus.WithField("name", res.GetName()).Warn("Unable to retrieve resource name") + continue + } + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + name.GetStringValue(), + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_cloudfunctions_function_enumerator.go b/enumeration/remote/google/google_cloudfunctions_function_enumerator.go new file mode 100644 index 00000000..914ac6b9 --- /dev/null +++ b/enumeration/remote/google/google_cloudfunctions_function_enumerator.go @@ -0,0 +1,53 @@ +package google + +import ( + "github.com/sirupsen/logrus" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleCloudFunctionsFunctionEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleCloudFunctionsFunctionEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleCloudFunctionsFunctionEnumerator { + return &GoogleCloudFunctionsFunctionEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleCloudFunctionsFunctionEnumerator) SupportedType() resource.ResourceType { + return google.GoogleCloudFunctionsFunctionResourceType +} + +func (e *GoogleCloudFunctionsFunctionEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllFunctions() + + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + name, exist := res.GetResource().GetData().GetFields()["name"] + if !exist || name.GetStringValue() == "" { + logrus.WithField("name", res.GetName()).Warn("Unable to retrieve resource name") + continue + } + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + name.GetStringValue(), + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_cloudrun_service_enumerator.go b/enumeration/remote/google/google_cloudrun_service_enumerator.go new file mode 100644 index 00000000..2a8cdc0e --- /dev/null +++ b/enumeration/remote/google/google_cloudrun_service_enumerator.go @@ -0,0 +1,62 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "strings" + + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleCloudRunServiceEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleCloudRunServiceEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleCloudRunServiceEnumerator { + return &GoogleCloudRunServiceEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleCloudRunServiceEnumerator) SupportedType() resource.ResourceType { + return google.GoogleCloudRunServiceResourceType +} + +func (e *GoogleCloudRunServiceEnumerator) Enumerate() ([]*resource.Resource, error) { + subnets, err := e.repository.SearchAllCloudRunServices() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(subnets)) + + for _, res := range subnets { + splittedName := strings.Split(res.GetName(), "/") + if len(splittedName) != 9 { + logrus.WithField("name", res.GetName()).Error("Unable to decode project from resource name") + continue + } + project := splittedName[4] + id := strings.Join([]string{ + "locations", res.GetLocation(), + "namespaces", project, + "services", res.GetDisplayName(), + }, "/") + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + id, + map[string]interface{}{ + "name": res.GetDisplayName(), + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_compute_address_enumerator.go b/enumeration/remote/google/google_compute_address_enumerator.go new file mode 100644 index 00000000..93a19382 --- /dev/null +++ b/enumeration/remote/google/google_compute_address_enumerator.go @@ -0,0 +1,58 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleComputeAddressEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleComputeAddressEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeAddressEnumerator { + return &GoogleComputeAddressEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleComputeAddressEnumerator) SupportedType() resource.ResourceType { + return google.GoogleComputeAddressResourceType +} + +func (e *GoogleComputeAddressEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllAddresses() + + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + // Global addresses are handled as a dedicated resource + if res.GetLocation() == "global" { + continue + } + address := "" + if addr, exist := res.GetAdditionalAttributes().GetFields()["address"]; exist { + address = addr.GetStringValue() + } + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{ + "name": res.GetDisplayName(), + "address": address, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_compute_disk_enumerator.go b/enumeration/remote/google/google_compute_disk_enumerator.go new file mode 100644 index 00000000..856eac25 --- /dev/null +++ b/enumeration/remote/google/google_compute_disk_enumerator.go @@ -0,0 +1,49 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleComputeDiskEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleComputeDiskEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeDiskEnumerator { + return &GoogleComputeDiskEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleComputeDiskEnumerator) SupportedType() resource.ResourceType { + return google.GoogleComputeDiskResourceType +} + +func (e *GoogleComputeDiskEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllDisks() + + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{ + "name": res.GetDisplayName(), + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_compute_firewall_enumerator.go b/enumeration/remote/google/google_compute_firewall_enumerator.go new file mode 100644 index 00000000..b0a75a7a --- /dev/null +++ b/enumeration/remote/google/google_compute_firewall_enumerator.go @@ -0,0 +1,59 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "strings" + + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleComputeFirewallEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleComputeFirewallEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeFirewallEnumerator { + return &GoogleComputeFirewallEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleComputeFirewallEnumerator) SupportedType() resource.ResourceType { + return google.GoogleComputeFirewallResourceType +} + +func (e *GoogleComputeFirewallEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllFirewalls() + + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + splittedName := strings.Split(res.GetName(), "/") + if len(splittedName) != 8 { + logrus.WithField("name", res.GetName()).Error("Unable to decode project from firewall name") + continue + } + project := splittedName[4] + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{ + "name": res.DisplayName, + "project": project, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_compute_forwarding_rule_enumerator.go b/enumeration/remote/google/google_compute_forwarding_rule_enumerator.go new file mode 100644 index 00000000..1da14161 --- /dev/null +++ b/enumeration/remote/google/google_compute_forwarding_rule_enumerator.go @@ -0,0 +1,45 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleComputeForwardingRuleEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleComputeForwardingRuleEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeForwardingRuleEnumerator { + return &GoogleComputeForwardingRuleEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleComputeForwardingRuleEnumerator) SupportedType() resource.ResourceType { + return google.GoogleComputeForwardingRuleResourceType +} + +func (e *GoogleComputeForwardingRuleEnumerator) Enumerate() ([]*resource.Resource, error) { + forwardingRules, err := e.repository.SearchAllForwardingRules() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(forwardingRules)) + for _, res := range forwardingRules { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_compute_global_address_enumerator.go b/enumeration/remote/google/google_compute_global_address_enumerator.go new file mode 100644 index 00000000..6dc0647b --- /dev/null +++ b/enumeration/remote/google/google_compute_global_address_enumerator.go @@ -0,0 +1,60 @@ +package google + +import ( + "github.com/sirupsen/logrus" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleComputeGlobalAddressEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleComputeGlobalAddressEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeGlobalAddressEnumerator { + return &GoogleComputeGlobalAddressEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleComputeGlobalAddressEnumerator) SupportedType() resource.ResourceType { + return google.GoogleComputeGlobalAddressResourceType +} + +func (e *GoogleComputeGlobalAddressEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllGlobalAddresses() + + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + name, exist := res.GetResource().GetData().GetFields()["name"] + if !exist || name.GetStringValue() == "" { + logrus.WithField("name", res.GetName()).Warn("Unable to retrieve resource name") + continue + } + address := "" + if addr, exist := res.GetResource().GetData().GetFields()["address"]; exist { + address = addr.GetStringValue() + } + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{ + "name": name.GetStringValue(), + "address": address, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_compute_global_forwarding_rule_enumerator.go b/enumeration/remote/google/google_compute_global_forwarding_rule_enumerator.go new file mode 100644 index 00000000..54fedb28 --- /dev/null +++ b/enumeration/remote/google/google_compute_global_forwarding_rule_enumerator.go @@ -0,0 +1,46 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleComputeGlobalForwardingRuleEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleComputeGlobalForwardingRuleEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeGlobalForwardingRuleEnumerator { + return &GoogleComputeGlobalForwardingRuleEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleComputeGlobalForwardingRuleEnumerator) SupportedType() resource.ResourceType { + return google.GoogleComputeGlobalForwardingRuleResourceType +} + +func (e *GoogleComputeGlobalForwardingRuleEnumerator) Enumerate() ([]*resource.Resource, error) { + globalForwardingRules, err := e.repository.SearchAllGlobalForwardingRules() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(globalForwardingRules)) + + for _, res := range globalForwardingRules { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_compute_health_check_enumerator.go b/enumeration/remote/google/google_compute_health_check_enumerator.go new file mode 100644 index 00000000..275e4e0b --- /dev/null +++ b/enumeration/remote/google/google_compute_health_check_enumerator.go @@ -0,0 +1,47 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleComputeHealthCheckEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleComputeHealthCheckEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeHealthCheckEnumerator { + return &GoogleComputeHealthCheckEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleComputeHealthCheckEnumerator) SupportedType() resource.ResourceType { + return google.GoogleComputeHealthCheckResourceType +} + +func (e *GoogleComputeHealthCheckEnumerator) Enumerate() ([]*resource.Resource, error) { + checks, err := e.repository.SearchAllHealthChecks() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(checks)) + for _, res := range checks { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{ + "name": res.GetDisplayName(), + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_compute_image_enumerator.go b/enumeration/remote/google/google_compute_image_enumerator.go new file mode 100644 index 00000000..7dbf3de2 --- /dev/null +++ b/enumeration/remote/google/google_compute_image_enumerator.go @@ -0,0 +1,49 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleComputeImageEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleComputeImageEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeImageEnumerator { + return &GoogleComputeImageEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleComputeImageEnumerator) SupportedType() resource.ResourceType { + return google.GoogleComputeImageResourceType +} + +func (e *GoogleComputeImageEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllImages() + + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{ + "name": res.GetDisplayName(), + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_compute_instance_enumerator.go b/enumeration/remote/google/google_compute_instance_enumerator.go new file mode 100644 index 00000000..547c8a3f --- /dev/null +++ b/enumeration/remote/google/google_compute_instance_enumerator.go @@ -0,0 +1,47 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleComputeInstanceEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleComputeInstanceEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeInstanceEnumerator { + return &GoogleComputeInstanceEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleComputeInstanceEnumerator) SupportedType() resource.ResourceType { + return google.GoogleComputeInstanceResourceType +} + +func (e *GoogleComputeInstanceEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllInstances() + + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_compute_instance_group_enumerator.go b/enumeration/remote/google/google_compute_instance_group_enumerator.go new file mode 100644 index 00000000..dfa987c3 --- /dev/null +++ b/enumeration/remote/google/google_compute_instance_group_enumerator.go @@ -0,0 +1,58 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "strings" + + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleComputeInstanceGroupEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleComputeInstanceGroupEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeInstanceGroupEnumerator { + return &GoogleComputeInstanceGroupEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleComputeInstanceGroupEnumerator) SupportedType() resource.ResourceType { + return google.GoogleComputeInstanceGroupResourceType +} + +func (e *GoogleComputeInstanceGroupEnumerator) Enumerate() ([]*resource.Resource, error) { + groups, err := e.repository.SearchAllInstanceGroups() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(groups)) + for _, res := range groups { + splittedName := strings.Split(res.GetName(), "/") + if len(splittedName) != 9 { + logrus.WithField("name", res.GetName()).Error("Unable to decode project from instance group name") + continue + } + project := splittedName[4] + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{ + "name": res.GetDisplayName(), + "project": project, + "location": res.GetLocation(), + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_compute_instance_group_manager_enumerator.go b/enumeration/remote/google/google_compute_instance_group_manager_enumerator.go new file mode 100644 index 00000000..3257cc70 --- /dev/null +++ b/enumeration/remote/google/google_compute_instance_group_manager_enumerator.go @@ -0,0 +1,56 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "strings" + + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleComputeInstanceGroupManagerEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleComputeInstanceGroupManagerEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeInstanceGroupManagerEnumerator { + return &GoogleComputeInstanceGroupManagerEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleComputeInstanceGroupManagerEnumerator) SupportedType() resource.ResourceType { + return google.GoogleComputeInstanceGroupManagerResourceType +} + +func (e *GoogleComputeInstanceGroupManagerEnumerator) Enumerate() ([]*resource.Resource, error) { + items, err := e.repository.SearchAllInstanceGroupManagers() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(items)) + for _, res := range items { + splittedName := strings.Split(res.GetName(), "/") + if len(splittedName) != 9 { + logrus.WithField("name", res.GetName()).Error("Unable to decode project from instance group name") + continue + } + name := splittedName[8] + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{ + "name": name, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_compute_network_enumerator.go b/enumeration/remote/google/google_compute_network_enumerator.go new file mode 100644 index 00000000..a58a29ac --- /dev/null +++ b/enumeration/remote/google/google_compute_network_enumerator.go @@ -0,0 +1,48 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleComputeNetworkEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleComputeNetworkEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeNetworkEnumerator { + return &GoogleComputeNetworkEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleComputeNetworkEnumerator) SupportedType() resource.ResourceType { + return google.GoogleComputeNetworkResourceType +} + +func (e *GoogleComputeNetworkEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllNetworks() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{ + "name": res.DisplayName, + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_compute_node_group_enumerator.go b/enumeration/remote/google/google_compute_node_group_enumerator.go new file mode 100644 index 00000000..fe2d724f --- /dev/null +++ b/enumeration/remote/google/google_compute_node_group_enumerator.go @@ -0,0 +1,47 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleComputeNodeGroupEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleComputeNodeGroupEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeNodeGroupEnumerator { + return &GoogleComputeNodeGroupEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleComputeNodeGroupEnumerator) SupportedType() resource.ResourceType { + return google.GoogleComputeNodeGroupResourceType +} + +func (e *GoogleComputeNodeGroupEnumerator) Enumerate() ([]*resource.Resource, error) { + nodeGroups, err := e.repository.SearchAllNodeGroups() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(nodeGroups)) + for _, res := range nodeGroups { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{ + "name": res.GetName(), + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_compute_router_enumerator.go b/enumeration/remote/google/google_compute_router_enumerator.go new file mode 100644 index 00000000..7c2d3f18 --- /dev/null +++ b/enumeration/remote/google/google_compute_router_enumerator.go @@ -0,0 +1,46 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleComputeRouterEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleComputeRouterEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeRouterEnumerator { + return &GoogleComputeRouterEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleComputeRouterEnumerator) SupportedType() resource.ResourceType { + return google.GoogleComputeRouterResourceType +} + +func (e *GoogleComputeRouterEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllRouters() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_compute_subnetwork_enumerator.go b/enumeration/remote/google/google_compute_subnetwork_enumerator.go new file mode 100644 index 00000000..19d4fabe --- /dev/null +++ b/enumeration/remote/google/google_compute_subnetwork_enumerator.go @@ -0,0 +1,49 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleComputeSubnetworkEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleComputeSubnetworkEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeSubnetworkEnumerator { + return &GoogleComputeSubnetworkEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleComputeSubnetworkEnumerator) SupportedType() resource.ResourceType { + return google.GoogleComputeSubnetworkResourceType +} + +func (e *GoogleComputeSubnetworkEnumerator) Enumerate() ([]*resource.Resource, error) { + subnets, err := e.repository.SearchAllSubnetworks() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(subnets)) + + for _, res := range subnets { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + trimResourceName(res.GetName()), + map[string]interface{}{ + "name": res.GetDisplayName(), + "region": res.GetLocation(), + }, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_dns_managed_zone_enumerator.go b/enumeration/remote/google/google_dns_managed_zone_enumerator.go new file mode 100644 index 00000000..ea872d48 --- /dev/null +++ b/enumeration/remote/google/google_dns_managed_zone_enumerator.go @@ -0,0 +1,59 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "strings" + + "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleDNSManagedZoneEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleDNSManagedZoneEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleDNSManagedZoneEnumerator { + return &GoogleDNSManagedZoneEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleDNSManagedZoneEnumerator) SupportedType() resource.ResourceType { + return google.GoogleDNSManagedZoneResourceType +} + +func (e *GoogleDNSManagedZoneEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllDNSManagedZones() + + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + // We should have ID = "projects/cloudskiff-dev-elie/managedZones/example-zone" + // We have projects/cloudskiff-dev-elie/managedZones/2435093289230056557 + for _, res := range resources { + id := trimResourceName(res.Name) + splittedId := strings.Split(id, "/managedZones/") + if len(splittedId) != 2 { + logrus.WithField("id", res.Name).Warn("Cannot parse google_dns_managed_zone ID") + continue + } + id = strings.Join([]string{splittedId[0], "managedZones", res.DisplayName}, "/") + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + id, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_project_iam_member_enumerator.go b/enumeration/remote/google/google_project_iam_member_enumerator.go new file mode 100644 index 00000000..5af56429 --- /dev/null +++ b/enumeration/remote/google/google_project_iam_member_enumerator.go @@ -0,0 +1,57 @@ +package google + +import ( + "fmt" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleProjectIamMemberEnumerator struct { + repository repository.CloudResourceManagerRepository + factory resource.ResourceFactory +} + +func NewGoogleProjectIamMemberEnumerator(repo repository.CloudResourceManagerRepository, factory resource.ResourceFactory) *GoogleProjectIamMemberEnumerator { + return &GoogleProjectIamMemberEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleProjectIamMemberEnumerator) SupportedType() resource.ResourceType { + return google.GoogleProjectIamMemberResourceType +} + +func (e *GoogleProjectIamMemberEnumerator) Enumerate() ([]*resource.Resource, error) { + results := make([]*resource.Resource, 0) + + bindingsByProject, err := e.repository.ListProjectsBindings() + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + for project, bindings := range bindingsByProject { + for roleName, members := range bindings { + for _, member := range members { + id := fmt.Sprintf("%s/%s/%s", project, roleName, member) + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + id, + map[string]interface{}{ + "id": id, + "project": project, + "role": roleName, + "member": member, + }, + ), + ) + } + } + } + + return results, err +} diff --git a/enumeration/remote/google/google_sql_database_instance_enumerator.go b/enumeration/remote/google/google_sql_database_instance_enumerator.go new file mode 100644 index 00000000..fe0eb5ca --- /dev/null +++ b/enumeration/remote/google/google_sql_database_instance_enumerator.go @@ -0,0 +1,53 @@ +package google + +import ( + "github.com/sirupsen/logrus" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleSQLDatabaseInstanceEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleSQLDatabaseInstanceEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleSQLDatabaseInstanceEnumerator { + return &GoogleSQLDatabaseInstanceEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleSQLDatabaseInstanceEnumerator) SupportedType() resource.ResourceType { + return google.GoogleSQLDatabaseInstanceResourceType +} + +func (e *GoogleSQLDatabaseInstanceEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllSQLDatabaseInstances() + + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + name, exist := res.GetResource().GetData().GetFields()["name"] + if !exist || name.GetStringValue() == "" { + logrus.WithField("name", res.GetName()).Warn("Unable to retrieve resource name") + continue + } + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + name.GetStringValue(), + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_storage_bucket_enumerator.go b/enumeration/remote/google/google_storage_bucket_enumerator.go new file mode 100644 index 00000000..03237cb3 --- /dev/null +++ b/enumeration/remote/google/google_storage_bucket_enumerator.go @@ -0,0 +1,47 @@ +package google + +import ( + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleStorageBucketEnumerator struct { + repository repository.AssetRepository + factory resource.ResourceFactory +} + +func NewGoogleStorageBucketEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleStorageBucketEnumerator { + return &GoogleStorageBucketEnumerator{ + repository: repo, + factory: factory, + } +} + +func (e *GoogleStorageBucketEnumerator) SupportedType() resource.ResourceType { + return google.GoogleStorageBucketResourceType +} + +func (e *GoogleStorageBucketEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllBuckets() + + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, res := range resources { + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + res.DisplayName, + map[string]interface{}{}, + ), + ) + } + + return results, err +} diff --git a/enumeration/remote/google/google_storage_bucket_iam_member_enumerator.go b/enumeration/remote/google/google_storage_bucket_iam_member_enumerator.go new file mode 100644 index 00000000..1006bb43 --- /dev/null +++ b/enumeration/remote/google/google_storage_bucket_iam_member_enumerator.go @@ -0,0 +1,64 @@ +package google + +import ( + "fmt" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + repository2 "github.com/snyk/driftctl/enumeration/remote/google/repository" + + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) + +type GoogleStorageBucketIamMemberEnumerator struct { + repository repository2.AssetRepository + storageRepository repository2.StorageRepository + factory resource.ResourceFactory +} + +func NewGoogleStorageBucketIamMemberEnumerator(repo repository2.AssetRepository, storageRepo repository2.StorageRepository, factory resource.ResourceFactory) *GoogleStorageBucketIamMemberEnumerator { + return &GoogleStorageBucketIamMemberEnumerator{ + repository: repo, + storageRepository: storageRepo, + factory: factory, + } +} + +func (e *GoogleStorageBucketIamMemberEnumerator) SupportedType() resource.ResourceType { + return google.GoogleStorageBucketIamMemberResourceType +} + +func (e *GoogleStorageBucketIamMemberEnumerator) Enumerate() ([]*resource.Resource, error) { + resources, err := e.repository.SearchAllBuckets() + if err != nil { + return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), google.GoogleStorageBucketResourceType) + } + + results := make([]*resource.Resource, 0, len(resources)) + + for _, bucket := range resources { + bindings, err := e.storageRepository.ListAllBindings(bucket.DisplayName) + if err != nil { + return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) + } + for roleName, members := range bindings { + for _, member := range members { + id := fmt.Sprintf("b/%s/%s/%s", bucket.DisplayName, roleName, member) + results = append( + results, + e.factory.CreateAbstractResource( + string(e.SupportedType()), + id, + map[string]interface{}{ + "id": id, + "bucket": fmt.Sprintf("b/%s", bucket.DisplayName), + "role": roleName, + "member": member, + }, + ), + ) + } + } + } + + return results, err +} diff --git a/enumeration/remote/google/init.go b/enumeration/remote/google/init.go new file mode 100644 index 00000000..be795acd --- /dev/null +++ b/enumeration/remote/google/init.go @@ -0,0 +1,121 @@ +package google + +import ( + "context" + + "github.com/snyk/driftctl/enumeration" + + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + repository2 "github.com/snyk/driftctl/enumeration/remote/google/repository" + "github.com/snyk/driftctl/enumeration/terraform" + + asset "cloud.google.com/go/asset/apiv1" + "cloud.google.com/go/storage" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" + "google.golang.org/api/cloudresourcemanager/v1" +) + +func Init(version string, alerter *alerter.Alerter, + providerLibrary *terraform.ProviderLibrary, + remoteLibrary *common2.RemoteLibrary, + progress enumeration.ProgressCounter, + resourceSchemaRepository *resource.SchemaRepository, + factory resource.ResourceFactory, + configDir string) error { + + provider, err := NewGCPTerraformProvider(version, progress, configDir) + if err != nil { + return err + } + + err = provider.CheckCredentialsExist() + if err != nil { + return err + } + + err = provider.Init() + if err != nil { + return err + } + + repositoryCache := cache.New(100) + + ctx := context.Background() + assetClient, err := asset.NewClient(ctx) + if err != nil { + return err + } + + storageClient, err := storage.NewClient(ctx) + if err != nil { + return err + } + + crmService, err := cloudresourcemanager.NewService(ctx) + if err != nil { + return err + } + + assetRepository := repository2.NewAssetRepository(assetClient, provider.GetConfig(), repositoryCache) + storageRepository := repository2.NewStorageRepository(storageClient, repositoryCache) + iamRepository := repository2.NewCloudResourceManagerRepository(crmService, provider.GetConfig(), repositoryCache) + + providerLibrary.AddProvider(terraform.GOOGLE, provider) + deserializer := resource.NewDeserializer(factory) + + remoteLibrary.AddEnumerator(NewGoogleStorageBucketEnumerator(assetRepository, factory)) + remoteLibrary.AddDetailsFetcher(google.GoogleStorageBucketResourceType, common2.NewGenericDetailsFetcher(google.GoogleStorageBucketResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewGoogleComputeFirewallEnumerator(assetRepository, factory)) + remoteLibrary.AddDetailsFetcher(google.GoogleComputeFirewallResourceType, common2.NewGenericDetailsFetcher(google.GoogleComputeFirewallResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewGoogleComputeRouterEnumerator(assetRepository, factory)) + + remoteLibrary.AddEnumerator(NewGoogleComputeInstanceEnumerator(assetRepository, factory)) + + remoteLibrary.AddEnumerator(NewGoogleProjectIamMemberEnumerator(iamRepository, factory)) + remoteLibrary.AddDetailsFetcher(google.GoogleProjectIamMemberResourceType, common2.NewGenericDetailsFetcher(google.GoogleProjectIamMemberResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewGoogleStorageBucketIamMemberEnumerator(assetRepository, storageRepository, factory)) + remoteLibrary.AddDetailsFetcher(google.GoogleStorageBucketIamMemberResourceType, common2.NewGenericDetailsFetcher(google.GoogleStorageBucketIamMemberResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewGoogleComputeNetworkEnumerator(assetRepository, factory)) + remoteLibrary.AddDetailsFetcher(google.GoogleComputeNetworkResourceType, common2.NewGenericDetailsFetcher(google.GoogleComputeNetworkResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewGoogleComputeSubnetworkEnumerator(assetRepository, factory)) + remoteLibrary.AddDetailsFetcher(google.GoogleComputeSubnetworkResourceType, common2.NewGenericDetailsFetcher(google.GoogleComputeSubnetworkResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewGoogleDNSManagedZoneEnumerator(assetRepository, factory)) + + remoteLibrary.AddEnumerator(NewGoogleComputeInstanceGroupEnumerator(assetRepository, factory)) + remoteLibrary.AddDetailsFetcher(google.GoogleComputeInstanceGroupResourceType, common2.NewGenericDetailsFetcher(google.GoogleComputeInstanceGroupResourceType, provider, deserializer)) + + remoteLibrary.AddEnumerator(NewGoogleBigqueryDatasetEnumerator(assetRepository, factory)) + remoteLibrary.AddEnumerator(NewGoogleBigqueryTableEnumerator(assetRepository, factory)) + + remoteLibrary.AddEnumerator(NewGoogleComputeAddressEnumerator(assetRepository, factory)) + remoteLibrary.AddEnumerator(NewGoogleComputeGlobalAddressEnumerator(assetRepository, factory)) + remoteLibrary.AddEnumerator(NewGoogleCloudFunctionsFunctionEnumerator(assetRepository, factory)) + remoteLibrary.AddEnumerator(NewGoogleComputeDiskEnumerator(assetRepository, factory)) + remoteLibrary.AddEnumerator(NewGoogleComputeImageEnumerator(assetRepository, factory)) + remoteLibrary.AddEnumerator(NewGoogleBigTableInstanceEnumerator(assetRepository, factory)) + remoteLibrary.AddEnumerator(NewGoogleBigtableTableEnumerator(assetRepository, factory)) + remoteLibrary.AddEnumerator(NewGoogleSQLDatabaseInstanceEnumerator(assetRepository, factory)) + remoteLibrary.AddEnumerator(NewGoogleComputeHealthCheckEnumerator(assetRepository, factory)) + remoteLibrary.AddEnumerator(NewGoogleCloudRunServiceEnumerator(assetRepository, factory)) + remoteLibrary.AddEnumerator(NewGoogleComputeNodeGroupEnumerator(assetRepository, factory)) + remoteLibrary.AddEnumerator(NewGoogleComputeForwardingRuleEnumerator(assetRepository, factory)) + remoteLibrary.AddEnumerator(NewGoogleComputeInstanceGroupManagerEnumerator(assetRepository, factory)) + remoteLibrary.AddEnumerator(NewGoogleComputeGlobalForwardingRuleEnumerator(assetRepository, factory)) + + err = resourceSchemaRepository.Init(terraform.GOOGLE, provider.Version(), provider.Schema()) + if err != nil { + return err + } + google.InitResourcesMetadata(resourceSchemaRepository) + + return nil +} diff --git a/enumeration/remote/google/provider.go b/enumeration/remote/google/provider.go new file mode 100644 index 00000000..276634b2 --- /dev/null +++ b/enumeration/remote/google/provider.go @@ -0,0 +1,78 @@ +package google + +import ( + "context" + "errors" + "os" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/google/config" + "github.com/snyk/driftctl/enumeration/remote/terraform" + terraform2 "github.com/snyk/driftctl/enumeration/terraform" + + asset "cloud.google.com/go/asset/apiv1" +) + +type GCPTerraformProvider struct { + *terraform.TerraformProvider + name string + version string +} + +func NewGCPTerraformProvider(version string, progress enumeration.ProgressCounter, configDir string) (*GCPTerraformProvider, error) { + if version == "" { + version = "3.78.0" + } + p := &GCPTerraformProvider{ + version: version, + name: terraform2.GOOGLE, + } + installer, err := terraform2.NewProviderInstaller(terraform2.ProviderConfig{ + Key: p.name, + Version: version, + ConfigDir: configDir, + }) + if err != nil { + return nil, err + } + tfProvider, err := terraform.NewTerraformProvider(installer, terraform.TerraformProviderConfig{ + Name: p.name, + GetProviderConfig: func(alias string) interface{} { + return p.GetConfig() + }, + }, progress) + + if err != nil { + return nil, err + } + + p.TerraformProvider = tfProvider + + return p, err +} + +func (p *GCPTerraformProvider) Name() string { + return p.name +} + +func (p *GCPTerraformProvider) Version() string { + return p.version +} + +func (p *GCPTerraformProvider) GetConfig() config.GCPTerraformConfig { + return config.GCPTerraformConfig{ + Project: os.Getenv("CLOUDSDK_CORE_PROJECT"), + Region: os.Getenv("CLOUDSDK_COMPUTE_REGION"), + Zone: os.Getenv("CLOUDSDK_COMPUTE_ZONE"), + } +} + +func (p *GCPTerraformProvider) CheckCredentialsExist() error { + client, err := asset.NewClient(context.Background()) + if err != nil { + return errors.New("Please use a Service Account to authenticate on GCP.\n" + + "For more information: https://cloud.google.com/docs/authentication/production") + } + _ = client.Close() + return nil +} diff --git a/enumeration/remote/google/repository/asset.go b/enumeration/remote/google/repository/asset.go new file mode 100644 index 00000000..153d0d46 --- /dev/null +++ b/enumeration/remote/google/repository/asset.go @@ -0,0 +1,282 @@ +package repository + +import ( + "context" + "fmt" + "github.com/snyk/driftctl/enumeration/remote/cache" + "github.com/snyk/driftctl/enumeration/remote/google/config" + + asset "cloud.google.com/go/asset/apiv1" + "google.golang.org/api/iterator" + assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" +) + +// https://cloud.google.com/asset-inventory/docs/supported-asset-types#supported_resource_types +const ( + storageBucketAssetType = "storage.googleapis.com/Bucket" + computeFirewallAssetType = "compute.googleapis.com/Firewall" + computeRouterAssetType = "compute.googleapis.com/Router" + computeInstanceAssetType = "compute.googleapis.com/Instance" + computeNetworkAssetType = "compute.googleapis.com/Network" + computeSubnetworkAssetType = "compute.googleapis.com/Subnetwork" + computeDiskAssetType = "compute.googleapis.com/Disk" + computeImageAssetType = "compute.googleapis.com/Image" + dnsManagedZoneAssetType = "dns.googleapis.com/ManagedZone" + computeInstanceGroupAssetType = "compute.googleapis.com/InstanceGroup" + bigqueryDatasetAssetType = "bigquery.googleapis.com/Dataset" + bigqueryTableAssetType = "bigquery.googleapis.com/Table" + computeAddressAssetType = "compute.googleapis.com/Address" + computeGlobalAddressAssetType = "compute.googleapis.com/GlobalAddress" + cloudFunctionsFunction = "cloudfunctions.googleapis.com/CloudFunction" + bigtableInstanceAssetType = "bigtableadmin.googleapis.com/Instance" + bigtableTableAssetType = "bigtableadmin.googleapis.com/Table" + sqlDatabaseInstanceAssetType = "sqladmin.googleapis.com/Instance" + healthCheckAssetType = "compute.googleapis.com/HealthCheck" + cloudRunServiceAssetType = "run.googleapis.com/Service" + nodeGroupAssetType = "compute.googleapis.com/NodeGroup" + computeForwardingRuleAssetType = "compute.googleapis.com/ForwardingRule" + instanceGroupManagerAssetType = "compute.googleapis.com/InstanceGroupManager" + computeGlobalForwardingRuleAssetType = "compute.googleapis.com/GlobalForwardingRule" +) + +type AssetRepository interface { + SearchAllBuckets() ([]*assetpb.ResourceSearchResult, error) + SearchAllFirewalls() ([]*assetpb.ResourceSearchResult, error) + SearchAllRouters() ([]*assetpb.ResourceSearchResult, error) + SearchAllInstances() ([]*assetpb.ResourceSearchResult, error) + SearchAllNetworks() ([]*assetpb.ResourceSearchResult, error) + SearchAllDisks() ([]*assetpb.ResourceSearchResult, error) + SearchAllImages() ([]*assetpb.ResourceSearchResult, error) + SearchAllDNSManagedZones() ([]*assetpb.ResourceSearchResult, error) + SearchAllInstanceGroups() ([]*assetpb.ResourceSearchResult, error) + SearchAllDatasets() ([]*assetpb.ResourceSearchResult, error) + SearchAllTables() ([]*assetpb.ResourceSearchResult, error) + SearchAllAddresses() ([]*assetpb.ResourceSearchResult, error) + SearchAllGlobalAddresses() ([]*assetpb.Asset, error) + SearchAllFunctions() ([]*assetpb.Asset, error) + SearchAllSubnetworks() ([]*assetpb.ResourceSearchResult, error) + SearchAllBigtableInstances() ([]*assetpb.Asset, error) + SearchAllBigtableTables() ([]*assetpb.Asset, error) + SearchAllSQLDatabaseInstances() ([]*assetpb.Asset, error) + SearchAllHealthChecks() ([]*assetpb.ResourceSearchResult, error) + SearchAllCloudRunServices() ([]*assetpb.ResourceSearchResult, error) + SearchAllNodeGroups() ([]*assetpb.Asset, error) + SearchAllForwardingRules() ([]*assetpb.Asset, error) + SearchAllInstanceGroupManagers() ([]*assetpb.Asset, error) + SearchAllGlobalForwardingRules() ([]*assetpb.Asset, error) +} + +type assetRepository struct { + client *asset.Client + config config.GCPTerraformConfig + cache cache.Cache +} + +func NewAssetRepository(client *asset.Client, config config.GCPTerraformConfig, c cache.Cache) *assetRepository { + return &assetRepository{ + client, + config, + c, + } +} + +func (s assetRepository) listAllResources(ty string) ([]*assetpb.Asset, error) { + req := &assetpb.ListAssetsRequest{ + Parent: fmt.Sprintf("projects/%s", s.config.Project), + ContentType: assetpb.ContentType_RESOURCE, + AssetTypes: []string{ + cloudFunctionsFunction, + bigtableInstanceAssetType, + bigtableTableAssetType, + sqlDatabaseInstanceAssetType, + computeGlobalAddressAssetType, + nodeGroupAssetType, + computeForwardingRuleAssetType, + instanceGroupManagerAssetType, + computeGlobalForwardingRuleAssetType, + }, + } + var results []*assetpb.Asset + + cacheKey := "listAllResources" + cachedResults := s.cache.GetAndLock(cacheKey) + defer s.cache.Unlock(cacheKey) + if cachedResults != nil { + results = cachedResults.([]*assetpb.Asset) + } + + if results == nil { + it := s.client.ListAssets(context.Background(), req) + for { + resource, err := it.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, err + } + results = append(results, resource) + } + s.cache.Put(cacheKey, results) + } + + filteredResults := []*assetpb.Asset{} + for _, result := range results { + if result.AssetType == ty { + filteredResults = append(filteredResults, result) + } + } + + return filteredResults, nil +} + +func (s assetRepository) searchAllResources(ty string) ([]*assetpb.ResourceSearchResult, error) { + req := &assetpb.SearchAllResourcesRequest{ + Scope: fmt.Sprintf("projects/%s", s.config.Project), + AssetTypes: []string{ + storageBucketAssetType, + computeFirewallAssetType, + computeRouterAssetType, + computeInstanceAssetType, + computeNetworkAssetType, + computeSubnetworkAssetType, + dnsManagedZoneAssetType, + computeInstanceGroupAssetType, + bigqueryDatasetAssetType, + bigqueryTableAssetType, + computeAddressAssetType, + computeDiskAssetType, + computeImageAssetType, + healthCheckAssetType, + cloudRunServiceAssetType, + }, + } + var results []*assetpb.ResourceSearchResult + + cacheKey := "SearchAllResources" + cachedResults := s.cache.GetAndLock(cacheKey) + defer s.cache.Unlock(cacheKey) + if cachedResults != nil { + results = cachedResults.([]*assetpb.ResourceSearchResult) + } + + if results == nil { + it := s.client.SearchAllResources(context.Background(), req) + for { + resource, err := it.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, err + } + results = append(results, resource) + } + s.cache.Put(cacheKey, results) + } + + filteredResults := []*assetpb.ResourceSearchResult{} + for _, result := range results { + if result.AssetType == ty { + filteredResults = append(filteredResults, result) + } + } + + return filteredResults, nil +} + +func (s assetRepository) SearchAllBuckets() ([]*assetpb.ResourceSearchResult, error) { + return s.searchAllResources(storageBucketAssetType) +} + +func (s assetRepository) SearchAllFirewalls() ([]*assetpb.ResourceSearchResult, error) { + return s.searchAllResources(computeFirewallAssetType) +} + +func (s assetRepository) SearchAllRouters() ([]*assetpb.ResourceSearchResult, error) { + return s.searchAllResources(computeRouterAssetType) +} + +func (s assetRepository) SearchAllInstances() ([]*assetpb.ResourceSearchResult, error) { + return s.searchAllResources(computeInstanceAssetType) +} + +func (s assetRepository) SearchAllNetworks() ([]*assetpb.ResourceSearchResult, error) { + return s.searchAllResources(computeNetworkAssetType) +} + +func (s assetRepository) SearchAllDNSManagedZones() ([]*assetpb.ResourceSearchResult, error) { + return s.searchAllResources(dnsManagedZoneAssetType) +} + +func (s assetRepository) SearchAllInstanceGroups() ([]*assetpb.ResourceSearchResult, error) { + return s.searchAllResources(computeInstanceGroupAssetType) +} + +func (s assetRepository) SearchAllDatasets() ([]*assetpb.ResourceSearchResult, error) { + return s.searchAllResources(bigqueryDatasetAssetType) +} + +func (s assetRepository) SearchAllTables() ([]*assetpb.ResourceSearchResult, error) { + return s.searchAllResources(bigqueryTableAssetType) +} + +func (s assetRepository) SearchAllAddresses() ([]*assetpb.ResourceSearchResult, error) { + return s.searchAllResources(computeAddressAssetType) +} + +func (s assetRepository) SearchAllGlobalAddresses() ([]*assetpb.Asset, error) { + return s.listAllResources(computeGlobalAddressAssetType) +} + +func (s assetRepository) SearchAllFunctions() ([]*assetpb.Asset, error) { + return s.listAllResources(cloudFunctionsFunction) +} + +func (s assetRepository) SearchAllSubnetworks() ([]*assetpb.ResourceSearchResult, error) { + return s.searchAllResources(computeSubnetworkAssetType) +} + +func (s assetRepository) SearchAllDisks() ([]*assetpb.ResourceSearchResult, error) { + return s.searchAllResources(computeDiskAssetType) +} + +func (s assetRepository) SearchAllImages() ([]*assetpb.ResourceSearchResult, error) { + return s.searchAllResources(computeImageAssetType) +} + +func (s assetRepository) SearchAllBigtableInstances() ([]*assetpb.Asset, error) { + return s.listAllResources(bigtableInstanceAssetType) +} + +func (s assetRepository) SearchAllBigtableTables() ([]*assetpb.Asset, error) { + return s.listAllResources(bigtableTableAssetType) +} + +func (s assetRepository) SearchAllSQLDatabaseInstances() ([]*assetpb.Asset, error) { + return s.listAllResources(sqlDatabaseInstanceAssetType) +} + +func (s assetRepository) SearchAllHealthChecks() ([]*assetpb.ResourceSearchResult, error) { + return s.searchAllResources(healthCheckAssetType) +} + +func (s assetRepository) SearchAllCloudRunServices() ([]*assetpb.ResourceSearchResult, error) { + return s.searchAllResources(cloudRunServiceAssetType) +} + +func (s assetRepository) SearchAllNodeGroups() ([]*assetpb.Asset, error) { + return s.listAllResources(nodeGroupAssetType) +} + +func (s assetRepository) SearchAllForwardingRules() ([]*assetpb.Asset, error) { + return s.listAllResources(computeForwardingRuleAssetType) +} + +func (s assetRepository) SearchAllInstanceGroupManagers() ([]*assetpb.Asset, error) { + return s.listAllResources(instanceGroupManagerAssetType) +} + +func (s assetRepository) SearchAllGlobalForwardingRules() ([]*assetpb.Asset, error) { + return s.listAllResources(computeGlobalForwardingRuleAssetType) +} diff --git a/enumeration/remote/google/repository/asset_test.go b/enumeration/remote/google/repository/asset_test.go new file mode 100644 index 00000000..b831a177 --- /dev/null +++ b/enumeration/remote/google/repository/asset_test.go @@ -0,0 +1,64 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "github.com/snyk/driftctl/enumeration/remote/google/config" + "testing" + + "github.com/snyk/driftctl/test/google" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" +) + +func Test_assetRepository_searchAllResources_CacheHit(t *testing.T) { + + expectedResults := []*assetpb.ResourceSearchResult{ + { + AssetType: "google_fake_type", + DisplayName: "driftctl-unittest-1", + }, + { + AssetType: "google_another_fake_type", + DisplayName: "driftctl-unittest-1", + }, + } + + c := &cache.MockCache{} + c.On("GetAndLock", "SearchAllResources").Return(expectedResults).Times(1) + c.On("Unlock", "SearchAllResources").Times(1) + repo := NewAssetRepository(nil, config.GCPTerraformConfig{Project: ""}, c) + + got, err := repo.searchAllResources("google_fake_type") + c.AssertExpectations(t) + assert.Nil(t, err) + assert.Len(t, got, 1) +} + +func Test_assetRepository_searchAllResources_CacheMiss(t *testing.T) { + + expectedResults := []*assetpb.ResourceSearchResult{ + { + AssetType: "google_fake_type", + DisplayName: "driftctl-unittest-1", + }, + { + AssetType: "google_another_fake_type", + DisplayName: "driftctl-unittest-1", + }, + } + assetClient, err := google.NewFakeAssetServer(expectedResults, nil) + if err != nil { + t.Fatal(err) + } + c := &cache.MockCache{} + c.On("GetAndLock", "SearchAllResources").Return(nil).Times(1) + c.On("Unlock", "SearchAllResources").Times(1) + c.On("Put", "SearchAllResources", mock.IsType([]*assetpb.ResourceSearchResult{})).Return(false).Times(1) + repo := NewAssetRepository(assetClient, config.GCPTerraformConfig{Project: ""}, c) + + got, err := repo.searchAllResources("google_fake_type") + c.AssertExpectations(t) + assert.Nil(t, err) + assert.Len(t, got, 1) +} diff --git a/enumeration/remote/google/repository/cloudresourcemanager.go b/enumeration/remote/google/repository/cloudresourcemanager.go new file mode 100644 index 00000000..09507a04 --- /dev/null +++ b/enumeration/remote/google/repository/cloudresourcemanager.go @@ -0,0 +1,50 @@ +package repository + +import ( + "github.com/snyk/driftctl/enumeration/remote/cache" + "github.com/snyk/driftctl/enumeration/remote/google/config" + "google.golang.org/api/cloudresourcemanager/v1" +) + +type CloudResourceManagerRepository interface { + ListProjectsBindings() (map[string]map[string][]string, error) +} + +type cloudResourceManagerRepository struct { + service *cloudresourcemanager.Service + config config.GCPTerraformConfig + cache cache.Cache +} + +func NewCloudResourceManagerRepository(service *cloudresourcemanager.Service, config config.GCPTerraformConfig, cache cache.Cache) CloudResourceManagerRepository { + return &cloudResourceManagerRepository{ + service: service, + config: config, + cache: cache, + } +} + +func (s *cloudResourceManagerRepository) ListProjectsBindings() (map[string]map[string][]string, error) { + if cachedResults := s.cache.Get("ListProjectsBindings"); cachedResults != nil { + return cachedResults.(map[string]map[string][]string), nil + } + + request := new(cloudresourcemanager.GetIamPolicyRequest) + policy, err := s.service.Projects.GetIamPolicy(s.config.Project, request).Do() + if err != nil { + return nil, err + } + + bindings := make(map[string][]string) + + for _, binding := range policy.Bindings { + bindings[binding.Role] = binding.Members + } + + bindingsByProject := make(map[string]map[string][]string) + bindingsByProject[s.config.Project] = bindings + + s.cache.Put("ListProjectsBindings", bindingsByProject) + + return bindingsByProject, nil +} diff --git a/pkg/remote/google/repository/mock_AssetRepository.go b/enumeration/remote/google/repository/mock_AssetRepository.go similarity index 100% rename from pkg/remote/google/repository/mock_AssetRepository.go rename to enumeration/remote/google/repository/mock_AssetRepository.go diff --git a/pkg/remote/google/repository/mock_CloudResourceManagerRepository.go b/enumeration/remote/google/repository/mock_CloudResourceManagerRepository.go similarity index 100% rename from pkg/remote/google/repository/mock_CloudResourceManagerRepository.go rename to enumeration/remote/google/repository/mock_CloudResourceManagerRepository.go diff --git a/pkg/remote/google/repository/mock_StorageRepository.go b/enumeration/remote/google/repository/mock_StorageRepository.go similarity index 100% rename from pkg/remote/google/repository/mock_StorageRepository.go rename to enumeration/remote/google/repository/mock_StorageRepository.go diff --git a/enumeration/remote/google/repository/storage.go b/enumeration/remote/google/repository/storage.go new file mode 100644 index 00000000..ae33bc10 --- /dev/null +++ b/enumeration/remote/google/repository/storage.go @@ -0,0 +1,52 @@ +package repository + +import ( + "context" + "fmt" + "github.com/snyk/driftctl/enumeration/remote/cache" + "sync" + + "cloud.google.com/go/storage" +) + +type StorageRepository interface { + ListAllBindings(bucketName string) (map[string][]string, error) +} + +type storageRepository struct { + client *storage.Client + cache cache.Cache + lock sync.Locker +} + +func NewStorageRepository(client *storage.Client, cache cache.Cache) *storageRepository { + return &storageRepository{ + client: client, + cache: cache, + lock: &sync.Mutex{}, + } +} + +func (s storageRepository) ListAllBindings(bucketName string) (map[string][]string, error) { + + s.lock.Lock() + defer s.lock.Unlock() + if cachedResults := s.cache.Get(fmt.Sprintf("%s-%s", "ListAllBindings", bucketName)); cachedResults != nil { + return cachedResults.(map[string][]string), nil + } + + bucket := s.client.Bucket(bucketName) + policy, err := bucket.IAM().Policy(context.Background()) + if err != nil { + return nil, err + } + bindings := make(map[string][]string) + for _, name := range policy.Roles() { + members := policy.Members(name) + bindings[string(name)] = members + } + + s.cache.Put("ListAllBindings", bindings) + + return bindings, nil +} diff --git a/pkg/remote/google/util.go b/enumeration/remote/google/util.go similarity index 100% rename from pkg/remote/google/util.go rename to enumeration/remote/google/util.go diff --git a/enumeration/remote/google_bigquery_scanner_test.go b/enumeration/remote/google_bigquery_scanner_test.go new file mode 100644 index 00000000..6f8b754e --- /dev/null +++ b/enumeration/remote/google_bigquery_scanner_test.go @@ -0,0 +1,233 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + google2 "github.com/snyk/driftctl/enumeration/remote/google" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/snyk/driftctl/enumeration/resource" + googleresource "github.com/snyk/driftctl/enumeration/resource/google" + "github.com/snyk/driftctl/mocks" + + testgoogle "github.com/snyk/driftctl/test/google" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestGoogleBigqueryDataset(t *testing.T) { + + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.ResourceSearchResult + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no dataset", + response: []*assetpb.ResourceSearchResult{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiples dataset", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + assert.Equal(t, "projects/cloudskiff-dev-elie/datasets/example_dataset", got[0].ResourceId()) + assert.Equal(t, "google_bigquery_dataset", got[0].ResourceType()) + }, + response: []*assetpb.ResourceSearchResult{ + { + AssetType: "bigquery.googleapis.com/Dataset", + Name: "//bigquery.googleapis.com/projects/cloudskiff-dev-elie/datasets/example_dataset", + }, + }, + }, + { + test: "cannot list datasets", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_bigquery_dataset", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_bigquery_dataset", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleBigqueryDatasetEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} + +func TestGoogleBigqueryTable(t *testing.T) { + + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.ResourceSearchResult + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no table", + response: []*assetpb.ResourceSearchResult{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiples table", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + assert.Equal(t, "projects/cloudskiff-dev-elie/datasets/example_dataset/tables/bar", got[0].ResourceId()) + assert.Equal(t, "google_bigquery_table", got[0].ResourceType()) + }, + response: []*assetpb.ResourceSearchResult{ + { + AssetType: "bigquery.googleapis.com/Table", + Name: "//bigquery.googleapis.com/projects/cloudskiff-dev-elie/datasets/example_dataset/tables/bar", + }, + }, + }, + { + test: "cannot list table", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_bigquery_table", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_bigquery_table", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleBigqueryTableEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} diff --git a/enumeration/remote/google_bigtable_scanner_test.go b/enumeration/remote/google_bigtable_scanner_test.go new file mode 100644 index 00000000..3bec37ac --- /dev/null +++ b/enumeration/remote/google_bigtable_scanner_test.go @@ -0,0 +1,293 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + google2 "github.com/snyk/driftctl/enumeration/remote/google" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/snyk/driftctl/enumeration/resource" + googleresource "github.com/snyk/driftctl/enumeration/resource/google" + "github.com/snyk/driftctl/mocks" + + testgoogle "github.com/snyk/driftctl/test/google" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/structpb" +) + +func TestGoogleBigtableInstance(t *testing.T) { + + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.Asset + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no instance", + response: []*assetpb.Asset{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "one instance returned", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + assert.Equal(t, "projects/cloudskiff-dev-elie/instances/tf-instance", got[0].ResourceId()) + assert.Equal(t, "google_bigtable_instance", got[0].ResourceType()) + }, + response: []*assetpb.Asset{ + { + AssetType: "bigtableadmin.googleapis.com/Instance", + Name: "//bigtable.googleapis.com/projects/cloudskiff-dev-elie/instances/tf-instance", + Resource: &assetpb.Resource{ + Data: func() *structpb.Struct { + v, err := structpb.NewStruct(map[string]interface{}{ + "name": "projects/cloudskiff-dev-elie/instances/tf-instance", + }) + if err != nil { + t.Fatal(err) + } + return v + }(), + }, + }, + }, + }, + { + test: "one instance without resource data", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + response: []*assetpb.Asset{ + { + AssetType: "bigtableadmin.googleapis.com/Instance", + Name: "//bigtable.googleapis.com/projects/cloudskiff-dev-elie/instances/tf-instance", + }, + { + AssetType: "bigtableadmin.googleapis.com/Instance", + Name: "//bigtable.googleapis.com/projects/cloudskiff-dev-elie/instances/tf-instance-2", + Resource: &assetpb.Resource{ + Data: func() *structpb.Struct { + v, err := structpb.NewStruct(map[string]interface{}{}) + if err != nil { + t.Fatal(err) + } + return v + }(), + }, + }, + }, + }, + { + test: "cannot list instances", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_bigtable_instance", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_bigtable_instance", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleBigTableInstanceEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} + +func TestGoogleBigtableTable(t *testing.T) { + + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.Asset + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no table", + response: []*assetpb.Asset{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "one resource returned", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + assert.Equal(t, "projects/cloudskiff-dev-elie/instances/tf-instance/tables/tf-table", got[0].ResourceId()) + assert.Equal(t, "google_bigtable_table", got[0].ResourceType()) + }, + response: []*assetpb.Asset{ + { + AssetType: "bigtableadmin.googleapis.com/Table", + Name: "//bigtable.googleapis.com/projects/cloudskiff-dev-elie/instances/tf-instance/tables/tf-table", + Resource: &assetpb.Resource{ + Data: func() *structpb.Struct { + v, err := structpb.NewStruct(map[string]interface{}{ + "name": "projects/cloudskiff-dev-elie/instances/tf-instance/tables/tf-table", + }) + if err != nil { + t.Fatal(err) + } + return v + }(), + }, + }, + }, + }, + { + test: "one resource without resource data", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + response: []*assetpb.Asset{ + { + AssetType: "bigtableadmin.googleapis.com/Table", + Name: "//bigtable.googleapis.com/projects/cloudskiff-dev-elie/instances/tf-instance/tables/tf-table", + }, + }, + }, + { + test: "cannot list cloud functions", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_bigtable_table", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_bigtable_table", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleBigtableTableEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} diff --git a/enumeration/remote/google_cloudfunctions_scanner_test.go b/enumeration/remote/google_cloudfunctions_scanner_test.go new file mode 100644 index 00000000..a0b63170 --- /dev/null +++ b/enumeration/remote/google_cloudfunctions_scanner_test.go @@ -0,0 +1,154 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/snyk/driftctl/enumeration/resource" + googleresource "github.com/snyk/driftctl/enumeration/resource/google" + "github.com/snyk/driftctl/mocks" + + testgoogle "github.com/snyk/driftctl/test/google" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/structpb" +) + +func TestGoogleCloudFunctionsFunction(t *testing.T) { + + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.Asset + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no compute instance", + response: []*assetpb.Asset{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "one cloud function returned", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + assert.Equal(t, "projects/cloudskiff-dev-elie/locations/us-central1/functions/function-test", got[0].ResourceId()) + assert.Equal(t, "google_cloudfunctions_function", got[0].ResourceType()) + }, + response: []*assetpb.Asset{ + { + AssetType: "cloudfunctions.googleapis.com/CloudFunction", + Name: "//cloudfunctions.googleapis.com/projects/cloudskiff-dev-elie/locations/us-central1/functions/function-test", + Resource: &assetpb.Resource{ + Data: func() *structpb.Struct { + v, err := structpb.NewStruct(map[string]interface{}{ + "name": "projects/cloudskiff-dev-elie/locations/us-central1/functions/function-test", + }) + if err != nil { + t.Fatal(err) + } + return v + }(), + }, + }, + }, + }, + { + test: "one cloud function without resource data", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + response: []*assetpb.Asset{ + { + AssetType: "cloudfunctions.googleapis.com/CloudFunction", + Name: "//cloudfunctions.googleapis.com/projects/cloudskiff-dev-elie/locations/us-central1/functions/function-test", + }, + }, + }, + { + test: "cannot list cloud functions", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_cloudfunctions_function", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_cloudfunctions_function", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google.NewGoogleCloudFunctionsFunctionEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} diff --git a/enumeration/remote/google_cloudrun_scanner_test.go b/enumeration/remote/google_cloudrun_scanner_test.go new file mode 100644 index 00000000..51c7a0f5 --- /dev/null +++ b/enumeration/remote/google_cloudrun_scanner_test.go @@ -0,0 +1,150 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/snyk/driftctl/enumeration/resource" + googleresource "github.com/snyk/driftctl/enumeration/resource/google" + "github.com/snyk/driftctl/mocks" + + testgoogle "github.com/snyk/driftctl/test/google" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestGoogleCloudRunService(t *testing.T) { + + cases := []struct { + test string + response []*assetpb.ResourceSearchResult + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + assertExpected func(t *testing.T, got []*resource.Resource) + }{ + { + test: "no resource", + response: []*assetpb.ResourceSearchResult{}, + wantErr: nil, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiples resources", + response: []*assetpb.ResourceSearchResult{ + { + AssetType: "run.googleapis.com/Service", + Name: "invalid ID", // Should be ignored + }, + { + AssetType: "run.googleapis.com/Service", + DisplayName: "cloudrun-srv-1", + Name: "//run.googleapis.com/projects/cloudskiff-dev-elie/locations/us-central1/services/cloudrun-srv-1", + Location: "us-central1", + }, + { + AssetType: "run.googleapis.com/Service", + DisplayName: "cloudrun-srv-2", + Name: "//run.googleapis.com/projects/cloudskiff-dev-elie/locations/us-central1/services/cloudrun-srv-2", + Location: "us-central1", + }, + }, + wantErr: nil, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + + assert.Equal(t, got[0].ResourceId(), "locations/us-central1/namespaces/cloudskiff-dev-elie/services/cloudrun-srv-1") + assert.Equal(t, got[0].ResourceType(), googleresource.GoogleCloudRunServiceResourceType) + + assert.Equal(t, got[1].ResourceId(), "locations/us-central1/namespaces/cloudskiff-dev-elie/services/cloudrun-srv-2") + assert.Equal(t, got[1].ResourceType(), googleresource.GoogleCloudRunServiceResourceType) + }, + }, + { + test: "should return access denied error", + wantErr: nil, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + googleresource.GoogleCloudRunServiceResourceType, + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + googleresource.GoogleCloudRunServiceResourceType, + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google.NewGoogleCloudRunServiceEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} diff --git a/enumeration/remote/google_compute_scanner_test.go b/enumeration/remote/google_compute_scanner_test.go new file mode 100644 index 00000000..27651c6b --- /dev/null +++ b/enumeration/remote/google_compute_scanner_test.go @@ -0,0 +1,1765 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + google2 "github.com/snyk/driftctl/enumeration/remote/google" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/snyk/driftctl/enumeration/resource" + googleresource "github.com/snyk/driftctl/enumeration/resource/google" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testgoogle "github.com/snyk/driftctl/test/google" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/structpb" +) + +func TestGoogleComputeFirewall(t *testing.T) { + + cases := []struct { + test string + dirName string + response []*assetpb.ResourceSearchResult + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no compute firewall", + dirName: "google_compute_firewall_empty", + response: []*assetpb.ResourceSearchResult{}, + wantErr: nil, + }, + { + test: "multiples compute firewall", + dirName: "google_compute_firewall", + response: []*assetpb.ResourceSearchResult{ + { + AssetType: "compute.googleapis.com/Firewall", + DisplayName: "test-firewall-0", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/global/firewalls/test-firewall-0", + }, + { + AssetType: "compute.googleapis.com/Firewall", + DisplayName: "test-firewall-1", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/global/firewalls/test-firewall-1", + }, + { + AssetType: "compute.googleapis.com/Firewall", + DisplayName: "test-firewall-2", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/global/firewalls/test-firewall-2", + }, + }, + wantErr: nil, + }, + { + test: "cannot list compute firewall", + dirName: "google_compute_firewall_empty", + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_compute_firewall", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_compute_firewall", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + wantErr: nil, + }, + } + + providerVersion := "3.78.0" + resType := resource.ResourceType(googleresource.GoogleComputeFirewallResourceType) + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err = realProvider.Init() + if err != nil { + tt.Fatal(err) + } + provider.ShouldUpdate() + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleComputeFirewallEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resType, common2.NewGenericDetailsFetcher(resType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + test.TestAgainstGoldenFile(got, resType.String(), c.dirName, provider, deserializer, shouldUpdate, tt) + }) + } +} + +func TestGoogleComputeRouter(t *testing.T) { + + cases := []struct { + test string + response []*assetpb.ResourceSearchResult + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + assertExpected func(t *testing.T, got []*resource.Resource) + }{ + { + test: "no compute router", + response: []*assetpb.ResourceSearchResult{}, + wantErr: nil, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiples compute routers", + response: []*assetpb.ResourceSearchResult{ + { + AssetType: "compute.googleapis.com/Router", + DisplayName: "test-router-0", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/regions/us-central1/routers/test-router-0", + }, + { + AssetType: "compute.googleapis.com/Router", + DisplayName: "test-router-1", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/regions/us-central1/routers/test-router-1", + }, + { + AssetType: "compute.googleapis.com/Router", + DisplayName: "test-router-2", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/regions/us-central1/routers/test-router-2", + }, + }, + wantErr: nil, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 3) + + assert.Equal(t, got[0].ResourceId(), "projects/cloudskiff-dev-raphael/regions/us-central1/routers/test-router-0") + assert.Equal(t, got[0].ResourceType(), googleresource.GoogleComputeRouterResourceType) + + assert.Equal(t, got[1].ResourceId(), "projects/cloudskiff-dev-raphael/regions/us-central1/routers/test-router-1") + assert.Equal(t, got[1].ResourceType(), googleresource.GoogleComputeRouterResourceType) + + assert.Equal(t, got[2].ResourceId(), "projects/cloudskiff-dev-raphael/regions/us-central1/routers/test-router-2") + assert.Equal(t, got[2].ResourceType(), googleresource.GoogleComputeRouterResourceType) + }, + }, + { + test: "should return access denied error", + wantErr: nil, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + googleresource.GoogleComputeRouterResourceType, + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + googleresource.GoogleComputeRouterResourceType, + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleComputeRouterEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} + +func TestGoogleComputeInstance(t *testing.T) { + + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.ResourceSearchResult + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no compute instance", + response: []*assetpb.ResourceSearchResult{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiples compute instances", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + assert.Equal(t, "projects/cloudskiff-dev-elie/zones/us-central1-a/instances/test", got[0].ResourceId()) + assert.Equal(t, "google_compute_instance", got[0].ResourceType()) + }, + response: []*assetpb.ResourceSearchResult{ + { + AssetType: "compute.googleapis.com/Instance", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/zones/us-central1-a/instances/test", + }, + }, + }, + { + test: "cannot list compute firewall", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_compute_instance", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_compute_instance", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleComputeInstanceEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} + +func TestGoogleComputeNetwork(t *testing.T) { + + cases := []struct { + test string + dirName string + response []*assetpb.ResourceSearchResult + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no network", + dirName: "google_compute_network_empty", + response: []*assetpb.ResourceSearchResult{}, + wantErr: nil, + }, + { + test: "multiple networks", + dirName: "google_compute_network", + response: []*assetpb.ResourceSearchResult{ + { + AssetType: "compute.googleapis.com/Network", + DisplayName: "driftctl-unittest-1", + Name: "//compute.googleapis.com/projects/driftctl-qa-1/global/networks/driftctl-unittest-1", + }, + { + AssetType: "compute.googleapis.com/Network", + DisplayName: "driftctl-unittest-2", + Name: "//compute.googleapis.com/projects/driftctl-qa-1/global/networks/driftctl-unittest-2", + }, + { + AssetType: "compute.googleapis.com/Network", + DisplayName: "driftctl-unittest-3", + Name: "//compute.googleapis.com/projects/driftctl-qa-1/global/networks/driftctl-unittest-3", + }, + }, + wantErr: nil, + }, + { + test: "cannot list compute networks", + dirName: "google_compute_network_empty", + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_compute_network", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_compute_network", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + wantErr: nil, + }, + } + + providerVersion := "3.78.0" + resType := resource.ResourceType(googleresource.GoogleComputeNetworkResourceType) + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err = realProvider.Init() + if err != nil { + tt.Fatal(err) + } + provider.ShouldUpdate() + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleComputeNetworkEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resType, common2.NewGenericDetailsFetcher(resType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + test.TestAgainstGoldenFile(got, resType.String(), c.dirName, provider, deserializer, shouldUpdate, tt) + }) + } +} + +func TestGoogleComputeInstanceGroup(t *testing.T) { + + cases := []struct { + test string + dirName string + response []*assetpb.ResourceSearchResult + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no instance group", + dirName: "google_compute_instance_group_empty", + response: []*assetpb.ResourceSearchResult{}, + wantErr: nil, + }, + { + test: "multiple instance groups", + dirName: "google_compute_instance_group", + response: []*assetpb.ResourceSearchResult{ + { + AssetType: "compute.googleapis.com/InstanceGroup", + DisplayName: "driftctl-test-1", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroups/driftctl-test-1", + Project: "cloudskiff-dev-raphael", + Location: "us-central1-a", + }, + { + AssetType: "compute.googleapis.com/InstanceGroup", + DisplayName: "driftctl-test-2", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroups/driftctl-test-2", + Project: "cloudskiff-dev-raphael", + Location: "us-central1-a", + }, + }, + wantErr: nil, + }, + { + test: "cannot list instance groups", + dirName: "google_compute_instance_group_empty", + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_compute_instance_group", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_compute_instance_group", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + wantErr: nil, + }, + } + + providerVersion := "3.78.0" + resType := resource.ResourceType(googleresource.GoogleComputeInstanceGroupResourceType) + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err = realProvider.Init() + if err != nil { + tt.Fatal(err) + } + provider.ShouldUpdate() + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleComputeInstanceGroupEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(googleresource.GoogleComputeInstanceGroupResourceType, common2.NewGenericDetailsFetcher(googleresource.GoogleComputeInstanceGroupResourceType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + test.TestAgainstGoldenFile(got, resType.String(), c.dirName, provider, deserializer, shouldUpdate, tt) + }) + } +} + +func TestGoogleComputeAddress(t *testing.T) { + + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.ResourceSearchResult + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no compute address", + response: []*assetpb.ResourceSearchResult{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiples compute address", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + assert.Equal(t, "projects/cloudskiff-dev-elie/regions/us-central1/addresses/my-address", got[0].ResourceId()) + assert.Equal(t, "google_compute_address", got[0].ResourceType()) + + assert.Equal(t, "projects/cloudskiff-dev-elie/regions/us-central1/addresses/my-address-2", got[1].ResourceId()) + assert.Equal(t, "google_compute_address", got[1].ResourceType()) + assert.Equal(t, "1.2.3.4", *got[1].Attributes().GetString("address")) + }, + response: []*assetpb.ResourceSearchResult{ + { + AssetType: "compute.googleapis.com/Address", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/regions/us-central1/addresses/my-address", + }, + { + AssetType: "compute.googleapis.com/Address", + Location: "global", // Global addresses should be ignored + }, + { + AssetType: "compute.googleapis.com/Address", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/regions/us-central1/addresses/my-address-2", + AdditionalAttributes: func() *structpb.Struct { + str, _ := structpb.NewStruct(map[string]interface{}{ + "address": "1.2.3.4", + }) + return str + }(), + }, + }, + }, + { + test: "cannot list compute address", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_compute_address", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_compute_address", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleComputeAddressEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} + +func TestGoogleComputeGlobalAddress(t *testing.T) { + + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.Asset + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no resource", + response: []*assetpb.Asset{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "one resource returned", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + assert.Equal(t, "projects/cloudskiff-dev-elie/global/addresses/global-appserver-ip", got[0].ResourceId()) + assert.Equal(t, "google_compute_global_address", got[0].ResourceType()) + }, + response: []*assetpb.Asset{ + { + AssetType: "compute.googleapis.com/GlobalAddress", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/global/addresses/global-appserver-ip", + Resource: &assetpb.Resource{ + Data: func() *structpb.Struct { + v, err := structpb.NewStruct(map[string]interface{}{ + "name": "projects/cloudskiff-dev-elie/global/addresses/global-appserver-ip", + }) + if err != nil { + t.Fatal(err) + } + return v + }(), + }, + }, + }, + }, + { + test: "one resource without resource data", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + response: []*assetpb.Asset{ + { + AssetType: "compute.googleapis.com/GlobalAddress", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/global/addresses/global-appserver-ip", + }, + }, + }, + { + test: "cannot list cloud functions", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_compute_global_address", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_compute_global_address", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleComputeGlobalAddressEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} + +func TestGoogleComputeSubnetwork(t *testing.T) { + + cases := []struct { + test string + dirName string + response []*assetpb.ResourceSearchResult + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no subnetwork", + dirName: "google_compute_subnetwork_empty", + response: []*assetpb.ResourceSearchResult{}, + wantErr: nil, + }, + { + test: "multiple subnetworks", + dirName: "google_compute_subnetwork_multiple", + response: []*assetpb.ResourceSearchResult{ + { + AssetType: "compute.googleapis.com/Subnetwork", + DisplayName: "driftctl-unittest-1", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/regions/us-central1/subnetworks/driftctl-unittest-1", + }, + { + AssetType: "compute.googleapis.com/Subnetwork", + DisplayName: "driftctl-unittest-2", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/regions/us-central1/subnetworks/driftctl-unittest-2", + }, + { + AssetType: "compute.googleapis.com/Subnetwork", + DisplayName: "driftctl-unittest-3", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/regions/us-central1/subnetworks/driftctl-unittest-3", + }, + }, + wantErr: nil, + }, + { + test: "cannot list compute subnetworks", + dirName: "google_compute_subnetwork_empty", + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_compute_subnetwork", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_compute_subnetwork", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + wantErr: nil, + }, + } + + providerVersion := "3.78.0" + resType := resource.ResourceType(googleresource.GoogleComputeSubnetworkResourceType) + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + err = realProvider.Init() + if err != nil { + tt.Fatal(err) + } + provider.ShouldUpdate() + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleComputeSubnetworkEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resType, common2.NewGenericDetailsFetcher(resType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + test.TestAgainstGoldenFile(got, resType.String(), c.dirName, provider, deserializer, shouldUpdate, tt) + }) + } +} + +func TestGoogleComputeDisk(t *testing.T) { + + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.ResourceSearchResult + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no compute disk", + response: []*assetpb.ResourceSearchResult{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiples compute disk", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + assert.Equal(t, "projects/cloudskiff-dev-elie/zones/us-central1-a/disks/test-disk", got[0].ResourceId()) + assert.Equal(t, "google_compute_disk", got[0].ResourceType()) + + assert.Equal(t, "projects/cloudskiff-dev-elie/zones/us-central1-a/disks/test-disk-2", got[1].ResourceId()) + assert.Equal(t, "google_compute_disk", got[1].ResourceType()) + }, + response: []*assetpb.ResourceSearchResult{ + { + AssetType: "compute.googleapis.com/Disk", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/zones/us-central1-a/disks/test-disk", + }, + { + AssetType: "compute.googleapis.com/Disk", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/zones/us-central1-a/disks/test-disk-2", + }, + }, + }, + { + test: "cannot list compute disk", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_compute_disk", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_compute_disk", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleComputeDiskEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} + +func TestGoogleComputeImage(t *testing.T) { + + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.ResourceSearchResult + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no compute image", + response: []*assetpb.ResourceSearchResult{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiples images", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + assert.Equal(t, "projects/cloudskiff-dev-elie/global/images/example-image", got[0].ResourceId()) + assert.Equal(t, "google_compute_image", got[0].ResourceType()) + + assert.Equal(t, "projects/cloudskiff-dev-elie/global/images/example-image-2", got[1].ResourceId()) + assert.Equal(t, "google_compute_image", got[1].ResourceType()) + }, + response: []*assetpb.ResourceSearchResult{ + { + AssetType: "compute.googleapis.com/Image", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/global/images/example-image", + }, + { + AssetType: "compute.googleapis.com/Image", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/global/images/example-image-2", + }, + }, + }, + { + test: "cannot list images", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_compute_image", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_compute_image", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleComputeImageEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} + +func TestGoogleComputeHealthCheck(t *testing.T) { + + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.ResourceSearchResult + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no compute health check", + response: []*assetpb.ResourceSearchResult{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiples compute health checks", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + assert.Equal(t, "projects/cloudskiff-dev-raphael/global/healthChecks/test-health-check-1", got[0].ResourceId()) + assert.Equal(t, "google_compute_health_check", got[0].ResourceType()) + + assert.Equal(t, "projects/cloudskiff-dev-raphael/global/healthChecks/test-health-check-2", got[1].ResourceId()) + assert.Equal(t, "google_compute_health_check", got[1].ResourceType()) + }, + response: []*assetpb.ResourceSearchResult{ + { + AssetType: "compute.googleapis.com/HealthCheck", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/global/healthChecks/test-health-check-1", + }, + { + AssetType: "compute.googleapis.com/HealthCheck", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/global/healthChecks/test-health-check-2", + }, + }, + }, + { + test: "cannot list compute health checks", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_compute_health_check", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_compute_health_check", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository(terraform3.GOOGLE, providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleComputeHealthCheckEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} + +func TestGoogleComputeNodeGroup(t *testing.T) { + + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.Asset + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no compute node group", + response: []*assetpb.Asset{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiples compute node group", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + assert.Equal(t, "projects/cloudskiff-dev-martin/zones/us-central1-f/nodeGroups/soletenant-group", got[0].ResourceId()) + assert.Equal(t, "google_compute_node_group", got[0].ResourceType()) + + assert.Equal(t, "projects/cloudskiff-dev-martin/zones/us-central1-f/nodeGroups/simple-group", got[1].ResourceId()) + assert.Equal(t, "google_compute_node_group", got[1].ResourceType()) + }, + response: []*assetpb.Asset{ + { + AssetType: "compute.googleapis.com/NodeGroup", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-martin/zones/us-central1-f/nodeGroups/soletenant-group", + }, + { + AssetType: "compute.googleapis.com/NodeGroup", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-martin/zones/us-central1-f/nodeGroups/simple-group", + }, + }, + }, + { + test: "cannot list compute node group", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_compute_node_group", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_compute_node_group", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository(terraform3.GOOGLE, providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleComputeNodeGroupEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} + +func TestGoogleComputeForwardingRule(t *testing.T) { + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.Asset + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no compute forwarding rules", + response: []*assetpb.Asset{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple compute forwarding rules", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + assert.Equal(t, "projects/cloudskiff-dev-william/regions/us-east1/forwardingRules/foo", got[0].ResourceId()) + assert.Equal(t, "google_compute_forwarding_rule", got[0].ResourceType()) + + assert.Equal(t, "projects/cloudskiff-dev-william/regions/us-east1/forwardingRules/bar", got[1].ResourceId()) + assert.Equal(t, "google_compute_forwarding_rule", got[1].ResourceType()) + }, + response: []*assetpb.Asset{ + { + AssetType: "compute.googleapis.com/ForwardingRule", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-william/regions/us-east1/forwardingRules/foo", + }, + { + AssetType: "compute.googleapis.com/ForwardingRule", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-william/regions/us-east1/forwardingRules/bar", + }, + }, + }, + { + test: "cannot list compute forwarding rules", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_compute_forwarding_rule", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_compute_forwarding_rule", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository(terraform3.GOOGLE, providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleComputeForwardingRuleEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} + +func TestGoogleComputeInstanceGroupManager(t *testing.T) { + + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.Asset + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no compute instance group manager", + response: []*assetpb.Asset{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiples compute instance group managers", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + assert.Equal(t, "projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroupManagers/appserver-abc", got[0].ResourceId()) + assert.Equal(t, "google_compute_instance_group_manager", got[0].ResourceType()) + + assert.Equal(t, "projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroupManagers/appserver-def", got[1].ResourceId()) + assert.Equal(t, "google_compute_instance_group_manager", got[1].ResourceType()) + }, + response: []*assetpb.Asset{ + { + AssetType: "compute.googleapis.com/InstanceGroupManager", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroupManagers/appserver-abc", + }, + { + AssetType: "compute.googleapis.com/InstanceGroupManager", + Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroupManagers/appserver-def", + }, + }, + }, + { + test: "cannot list compute instance group managers", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_compute_instance_group_manager", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_compute_instance_group_manager", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository(terraform3.GOOGLE, providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleComputeInstanceGroupManagerEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} + +func TestGoogleComputeGlobalForwardingRule(t *testing.T) { + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.Asset + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no compute global forwarding rules", + response: []*assetpb.Asset{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiple compute global forwarding rules", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 2) + assert.Equal(t, "//projects/driftctl-qa-1/global/forwardingRules/global-rule-foo", got[0].ResourceId()) + assert.Equal(t, "google_compute_global_forwarding_rule", got[0].ResourceType()) + + assert.Equal(t, "//projects/driftctl-qa-1/global/forwardingRules/global-rule-bar", got[1].ResourceId()) + assert.Equal(t, "google_compute_global_forwarding_rule", got[1].ResourceType()) + }, + response: []*assetpb.Asset{ + { + AssetType: "compute.googleapis.com/GlobalForwardingRule", + Name: "//projects/driftctl-qa-1/global/forwardingRules/global-rule-foo", + }, + { + AssetType: "compute.googleapis.com/GlobalForwardingRule", + Name: "//projects/driftctl-qa-1/global/forwardingRules/global-rule-bar", + }, + }, + }, + { + test: "cannot list compute global forwarding rules", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_compute_global_forwarding_rule", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_compute_global_forwarding_rule", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository(terraform3.GOOGLE, providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleComputeGlobalForwardingRuleEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} diff --git a/enumeration/remote/google_network_scanner_test.go b/enumeration/remote/google_network_scanner_test.go new file mode 100644 index 00000000..2956c004 --- /dev/null +++ b/enumeration/remote/google_network_scanner_test.go @@ -0,0 +1,156 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/snyk/driftctl/enumeration/resource" + googleresource "github.com/snyk/driftctl/enumeration/resource/google" + "github.com/snyk/driftctl/mocks" + + testgoogle "github.com/snyk/driftctl/test/google" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestGoogleDNSNanagedZone(t *testing.T) { + + cases := []struct { + test string + response []*assetpb.ResourceSearchResult + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + assertExpected func(t *testing.T, got []*resource.Resource) + }{ + { + test: "no managed zone", + response: []*assetpb.ResourceSearchResult{}, + wantErr: nil, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "multiples managed zones", + response: []*assetpb.ResourceSearchResult{ + { + AssetType: "dns.googleapis.com/ManagedZone", + Name: "invalid ID", // Should be ignored + }, + { + AssetType: "dns.googleapis.com/ManagedZone", + DisplayName: "test-zone-0", + Name: "//dns.googleapis.com/projects/cloudskiff-dev-raphael/managedZones/123456789", + }, + { + AssetType: "dns.googleapis.com/ManagedZone", + DisplayName: "test-zone-1", + Name: "//dns.googleapis.com/projects/cloudskiff-dev-raphael/managedZones/123456789", + }, + { + AssetType: "dns.googleapis.com/ManagedZone", + DisplayName: "test-zone-2", + Name: "//dns.googleapis.com/projects/cloudskiff-dev-raphael/managedZones/123456789", + }, + }, + wantErr: nil, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 3) + + assert.Equal(t, got[0].ResourceId(), "projects/cloudskiff-dev-raphael/managedZones/test-zone-0") + assert.Equal(t, got[0].ResourceType(), googleresource.GoogleDNSManagedZoneResourceType) + + assert.Equal(t, got[1].ResourceId(), "projects/cloudskiff-dev-raphael/managedZones/test-zone-1") + assert.Equal(t, got[1].ResourceType(), googleresource.GoogleDNSManagedZoneResourceType) + + assert.Equal(t, got[2].ResourceId(), "projects/cloudskiff-dev-raphael/managedZones/test-zone-2") + assert.Equal(t, got[2].ResourceType(), googleresource.GoogleDNSManagedZoneResourceType) + }, + }, + { + test: "should return access denied error", + wantErr: nil, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + googleresource.GoogleDNSManagedZoneResourceType, + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + googleresource.GoogleDNSManagedZoneResourceType, + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google.NewGoogleDNSManagedZoneEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} diff --git a/enumeration/remote/google_project_scanner_test.go b/enumeration/remote/google_project_scanner_test.go new file mode 100644 index 00000000..c839dc94 --- /dev/null +++ b/enumeration/remote/google_project_scanner_test.go @@ -0,0 +1,140 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + googleresource "github.com/snyk/driftctl/enumeration/resource/google" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestGoogleProjectIAMMember(t *testing.T) { + + cases := []struct { + test string + dirName string + repositoryMock func(repository *repository.MockCloudResourceManagerRepository) + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no bindings", + dirName: "google_project_member_empty", + repositoryMock: func(repository *repository.MockCloudResourceManagerRepository) { + repository.On("ListProjectsBindings").Return(map[string]map[string][]string{}, nil) + }, + wantErr: nil, + }, + { + test: "Cannot list bindings", + dirName: "google_project_member_listing_error", + repositoryMock: func(repository *repository.MockCloudResourceManagerRepository) { + repository.On("ListProjectsBindings").Return( + map[string]map[string][]string{}, + errors.New("googleapi: Error 403: driftctl-acc-circle@driftctl-qa-1.iam.gserviceaccount.com does not have project.getIamPolicy access., forbidden")) + }, + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_project_iam_member", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + errors.New("googleapi: Error 403: driftctl-acc-circle@driftctl-qa-1.iam.gserviceaccount.com does not have project.getIamPolicy access., forbidden"), + "google_project_iam_member", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + wantErr: nil, + }, + { + test: "multiples storage buckets, multiple bindings", + dirName: "google_project_member_listing_multiple", + repositoryMock: func(repository *repository.MockCloudResourceManagerRepository) { + repository.On("ListProjectsBindings").Return(map[string]map[string][]string{ + "": { + "roles/editor": { + "user:martin.guibert@cloudskiff.com", + "serviceAccount:drifctl-admin@cloudskiff-dev-martin.iam.gserviceaccount.com", + }, + "roles/storage.admin": {"user:martin.guibert@cloudskiff.com"}, + "roles/viewer": {"serviceAccount:driftctl@cloudskiff-dev-martin.iam.gserviceaccount.com"}, + "roles/cloudasset.viewer": {"serviceAccount:driftctl@cloudskiff-dev-martin.iam.gserviceaccount.com"}, + "roles/iam.securityReviewer": {"serviceAccount:driftctl@cloudskiff-dev-martin.iam.gserviceaccount.com"}, + }, + }, nil) + }, + wantErr: nil, + }, + } + + providerVersion := "3.78.0" + resType := resource.ResourceType(googleresource.GoogleProjectIamMemberResourceType) + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + managerRepository := &repository.MockCloudResourceManagerRepository{} + if c.repositoryMock != nil { + c.repositoryMock(managerRepository) + } + + remoteLibrary.AddEnumerator(google.NewGoogleProjectIamMemberEnumerator(managerRepository, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + test.TestAgainstGoldenFile(got, resType.String(), c.dirName, provider, deserializer, shouldUpdate, tt) + }) + } +} diff --git a/enumeration/remote/google_sql_scanner_test.go b/enumeration/remote/google_sql_scanner_test.go new file mode 100644 index 00000000..bcb05499 --- /dev/null +++ b/enumeration/remote/google_sql_scanner_test.go @@ -0,0 +1,152 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/snyk/driftctl/enumeration/remote/google" + "github.com/snyk/driftctl/enumeration/remote/google/repository" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/snyk/driftctl/enumeration/resource" + googleresource "github.com/snyk/driftctl/enumeration/resource/google" + "github.com/snyk/driftctl/mocks" + + testgoogle "github.com/snyk/driftctl/test/google" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/structpb" +) + +func TestGoogleSQLDatabaseInstance(t *testing.T) { + + cases := []struct { + test string + assertExpected func(t *testing.T, got []*resource.Resource) + response []*assetpb.Asset + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no instance", + response: []*assetpb.Asset{}, + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + }, + { + test: "one resource returned", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 1) + assert.Equal(t, "instance-test", got[0].ResourceId()) + assert.Equal(t, "google_sql_database_instance", got[0].ResourceType()) + }, + response: []*assetpb.Asset{ + { + AssetType: "sqladmin.googleapis.com/Instance", + Resource: &assetpb.Resource{ + Data: func() *structpb.Struct { + v, err := structpb.NewStruct(map[string]interface{}{ + "name": "instance-test", + }) + if err != nil { + t.Fatal(err) + } + return v + }(), + }, + }, + }, + }, + { + test: "one resource without resource data", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + response: []*assetpb.Asset{ + { + AssetType: "sqladmin.googleapis.com/Instance", + }, + }, + }, + { + test: "cannot list resources", + assertExpected: func(t *testing.T, got []*resource.Resource) { + assert.Len(t, got, 0) + }, + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_sql_database_instance", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_sql_database_instance", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + }, + } + + providerVersion := "3.78.0" + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + scanOptions := ScannerOptions{} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + + repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google.NewGoogleSQLDatabaseInstanceEnumerator(repo, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + if c.assertExpected != nil { + c.assertExpected(t, got) + } + }) + } +} diff --git a/enumeration/remote/google_storage_scanner_test.go b/enumeration/remote/google_storage_scanner_test.go new file mode 100644 index 00000000..57f7a76f --- /dev/null +++ b/enumeration/remote/google_storage_scanner_test.go @@ -0,0 +1,331 @@ +package remote + +import ( + "context" + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/cache" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + google2 "github.com/snyk/driftctl/enumeration/remote/google" + repository2 "github.com/snyk/driftctl/enumeration/remote/google/repository" + terraform3 "github.com/snyk/driftctl/enumeration/terraform" + + asset "cloud.google.com/go/asset/apiv1" + "cloud.google.com/go/storage" + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/resource" + googleresource "github.com/snyk/driftctl/enumeration/resource/google" + "github.com/snyk/driftctl/mocks" + + "github.com/snyk/driftctl/test" + "github.com/snyk/driftctl/test/goldenfile" + testgoogle "github.com/snyk/driftctl/test/google" + testresource "github.com/snyk/driftctl/test/resource" + terraform2 "github.com/snyk/driftctl/test/terraform" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestGoogleStorageBucket(t *testing.T) { + + cases := []struct { + test string + dirName string + response []*assetpb.ResourceSearchResult + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no storage buckets", + dirName: "google_storage_bucket_empty", + response: []*assetpb.ResourceSearchResult{}, + wantErr: nil, + }, + { + test: "multiples storage buckets", + dirName: "google_storage_bucket", + response: []*assetpb.ResourceSearchResult{ + { + AssetType: "storage.googleapis.com/Bucket", + DisplayName: "driftctl-unittest-1", + }, + { + AssetType: "storage.googleapis.com/Bucket", + DisplayName: "driftctl-unittest-2", + }, + { + AssetType: "storage.googleapis.com/Bucket", + DisplayName: "driftctl-unittest-3", + }, + }, + wantErr: nil, + }, + { + test: "cannot list storage buckets", + dirName: "google_storage_bucket_empty", + responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_storage_bucket", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + status.Error(codes.PermissionDenied, "The caller does not have permission"), + "google_storage_bucket", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + wantErr: nil, + }, + } + + providerVersion := "3.78.0" + resType := resource.ResourceType(googleresource.GoogleStorageBucketResourceType) + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + var assetClient *asset.Client + if !shouldUpdate { + var err error + assetClient, err = testgoogle.NewFakeAssetServer(c.response, c.responseErr) + if err != nil { + tt.Fatal(err) + } + } + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + // Replace mock by real resources if we are in update mode + if shouldUpdate { + ctx := context.Background() + assetClient, err = asset.NewClient(ctx) + if err != nil { + tt.Fatal(err) + } + err = realProvider.Init() + if err != nil { + tt.Fatal(err) + } + provider.ShouldUpdate() + } + + repo := repository2.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) + + remoteLibrary.AddEnumerator(google2.NewGoogleStorageBucketEnumerator(repo, factory)) + remoteLibrary.AddDetailsFetcher(resType, common2.NewGenericDetailsFetcher(resType, provider, deserializer)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, err, c.wantErr) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + test.TestAgainstGoldenFile(got, resType.String(), c.dirName, provider, deserializer, shouldUpdate, tt) + }) + } +} + +func TestGoogleStorageBucketIAMMember(t *testing.T) { + + cases := []struct { + test string + dirName string + assetRepositoryMock func(assetRepository *repository2.MockAssetRepository) + storageRepositoryMock func(storageRepository *repository2.MockStorageRepository) + responseErr error + setupAlerterMock func(alerter *mocks.AlerterInterface) + wantErr error + }{ + { + test: "no storage buckets", + dirName: "google_storage_bucket_member_empty", + assetRepositoryMock: func(assetRepository *repository2.MockAssetRepository) { + assetRepository.On("SearchAllBuckets").Return([]*assetpb.ResourceSearchResult{}, nil) + }, + wantErr: nil, + }, + { + test: "multiples storage buckets, no bindings", + dirName: "google_storage_bucket_member_empty", + assetRepositoryMock: func(assetRepository *repository2.MockAssetRepository) { + assetRepository.On("SearchAllBuckets").Return([]*assetpb.ResourceSearchResult{ + { + AssetType: "storage.googleapis.com/Bucket", + DisplayName: "dctlgstoragebucketiambinding-1", + }, + { + AssetType: "storage.googleapis.com/Bucket", + DisplayName: "dctlgstoragebucketiambinding-2", + }, + }, nil) + }, + storageRepositoryMock: func(storageRepository *repository2.MockStorageRepository) { + storageRepository.On("ListAllBindings", "dctlgstoragebucketiambinding-1").Return(map[string][]string{}, nil) + storageRepository.On("ListAllBindings", "dctlgstoragebucketiambinding-2").Return(map[string][]string{}, nil) + }, + wantErr: nil, + }, + { + test: "Cannot list bindings", + dirName: "google_storage_bucket_member_listing_error", + assetRepositoryMock: func(assetRepository *repository2.MockAssetRepository) { + assetRepository.On("SearchAllBuckets").Return([]*assetpb.ResourceSearchResult{ + { + AssetType: "storage.googleapis.com/Bucket", + DisplayName: "dctlgstoragebucketiambinding-1", + }, + }, nil) + }, + storageRepositoryMock: func(storageRepository *repository2.MockStorageRepository) { + storageRepository.On("ListAllBindings", "dctlgstoragebucketiambinding-1").Return( + map[string][]string{}, + errors.New("googleapi: Error 403: driftctl-acc-circle@driftctl-qa-1.iam.gserviceaccount.com does not have storage.buckets.getIamPolicy access to the Google Cloud Storage bucket., forbidden")) + }, + setupAlerterMock: func(alerter *mocks.AlerterInterface) { + alerter.On( + "SendAlert", + "google_storage_bucket_iam_member", + alerts.NewRemoteAccessDeniedAlert( + common2.RemoteGoogleTerraform, + remoteerr.NewResourceListingError( + errors.New("googleapi: Error 403: driftctl-acc-circle@driftctl-qa-1.iam.gserviceaccount.com does not have storage.buckets.getIamPolicy access to the Google Cloud Storage bucket., forbidden"), + "google_storage_bucket_iam_member", + ), + alerts.EnumerationPhase, + ), + ).Once() + }, + wantErr: nil, + }, + { + test: "multiples storage buckets, multiple bindings", + dirName: "google_storage_bucket_member_listing_multiple", + assetRepositoryMock: func(assetRepository *repository2.MockAssetRepository) { + assetRepository.On("SearchAllBuckets").Return([]*assetpb.ResourceSearchResult{ + { + AssetType: "storage.googleapis.com/Bucket", + DisplayName: "dctlgstoragebucketiambinding-1", + }, + { + AssetType: "storage.googleapis.com/Bucket", + DisplayName: "dctlgstoragebucketiambinding-2", + }, + }, nil) + }, + storageRepositoryMock: func(storageRepository *repository2.MockStorageRepository) { + storageRepository.On("ListAllBindings", "dctlgstoragebucketiambinding-1").Return(map[string][]string{ + "roles/storage.admin": {"user:elie.charra@cloudskiff.com"}, + "roles/storage.objectViewer": {"user:william.beuil@cloudskiff.com"}, + }, nil) + + storageRepository.On("ListAllBindings", "dctlgstoragebucketiambinding-2").Return(map[string][]string{ + "roles/storage.admin": {"user:william.beuil@cloudskiff.com"}, + "roles/storage.objectViewer": {"user:elie.charra@cloudskiff.com"}, + }, nil) + }, + wantErr: nil, + }, + } + + providerVersion := "3.78.0" + resType := resource.ResourceType(googleresource.GoogleStorageBucketIamMemberResourceType) + schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) + googleresource.InitResourcesMetadata(schemaRepository) + factory := terraform3.NewTerraformResourceFactory(schemaRepository) + deserializer := resource.NewDeserializer(factory) + + for _, c := range cases { + t.Run(c.test, func(tt *testing.T) { + repositoryCache := cache.New(100) + + shouldUpdate := c.dirName == *goldenfile.Update + + scanOptions := ScannerOptions{Deep: true} + providerLibrary := terraform3.NewProviderLibrary() + remoteLibrary := common2.NewRemoteLibrary() + + // Initialize mocks + alerter := &mocks.AlerterInterface{} + if c.setupAlerterMock != nil { + c.setupAlerterMock(alerter) + } + + storageRepo := &repository2.MockStorageRepository{} + if c.storageRepositoryMock != nil { + c.storageRepositoryMock(storageRepo) + } + var storageRepository repository2.StorageRepository = storageRepo + if shouldUpdate { + storageClient, err := storage.NewClient(context.Background()) + if err != nil { + panic(err) + } + storageRepository = repository2.NewStorageRepository(storageClient, repositoryCache) + } + + assetRepo := &repository2.MockAssetRepository{} + if c.assetRepositoryMock != nil { + c.assetRepositoryMock(assetRepo) + } + var assetRepository repository2.AssetRepository = assetRepo + + realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) + if err != nil { + tt.Fatal(err) + } + provider := terraform2.NewFakeTerraformProvider(realProvider) + provider.WithResponse(c.dirName) + + remoteLibrary.AddEnumerator(google2.NewGoogleStorageBucketIamMemberEnumerator(assetRepository, storageRepository, factory)) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", mock.Anything).Return(false) + + s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) + got, err := s.Resources() + assert.Equal(tt, c.wantErr, err) + if err != nil { + return + } + alerter.AssertExpectations(tt) + testFilter.AssertExpectations(tt) + test.TestAgainstGoldenFile(got, resType.String(), c.dirName, provider, deserializer, shouldUpdate, tt) + }) + } +} diff --git a/enumeration/remote/remote.go b/enumeration/remote/remote.go new file mode 100644 index 00000000..eb2f3816 --- /dev/null +++ b/enumeration/remote/remote.go @@ -0,0 +1,56 @@ +package remote + +import ( + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/aws" + "github.com/snyk/driftctl/enumeration/remote/azurerm" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + "github.com/snyk/driftctl/enumeration/remote/github" + "github.com/snyk/driftctl/enumeration/remote/google" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/terraform" +) + +var supportedRemotes = []string{ + common2.RemoteAWSTerraform, + common2.RemoteGithubTerraform, + common2.RemoteGoogleTerraform, + common2.RemoteAzureTerraform, +} + +func IsSupported(remote string) bool { + for _, r := range supportedRemotes { + if r == remote { + return true + } + } + return false +} + +func Activate(remote, version string, alerter *alerter.Alerter, + providerLibrary *terraform.ProviderLibrary, + remoteLibrary *common2.RemoteLibrary, + progress enumeration.ProgressCounter, + resourceSchemaRepository *resource.SchemaRepository, + factory resource.ResourceFactory, + configDir string) error { + switch remote { + case common2.RemoteAWSTerraform: + return aws.Init(version, alerter, providerLibrary, remoteLibrary, progress, resourceSchemaRepository, factory, configDir) + case common2.RemoteGithubTerraform: + return github.Init(version, alerter, providerLibrary, remoteLibrary, progress, resourceSchemaRepository, factory, configDir) + case common2.RemoteGoogleTerraform: + return google.Init(version, alerter, providerLibrary, remoteLibrary, progress, resourceSchemaRepository, factory, configDir) + case common2.RemoteAzureTerraform: + return azurerm.Init(version, alerter, providerLibrary, remoteLibrary, progress, resourceSchemaRepository, factory, configDir) + + default: + return errors.Errorf("unsupported remote '%s'", remote) + } +} + +func GetSupportedRemotes() []string { + return supportedRemotes +} diff --git a/enumeration/remote/resource_enumeration_error_handler.go b/enumeration/remote/resource_enumeration_error_handler.go new file mode 100644 index 00000000..c2d09274 --- /dev/null +++ b/enumeration/remote/resource_enumeration_error_handler.go @@ -0,0 +1,116 @@ +package remote + +import ( + "strings" + + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/common" + remoteerror "github.com/snyk/driftctl/enumeration/remote/error" + + "github.com/aws/aws-sdk-go/aws/awserr" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func HandleResourceEnumerationError(err error, alerter alerter.AlerterInterface) error { + listError, ok := err.(*remoteerror.ResourceScanningError) + if !ok { + return err + } + + rootCause := listError.RootCause() + + // We cannot use the status.FromError() method because AWS errors are not well-formed. + // Indeed, they compose the error interface without implementing the Error() method and thus triggering a nil panic + // when returning an unknown error from status.FromError() + // As a workaround we duplicated the logic from status.FromError here + if _, ok := rootCause.(interface{ GRPCStatus() *status.Status }); ok { + return handleGoogleEnumerationError(alerter, listError, status.Convert(rootCause)) + } + + // at least for storage api google sdk does not return grpc error so we parse the error message. + if shouldHandleGoogleForbiddenError(listError) { + alerts.SendEnumerationAlert(common.RemoteGoogleTerraform, alerter, listError) + return nil + } + + reqerr, ok := rootCause.(awserr.RequestFailure) + if ok { + return handleAWSError(alerter, listError, reqerr) + } + + // This handles access denied errors like the following: + // aws_s3_bucket_policy: AccessDenied: Error listing bucket policy + if strings.Contains(rootCause.Error(), "AccessDenied") { + alerts.SendEnumerationAlert(common.RemoteAWSTerraform, alerter, listError) + return nil + } + + if strings.HasPrefix( + rootCause.Error(), + "Your token has not been granted the required scopes to execute this query.", + ) { + alerts.SendEnumerationAlert(common.RemoteGithubTerraform, alerter, listError) + return nil + } + + return err +} + +func HandleResourceDetailsFetchingError(err error, alerter alerter.AlerterInterface) error { + listError, ok := err.(*remoteerror.ResourceScanningError) + if !ok { + return err + } + + rootCause := listError.RootCause() + + if shouldHandleGoogleForbiddenError(listError) { + alerts.SendDetailsFetchingAlert(common.RemoteGoogleTerraform, alerter, listError) + return nil + } + + // This handles access denied errors like the following: + // iam_role_policy: error reading IAM Role Policy (): AccessDenied: User: ... + if strings.HasPrefix(rootCause.Error(), "AccessDeniedException") || + strings.Contains(rootCause.Error(), "AccessDenied") || + strings.Contains(rootCause.Error(), "AuthorizationError") { + alerts.SendDetailsFetchingAlert(common.RemoteAWSTerraform, alerter, listError) + return nil + } + + return err +} + +func handleAWSError(alerter alerter.AlerterInterface, listError *remoteerror.ResourceScanningError, reqerr awserr.RequestFailure) error { + if reqerr.StatusCode() == 403 || (reqerr.StatusCode() == 400 && strings.Contains(reqerr.Code(), "AccessDenied")) { + alerts.SendEnumerationAlert(common.RemoteAWSTerraform, alerter, listError) + return nil + } + + return reqerr +} + +func handleGoogleEnumerationError(alerter alerter.AlerterInterface, err *remoteerror.ResourceScanningError, st *status.Status) error { + if st.Code() == codes.PermissionDenied { + alerts.SendEnumerationAlert(common.RemoteGoogleTerraform, alerter, err) + return nil + } + return err +} + +func shouldHandleGoogleForbiddenError(err *remoteerror.ResourceScanningError) bool { + errMsg := err.RootCause().Error() + + // Check if this is a Google related error + if !strings.Contains(errMsg, "googleapi") { + return false + } + + if strings.Contains(errMsg, "Error 403") { + return true + } + + return false +} diff --git a/enumeration/remote/resource_enumeration_error_handler_test.go b/enumeration/remote/resource_enumeration_error_handler_test.go new file mode 100644 index 00000000..fcf3162f --- /dev/null +++ b/enumeration/remote/resource_enumeration_error_handler_test.go @@ -0,0 +1,375 @@ +package remote + +import ( + "errors" + "testing" + + alerter2 "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + + resourcegithub "github.com/snyk/driftctl/enumeration/resource/github" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-sdk-go/aws/awserr" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" +) + +func TestHandleAwsEnumerationErrors(t *testing.T) { + + tests := []struct { + name string + err error + wantAlerts alerter2.Alerts + wantErr bool + }{ + { + name: "Handled error 403", + err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("", "", errors.New("")), 403, ""), resourceaws.AwsVpcResourceType), + wantAlerts: alerter2.Alerts{"aws_vpc": []alerter2.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awserr.NewRequestFailure(awserr.New("", "", errors.New("")), 403, ""), "aws_vpc", "aws_vpc"), alerts.EnumerationPhase)}}, + wantErr: false, + }, + { + name: "Handled error AccessDenied", + err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, ""), resourceaws.AwsDynamodbTableResourceType), + wantAlerts: alerter2.Alerts{"aws_dynamodb_table": []alerter2.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, ""), "aws_dynamodb_table", "aws_dynamodb_table"), alerts.EnumerationPhase)}}, + wantErr: false, + }, + { + name: "Not Handled error code", + err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("", "", errors.New("")), 404, ""), resourceaws.AwsVpcResourceType), + wantAlerts: map[string][]alerter2.Alert{}, + wantErr: true, + }, + { + name: "Not Handled error type", + err: errors.New("error"), + wantAlerts: map[string][]alerter2.Alert{}, + wantErr: true, + }, + { + name: "Not Handled root error type", + err: remoteerr.NewResourceListingError(errors.New("error"), resourceaws.AwsVpcResourceType), + wantAlerts: map[string][]alerter2.Alert{}, + wantErr: true, + }, + { + name: "Handle AccessDenied error", + err: remoteerr.NewResourceListingError(errors.New("an error occured: AccessDenied: 403"), resourceaws.AwsVpcResourceType), + wantAlerts: alerter2.Alerts{"aws_vpc": []alerter2.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("an error occured: AccessDenied: 403"), "aws_vpc", "aws_vpc"), alerts.EnumerationPhase)}}, + wantErr: false, + }, + { + name: "Access denied error on a single resource", + err: remoteerr.NewResourceScanningError(errors.New("Error: AccessDenied: 403 ..."), resourceaws.AwsS3BucketResourceType, "my-bucket"), + wantAlerts: alerter2.Alerts{"aws_s3_bucket.my-bucket": []alerter2.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Error: AccessDenied: 403 ..."), "aws_s3_bucket.my-bucket", "aws_s3_bucket"), alerts.EnumerationPhase)}}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + alertr := alerter2.NewAlerter() + gotErr := HandleResourceEnumerationError(tt.err, alertr) + assert.Equal(t, tt.wantErr, gotErr != nil) + + retrieve := alertr.Retrieve() + assert.Equal(t, tt.wantAlerts, retrieve) + + }) + } +} + +func TestHandleGithubEnumerationErrors(t *testing.T) { + + tests := []struct { + name string + err error + wantAlerts alerter2.Alerts + wantErr bool + }{ + { + name: "Handled graphql error", + err: remoteerr.NewResourceListingError(errors.New("Your token has not been granted the required scopes to execute this query."), resourcegithub.GithubTeamResourceType), + wantAlerts: alerter2.Alerts{"github_team": []alerter2.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteGithubTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Your token has not been granted the required scopes to execute this query."), "github_team", "github_team"), alerts.EnumerationPhase)}}, + wantErr: false, + }, + { + name: "Not handled graphql error", + err: remoteerr.NewResourceListingError(errors.New("This is a not handler graphql error"), resourcegithub.GithubTeamResourceType), + wantAlerts: map[string][]alerter2.Alert{}, + wantErr: true, + }, + { + name: "Not Handled error type", + err: errors.New("error"), + wantAlerts: map[string][]alerter2.Alert{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + alertr := alerter2.NewAlerter() + gotErr := HandleResourceEnumerationError(tt.err, alertr) + assert.Equal(t, tt.wantErr, gotErr != nil) + + retrieve := alertr.Retrieve() + assert.Equal(t, tt.wantAlerts, retrieve) + + }) + } +} + +func TestHandleGoogleEnumerationErrors(t *testing.T) { + tests := []struct { + name string + err error + wantAlerts alerter2.Alerts + wantErr bool + }{ + { + name: "Handled 403 error", + err: remoteerr.NewResourceListingError(status.Error(codes.PermissionDenied, "useless message"), "google_type"), + wantAlerts: alerter2.Alerts{"google_type": []alerter2.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteGoogleTerraform, remoteerr.NewResourceListingErrorWithType(status.Error(codes.PermissionDenied, "useless message"), "google_type", "google_type"), alerts.EnumerationPhase)}}, + wantErr: false, + }, + { + name: "Not handled non 403 error", + err: status.Error(codes.Unknown, ""), + wantAlerts: map[string][]alerter2.Alert{}, + wantErr: true, + }, + { + name: "Not Handled error type", + err: errors.New("error"), + wantAlerts: map[string][]alerter2.Alert{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + alertr := alerter2.NewAlerter() + gotErr := HandleResourceEnumerationError(tt.err, alertr) + assert.Equal(t, tt.wantErr, gotErr != nil) + + retrieve := alertr.Retrieve() + assert.Equal(t, tt.wantAlerts, retrieve) + + }) + } +} + +func TestHandleAwsDetailsFetchingErrors(t *testing.T) { + + tests := []struct { + name string + err error + wantAlerts alerter2.Alerts + wantErr bool + }{ + { + name: "Handle AccessDeniedException error", + err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("AccessDeniedException", "test", errors.New("")), 403, ""), resourceaws.AwsVpcResourceType), + wantAlerts: alerter2.Alerts{"aws_vpc": []alerter2.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awserr.NewRequestFailure(awserr.New("AccessDeniedException", "test", errors.New("")), 403, ""), "aws_vpc", "aws_vpc"), alerts.DetailsFetchingPhase)}}, + wantErr: false, + }, + { + name: "Handle AccessDenied error", + err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("test", "error: AccessDenied", errors.New("")), 403, ""), resourceaws.AwsVpcResourceType), + wantAlerts: alerter2.Alerts{"aws_vpc": []alerter2.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awserr.NewRequestFailure(awserr.New("test", "error: AccessDenied", errors.New("")), 403, ""), "aws_vpc", "aws_vpc"), alerts.DetailsFetchingPhase)}}, + wantErr: false, + }, + { + name: "Handle AuthorizationError error", + err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("test", "error: AuthorizationError", errors.New("")), 403, ""), resourceaws.AwsVpcResourceType), + wantAlerts: alerter2.Alerts{"aws_vpc": []alerter2.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awserr.NewRequestFailure(awserr.New("test", "error: AuthorizationError", errors.New("")), 403, ""), "aws_vpc", "aws_vpc"), alerts.DetailsFetchingPhase)}}, + wantErr: false, + }, + { + name: "Unhandled error", + err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("test", "error: dummy error", errors.New("")), 403, ""), resourceaws.AwsVpcResourceType), + wantAlerts: alerter2.Alerts{}, + wantErr: true, + }, + { + name: "Not Handled error type", + err: errors.New("error"), + wantAlerts: map[string][]alerter2.Alert{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + alertr := alerter2.NewAlerter() + gotErr := HandleResourceDetailsFetchingError(tt.err, alertr) + assert.Equal(t, tt.wantErr, gotErr != nil) + + retrieve := alertr.Retrieve() + assert.Equal(t, tt.wantAlerts, retrieve) + + }) + } +} + +func TestHandleGoogleDetailsFetchingErrors(t *testing.T) { + + tests := []struct { + name string + err error + wantAlerts alerter2.Alerts + wantErr bool + }{ + { + name: "Handle 403 error", + err: remoteerr.NewResourceScanningError( + errors.New("Error when reading or editing Storage Bucket \"driftctl-unittest-1\": googleapi: Error 403: driftctl@elie-dev.iam.gserviceaccount.com does not have storage.buckets.get access to the Google Cloud Storage bucket., forbidden"), + "google_type", + "resource_id", + ), + wantAlerts: alerter2.Alerts{"google_type.resource_id": []alerter2.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteGoogleTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Error when reading or editing Storage Bucket \"driftctl-unittest-1\": googleapi: Error 403: driftctl@elie-dev.iam.gserviceaccount.com does not have storage.buckets.get access to the Google Cloud Storage bucket., forbidden"), "google_type.resource_id", "google_type"), alerts.DetailsFetchingPhase)}}, + wantErr: false, + }, + { + name: "do not handle google unrelated error", + err: remoteerr.NewResourceScanningError( + errors.New("this string does not contains g o o g l e a p i string and thus should not be matched"), + "google_type", + "resource_id", + ), wantAlerts: alerter2.Alerts{}, + wantErr: true, + }, + { + name: "do not handle google error other than 403", + err: remoteerr.NewResourceScanningError( + errors.New("Error when reading or editing Storage Bucket \"driftctl-unittest-1\": googleapi: Error 404: not found"), + "google_type", + "resource_id", + ), wantAlerts: alerter2.Alerts{}, + wantErr: true, + }, + { + name: "Not Handled error type", + err: errors.New("error"), + wantAlerts: map[string][]alerter2.Alert{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + alertr := alerter2.NewAlerter() + gotErr := HandleResourceDetailsFetchingError(tt.err, alertr) + assert.Equal(t, tt.wantErr, gotErr != nil) + + retrieve := alertr.Retrieve() + assert.Equal(t, tt.wantAlerts, retrieve) + + }) + } +} + +func TestEnumerationAccessDeniedAlert_GetProviderMessage(t *testing.T) { + tests := []struct { + name string + provider string + want string + }{ + { + name: "test for unsupported provider", + provider: "foobar", + want: "", + }, + { + name: "test for AWS", + provider: common.RemoteAWSTerraform, + want: "It seems that we got access denied exceptions while listing resources.\nThe latest minimal read-only IAM policy for driftctl is always available here, please update yours: https://docs.driftctl.com/aws/policy", + }, + { + name: "test for github", + provider: common.RemoteGithubTerraform, + want: "It seems that we got access denied exceptions while listing resources.\nPlease be sure that your Github token has the right permissions, check the last up-to-date documentation there: https://docs.driftctl.com/github/policy", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := alerts.NewRemoteAccessDeniedAlert(tt.provider, remoteerr.NewResourceListingErrorWithType(errors.New("dummy error"), "supplier_type", "listed_type_error"), alerts.EnumerationPhase) + if got := e.GetProviderMessage(); got != tt.want { + t.Errorf("GetProviderMessage() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDetailsFetchingAccessDeniedAlert_GetProviderMessage(t *testing.T) { + tests := []struct { + name string + provider string + want string + }{ + { + name: "test for unsupported provider", + provider: "foobar", + want: "", + }, + { + name: "test for AWS", + provider: common.RemoteAWSTerraform, + want: "It seems that we got access denied exceptions while reading details of resources.\nThe latest minimal read-only IAM policy for driftctl is always available here, please update yours: https://docs.driftctl.com/aws/policy", + }, + { + name: "test for github", + provider: common.RemoteGithubTerraform, + want: "It seems that we got access denied exceptions while reading details of resources.\nPlease be sure that your Github token has the right permissions, check the last up-to-date documentation there: https://docs.driftctl.com/github/policy", + }, + { + name: "test for google", + provider: common.RemoteGoogleTerraform, + want: "It seems that we got access denied exceptions while reading details of resources.\nPlease ensure that you have configured the required roles, please check our documentation at https://docs.driftctl.com/google/policy", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := alerts.NewRemoteAccessDeniedAlert(tt.provider, remoteerr.NewResourceListingErrorWithType(errors.New("dummy error"), "supplier_type", "listed_type_error"), alerts.DetailsFetchingPhase) + if got := e.GetProviderMessage(); got != tt.want { + t.Errorf("GetProviderMessage() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestResourceScanningErrorMethods(t *testing.T) { + + tests := []struct { + name string + err *remoteerr.ResourceScanningError + expectedError string + expectedResourceType string + }{ + { + name: "Handled error AccessDenied", + err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, ""), resourceaws.AwsDynamodbTableResourceType), + expectedError: "error scanning resource type aws_dynamodb_table: AccessDeniedException: \n\tstatus code: 403, request id: \ncaused by: ", + expectedResourceType: resourceaws.AwsDynamodbTableResourceType, + }, + { + name: "Handle AccessDenied error", + err: remoteerr.NewResourceListingError(errors.New("an error occured: AccessDenied: 403"), resourceaws.AwsVpcResourceType), + expectedError: "error scanning resource type aws_vpc: an error occured: AccessDenied: 403", + expectedResourceType: resourceaws.AwsVpcResourceType, + }, + { + name: "Access denied error on a single resource", + err: remoteerr.NewResourceScanningError(errors.New("Error: AccessDenied: 403 ..."), resourceaws.AwsS3BucketResourceType, "my-bucket"), + expectedError: "error scanning resource aws_s3_bucket.my-bucket: Error: AccessDenied: 403 ...", + expectedResourceType: resourceaws.AwsS3BucketResourceType, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expectedError, tt.err.Error()) + assert.Equal(t, tt.expectedResourceType, tt.err.ResourceType()) + }) + } +} diff --git a/enumeration/remote/scanner.go b/enumeration/remote/scanner.go new file mode 100644 index 00000000..5d52a913 --- /dev/null +++ b/enumeration/remote/scanner.go @@ -0,0 +1,136 @@ +package remote + +import ( + "context" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/parallel" + "github.com/snyk/driftctl/enumeration/remote/common" + "github.com/snyk/driftctl/enumeration/resource" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +type ScannerOptions struct { + Deep bool +} + +type Scanner struct { + enumeratorRunner *parallel.ParallelRunner + detailsFetcherRunner *parallel.ParallelRunner + remoteLibrary *common.RemoteLibrary + alerter alerter.AlerterInterface + options ScannerOptions + filter enumeration.Filter +} + +func NewScanner(remoteLibrary *common.RemoteLibrary, alerter alerter.AlerterInterface, options ScannerOptions, filter enumeration.Filter) *Scanner { + return &Scanner{ + enumeratorRunner: parallel.NewParallelRunner(context.TODO(), 10), + detailsFetcherRunner: parallel.NewParallelRunner(context.TODO(), 10), + remoteLibrary: remoteLibrary, + alerter: alerter, + options: options, + filter: filter, + } +} + +func (s *Scanner) retrieveRunnerResults(runner *parallel.ParallelRunner) ([]*resource.Resource, error) { + results := make([]*resource.Resource, 0) +loop: + for { + select { + case resources, ok := <-runner.Read(): + if !ok || resources == nil { + break loop + } + + for _, res := range resources.([]*resource.Resource) { + if res != nil { + results = append(results, res) + } + } + case <-runner.DoneChan(): + break loop + } + } + return results, runner.Err() +} + +func (s *Scanner) scan() ([]*resource.Resource, error) { + for _, enumerator := range s.remoteLibrary.Enumerators() { + if s.filter.IsTypeIgnored(enumerator.SupportedType()) { + logrus.WithFields(logrus.Fields{ + "type": enumerator.SupportedType(), + }).Debug("Ignored enumeration of resources since it is ignored in filter") + continue + } + enumerator := enumerator + s.enumeratorRunner.Run(func() (interface{}, error) { + resources, err := enumerator.Enumerate() + if err != nil { + err := HandleResourceEnumerationError(err, s.alerter) + if err == nil { + return []*resource.Resource{}, nil + } + return nil, err + } + for _, res := range resources { + if res == nil { + continue + } + logrus.WithFields(logrus.Fields{ + "id": res.ResourceId(), + "type": res.ResourceType(), + }).Debug("Found cloud resource") + } + return resources, nil + }) + } + + enumerationResult, err := s.retrieveRunnerResults(s.enumeratorRunner) + if err != nil { + return nil, err + } + + if !s.options.Deep { + return enumerationResult, nil + } + + for _, res := range enumerationResult { + res := res + s.detailsFetcherRunner.Run(func() (interface{}, error) { + fetcher := s.remoteLibrary.GetDetailsFetcher(resource.ResourceType(res.ResourceType())) + if fetcher == nil { + return []*resource.Resource{res}, nil + } + + resourceWithDetails, err := fetcher.ReadDetails(res) + if err != nil { + if err := HandleResourceDetailsFetchingError(err, s.alerter); err != nil { + return nil, err + } + return []*resource.Resource{}, nil + } + return []*resource.Resource{resourceWithDetails}, nil + }) + } + + return s.retrieveRunnerResults(s.detailsFetcherRunner) +} + +func (s *Scanner) Resources() ([]*resource.Resource, error) { + resources, err := s.scan() + if err != nil { + return nil, err + } + return resources, err +} + +func (s *Scanner) Stop() { + logrus.Debug("Stopping scanner") + s.enumeratorRunner.Stop(errors.New("interrupted")) + s.detailsFetcherRunner.Stop(errors.New("interrupted")) +} diff --git a/enumeration/remote/scanner_test.go b/enumeration/remote/scanner_test.go new file mode 100644 index 00000000..5e14f1da --- /dev/null +++ b/enumeration/remote/scanner_test.go @@ -0,0 +1,33 @@ +package remote + +import ( + "testing" + + "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/alerter" + common2 "github.com/snyk/driftctl/enumeration/remote/common" + + "github.com/snyk/driftctl/enumeration/resource" + + "github.com/stretchr/testify/assert" +) + +func TestScannerShouldIgnoreType(t *testing.T) { + + // Initialize mocks + alerter := alerter.NewAlerter() + fakeEnumerator := &common2.MockEnumerator{} + fakeEnumerator.On("SupportedType").Return(resource.ResourceType("FakeType")) + fakeEnumerator.AssertNotCalled(t, "Enumerate") + + remoteLibrary := common2.NewRemoteLibrary() + remoteLibrary.AddEnumerator(fakeEnumerator) + + testFilter := &enumeration.MockFilter{} + testFilter.On("IsTypeIgnored", resource.ResourceType("FakeType")).Return(true) + + s := NewScanner(remoteLibrary, alerter, ScannerOptions{}, testFilter) + _, err := s.Resources() + assert.Nil(t, err) + fakeEnumerator.AssertExpectations(t) +} diff --git a/enumeration/remote/terraform/provider.go b/enumeration/remote/terraform/provider.go new file mode 100644 index 00000000..28c913f7 --- /dev/null +++ b/enumeration/remote/terraform/provider.go @@ -0,0 +1,224 @@ +package terraform + +import ( + "context" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/eapache/go-resiliency/retrier" + "github.com/hashicorp/terraform/plugin" + "github.com/hashicorp/terraform/plugin/discovery" + "github.com/hashicorp/terraform/providers" + "github.com/hashicorp/terraform/terraform" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + progress2 "github.com/snyk/driftctl/enumeration" + "github.com/snyk/driftctl/enumeration/parallel" + terraform2 "github.com/snyk/driftctl/enumeration/terraform" + "github.com/zclconf/go-cty/cty" + "github.com/zclconf/go-cty/cty/gocty" +) + +const EXIT_ERROR = 3 + +// "alias" in these struct are a way to namespace gRPC clients. +// For example, if we need to read S3 bucket from multiple AWS region, +// we'll have an alias per region, and the alias IS the region itself. +// So we can query resources using a specific custom provider configuration +type TerraformProviderConfig struct { + Name string + DefaultAlias string + GetProviderConfig func(alias string) interface{} +} + +type TerraformProvider struct { + lock sync.Mutex + providerInstaller *terraform2.ProviderInstaller + grpcProviders map[string]*plugin.GRPCProvider + schemas map[string]providers.Schema + Config TerraformProviderConfig + runner *parallel.ParallelRunner + progress progress2.ProgressCounter +} + +func NewTerraformProvider(installer *terraform2.ProviderInstaller, config TerraformProviderConfig, progress progress2.ProgressCounter) (*TerraformProvider, error) { + p := TerraformProvider{ + providerInstaller: installer, + runner: parallel.NewParallelRunner(context.TODO(), 10), + grpcProviders: make(map[string]*plugin.GRPCProvider), + Config: config, + progress: progress, + } + return &p, nil +} + +func (p *TerraformProvider) Init() error { + stopCh := make(chan bool) + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + select { + case <-c: + logrus.Warn("Detected interrupt during terraform provider configuration, cleanup ...") + p.Cleanup() + os.Exit(EXIT_ERROR) + case <-stopCh: + return + } + }() + defer func() { + stopCh <- true + }() + err := p.configure(p.Config.DefaultAlias) + if err != nil { + return err + } + return nil +} + +func (p *TerraformProvider) Schema() map[string]providers.Schema { + return p.schemas +} + +func (p *TerraformProvider) Runner() *parallel.ParallelRunner { + return p.runner +} + +func (p *TerraformProvider) configure(alias string) error { + providerPath, err := p.providerInstaller.Install() + if err != nil { + return err + } + + if p.grpcProviders[alias] == nil { + logrus.WithFields(logrus.Fields{ + "alias": alias, + }).Debug("Starting gRPC client") + GRPCProvider, err := terraform2.NewGRPCProvider(discovery.PluginMeta{ + Path: providerPath, + }) + + if err != nil { + return err + } + p.grpcProviders[alias] = GRPCProvider + } + + schema := p.grpcProviders[alias].GetSchema() + if p.schemas == nil { + p.schemas = schema.ResourceTypes + } + + // This value is optional. It'll be overridden by the provider config. + config := cty.NullVal(cty.DynamicPseudoType) + + if p.Config.GetProviderConfig != nil { + configType := schema.Provider.Block.ImpliedType() + config, err = gocty.ToCtyValue(p.Config.GetProviderConfig(alias), configType) + if err != nil { + return err + } + } + + resp := p.grpcProviders[alias].Configure(providers.ConfigureRequest{ + Config: config, + }) + if resp.Diagnostics.HasErrors() { + return resp.Diagnostics.Err() + } + + logrus.WithFields(logrus.Fields{ + "alias": alias, + }).Debug("New gRPC client started") + + logrus.WithFields(logrus.Fields{ + "name": p.Config.Name, + "alias": alias, + }).Debug("Terraform provider initialized") + + return nil +} + +func (p *TerraformProvider) ReadResource(args terraform2.ReadResourceArgs) (*cty.Value, error) { + + logrus.WithFields(logrus.Fields{ + "id": args.ID, + "type": args.Ty, + "attrs": args.Attributes, + }).Debugf("Reading cloud resource") + + typ := string(args.Ty) + state := &terraform.InstanceState{ + ID: args.ID, + Attributes: map[string]string{}, + } + + alias := p.Config.DefaultAlias + if args.Attributes["alias"] != "" { + alias = args.Attributes["alias"] + delete(args.Attributes, "alias") + } + + p.lock.Lock() + if p.grpcProviders[alias] == nil { + err := p.configure(alias) + if err != nil { + return nil, err + } + } + p.lock.Unlock() + + if args.Attributes != nil && len(args.Attributes) > 0 { + // call to the provider sometimes add and delete field to their attribute this may broke caller so we deep copy attributes + state.Attributes = make(map[string]string, len(args.Attributes)) + for k, v := range args.Attributes { + state.Attributes[k] = v + } + } + + impliedType := p.schemas[typ].Block.ImpliedType() + + priorState, err := state.AttrsAsObjectValue(impliedType) + if err != nil { + return nil, err + } + + var newState cty.Value + r := retrier.New(retrier.ConstantBackoff(3, 100*time.Millisecond), nil) + + err = r.Run(func() error { + resp := p.grpcProviders[alias].ReadResource(providers.ReadResourceRequest{ + TypeName: typ, + PriorState: priorState, + Private: []byte{}, + ProviderMeta: cty.NullVal(cty.DynamicPseudoType), + }) + if resp.Diagnostics.HasErrors() { + return resp.Diagnostics.Err() + } + nonFatalErr := resp.Diagnostics.NonFatalErr() + if resp.NewState.IsNull() && nonFatalErr != nil { + return errors.Errorf("state returned by ReadResource is nil: %+v", nonFatalErr) + } + newState = resp.NewState + return nil + }) + + if err != nil { + return nil, err + } + p.progress.Inc() + return &newState, nil +} + +func (p *TerraformProvider) Cleanup() { + for alias, client := range p.grpcProviders { + logrus.WithFields(logrus.Fields{ + "alias": alias, + }).Debug("Closing gRPC client") + client.Close() + } +} diff --git a/pkg/remote/test/aws_appautoscaling_policy_single/aws_appautoscaling_policy-DynamoDBReadCapacityUtilization_table_GameScores-DynamoDBReadCapacityUtilization_table_GameScores-table_GameScores-dynamodb_table_ReadCapacityUnits-dynamodb.res.golden.json b/enumeration/remote/test/aws_appautoscaling_policy_single/aws_appautoscaling_policy-DynamoDBReadCapacityUtilization_table_GameScores-DynamoDBReadCapacityUtilization_table_GameScores-table_GameScores-dynamodb_table_ReadCapacityUnits-dynamodb.res.golden.json similarity index 100% rename from pkg/remote/test/aws_appautoscaling_policy_single/aws_appautoscaling_policy-DynamoDBReadCapacityUtilization_table_GameScores-DynamoDBReadCapacityUtilization_table_GameScores-table_GameScores-dynamodb_table_ReadCapacityUnits-dynamodb.res.golden.json rename to enumeration/remote/test/aws_appautoscaling_policy_single/aws_appautoscaling_policy-DynamoDBReadCapacityUtilization_table_GameScores-DynamoDBReadCapacityUtilization_table_GameScores-table_GameScores-dynamodb_table_ReadCapacityUnits-dynamodb.res.golden.json diff --git a/pkg/remote/test/aws_appautoscaling_policy_single/results.golden.json b/enumeration/remote/test/aws_appautoscaling_policy_single/results.golden.json similarity index 100% rename from pkg/remote/test/aws_appautoscaling_policy_single/results.golden.json rename to enumeration/remote/test/aws_appautoscaling_policy_single/results.golden.json diff --git a/pkg/remote/test/aws_appautoscaling_target_single/aws_appautoscaling_target-table_GameScores-dynamodb_table_ReadCapacityUnits-dynamodb.res.golden.json b/enumeration/remote/test/aws_appautoscaling_target_single/aws_appautoscaling_target-table_GameScores-dynamodb_table_ReadCapacityUnits-dynamodb.res.golden.json similarity index 100% rename from pkg/remote/test/aws_appautoscaling_target_single/aws_appautoscaling_target-table_GameScores-dynamodb_table_ReadCapacityUnits-dynamodb.res.golden.json rename to enumeration/remote/test/aws_appautoscaling_target_single/aws_appautoscaling_target-table_GameScores-dynamodb_table_ReadCapacityUnits-dynamodb.res.golden.json diff --git a/pkg/remote/test/aws_appautoscaling_target_single/results.golden.json b/enumeration/remote/test/aws_appautoscaling_target_single/results.golden.json similarity index 100% rename from pkg/remote/test/aws_appautoscaling_target_single/results.golden.json rename to enumeration/remote/test/aws_appautoscaling_target_single/results.golden.json diff --git a/pkg/remote/test/aws_cloudformation_stack_multiple/aws_cloudformation_stack-arn_aws_cloudformation_us-east-1_047081014315_stack_bar-stack_c7a96e70-0f21-11ec-bd2a-0a2d95c2b2ab.res.golden.json b/enumeration/remote/test/aws_cloudformation_stack_multiple/aws_cloudformation_stack-arn_aws_cloudformation_us-east-1_047081014315_stack_bar-stack_c7a96e70-0f21-11ec-bd2a-0a2d95c2b2ab.res.golden.json similarity index 100% rename from pkg/remote/test/aws_cloudformation_stack_multiple/aws_cloudformation_stack-arn_aws_cloudformation_us-east-1_047081014315_stack_bar-stack_c7a96e70-0f21-11ec-bd2a-0a2d95c2b2ab.res.golden.json rename to enumeration/remote/test/aws_cloudformation_stack_multiple/aws_cloudformation_stack-arn_aws_cloudformation_us-east-1_047081014315_stack_bar-stack_c7a96e70-0f21-11ec-bd2a-0a2d95c2b2ab.res.golden.json diff --git a/pkg/remote/test/aws_cloudformation_stack_multiple/aws_cloudformation_stack-arn_aws_cloudformation_us-east-1_047081014315_stack_foo-stack_c7aa0ab0-0f21-11ec-ba25-129d8c0b3757.res.golden.json b/enumeration/remote/test/aws_cloudformation_stack_multiple/aws_cloudformation_stack-arn_aws_cloudformation_us-east-1_047081014315_stack_foo-stack_c7aa0ab0-0f21-11ec-ba25-129d8c0b3757.res.golden.json similarity index 100% rename from pkg/remote/test/aws_cloudformation_stack_multiple/aws_cloudformation_stack-arn_aws_cloudformation_us-east-1_047081014315_stack_foo-stack_c7aa0ab0-0f21-11ec-ba25-129d8c0b3757.res.golden.json rename to enumeration/remote/test/aws_cloudformation_stack_multiple/aws_cloudformation_stack-arn_aws_cloudformation_us-east-1_047081014315_stack_foo-stack_c7aa0ab0-0f21-11ec-ba25-129d8c0b3757.res.golden.json diff --git a/pkg/remote/test/aws_cloudformation_stack_multiple/iam.yml b/enumeration/remote/test/aws_cloudformation_stack_multiple/iam.yml similarity index 100% rename from pkg/remote/test/aws_cloudformation_stack_multiple/iam.yml rename to enumeration/remote/test/aws_cloudformation_stack_multiple/iam.yml diff --git a/enumeration/remote/test/aws_cloudformation_stack_multiple/results.golden.json b/enumeration/remote/test/aws_cloudformation_stack_multiple/results.golden.json new file mode 100755 index 00000000..7b25d795 --- /dev/null +++ b/enumeration/remote/test/aws_cloudformation_stack_multiple/results.golden.json @@ -0,0 +1,42 @@ +[ + { + "capabilities": null, + "disable_rollback": false, + "iam_role_arn": "", + "id": "arn:aws:cloudformation:us-east-1:047081014315:stack/foo-stack/c7aa0ab0-0f21-11ec-ba25-129d8c0b3757", + "name": "foo-stack", + "notification_arns": null, + "on_failure": null, + "outputs": null, + "parameters": { + "VPCCidr": "10.0.0.0/16" + }, + "policy_body": null, + "policy_url": null, + "tags": null, + "template_body": "{\"Parameters\":{\"VPCCidr\":{\"Default\":\"10.0.0.0/16\",\"Description\":\"Enter the CIDR block for the VPC. Default is 10.0.0.0/16.\",\"Type\":\"String\"}},\"Resources\":{\"myVpc\":{\"Properties\":{\"CidrBlock\":{\"Ref\":\"VPCCidr\"},\"Tags\":[{\"Key\":\"Name\",\"Value\":\"Primary_CF_VPC\"}]},\"Type\":\"AWS::EC2::VPC\"}}}", + "template_url": null, + "timeout_in_minutes": null, + "timeouts": {} + }, + { + "capabilities": [ + "CAPABILITY_NAMED_IAM" + ], + "disable_rollback": false, + "iam_role_arn": "", + "id": "arn:aws:cloudformation:us-east-1:047081014315:stack/bar-stack/c7a96e70-0f21-11ec-bd2a-0a2d95c2b2ab", + "name": "bar-stack", + "notification_arns": null, + "on_failure": null, + "outputs": null, + "parameters": null, + "policy_body": null, + "policy_url": null, + "tags": null, + "template_body": "Resources:\n myUser:\n Type: AWS::IAM::User\n Properties:\n UserName: \"bar_cfn\"\n", + "template_url": null, + "timeout_in_minutes": null, + "timeouts": {} + } +] \ No newline at end of file diff --git a/pkg/remote/test/aws_cloudformation_stack_multiple/terraform.tf b/enumeration/remote/test/aws_cloudformation_stack_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_cloudformation_stack_multiple/terraform.tf rename to enumeration/remote/test/aws_cloudformation_stack_multiple/terraform.tf diff --git a/pkg/remote/test/aws_cloudfront_distribution_single/aws_cloudfront_distribution-E1M9CNS0XSHI19.res.golden.json b/enumeration/remote/test/aws_cloudfront_distribution_single/aws_cloudfront_distribution-E1M9CNS0XSHI19.res.golden.json similarity index 100% rename from pkg/remote/test/aws_cloudfront_distribution_single/aws_cloudfront_distribution-E1M9CNS0XSHI19.res.golden.json rename to enumeration/remote/test/aws_cloudfront_distribution_single/aws_cloudfront_distribution-E1M9CNS0XSHI19.res.golden.json diff --git a/pkg/remote/test/aws_cloudfront_distribution_single/results.golden.json b/enumeration/remote/test/aws_cloudfront_distribution_single/results.golden.json similarity index 100% rename from pkg/remote/test/aws_cloudfront_distribution_single/results.golden.json rename to enumeration/remote/test/aws_cloudfront_distribution_single/results.golden.json diff --git a/pkg/remote/test/aws_cloudfront_distribution_single/terraform.tf b/enumeration/remote/test/aws_cloudfront_distribution_single/terraform.tf similarity index 100% rename from pkg/remote/test/aws_cloudfront_distribution_single/terraform.tf rename to enumeration/remote/test/aws_cloudfront_distribution_single/terraform.tf diff --git a/pkg/remote/test/aws_default_vpc/aws_default_vpc-vpc-a8c5d4c1.res.golden.json b/enumeration/remote/test/aws_default_vpc/aws_default_vpc-vpc-a8c5d4c1.res.golden.json similarity index 100% rename from pkg/remote/test/aws_default_vpc/aws_default_vpc-vpc-a8c5d4c1.res.golden.json rename to enumeration/remote/test/aws_default_vpc/aws_default_vpc-vpc-a8c5d4c1.res.golden.json diff --git a/pkg/remote/test/aws_default_vpc/results.golden.json b/enumeration/remote/test/aws_default_vpc/results.golden.json similarity index 100% rename from pkg/remote/test/aws_default_vpc/results.golden.json rename to enumeration/remote/test/aws_default_vpc/results.golden.json diff --git a/pkg/remote/test/aws_default_vpc/terraform.tf b/enumeration/remote/test/aws_default_vpc/terraform.tf similarity index 100% rename from pkg/remote/test/aws_default_vpc/terraform.tf rename to enumeration/remote/test/aws_default_vpc/terraform.tf diff --git a/pkg/remote/test/aws_dynamodb_table_multiple/aws_dynamodb_table-GameScores-GameScores.res.golden.json b/enumeration/remote/test/aws_dynamodb_table_multiple/aws_dynamodb_table-GameScores-GameScores.res.golden.json similarity index 100% rename from pkg/remote/test/aws_dynamodb_table_multiple/aws_dynamodb_table-GameScores-GameScores.res.golden.json rename to enumeration/remote/test/aws_dynamodb_table_multiple/aws_dynamodb_table-GameScores-GameScores.res.golden.json diff --git a/pkg/remote/test/aws_dynamodb_table_multiple/aws_dynamodb_table-example-example.res.golden.json b/enumeration/remote/test/aws_dynamodb_table_multiple/aws_dynamodb_table-example-example.res.golden.json similarity index 100% rename from pkg/remote/test/aws_dynamodb_table_multiple/aws_dynamodb_table-example-example.res.golden.json rename to enumeration/remote/test/aws_dynamodb_table_multiple/aws_dynamodb_table-example-example.res.golden.json diff --git a/pkg/remote/test/aws_dynamodb_table_multiple/results.golden.json b/enumeration/remote/test/aws_dynamodb_table_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_dynamodb_table_multiple/results.golden.json rename to enumeration/remote/test/aws_dynamodb_table_multiple/results.golden.json diff --git a/pkg/remote/test/aws_ebs_encryption_by_default_list/aws_ebs_encryption_by_default-ebs_encryption_default.res.golden.json b/enumeration/remote/test/aws_ebs_encryption_by_default_list/aws_ebs_encryption_by_default-ebs_encryption_default.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ebs_encryption_by_default_list/aws_ebs_encryption_by_default-ebs_encryption_default.res.golden.json rename to enumeration/remote/test/aws_ebs_encryption_by_default_list/aws_ebs_encryption_by_default-ebs_encryption_default.res.golden.json diff --git a/pkg/remote/test/aws_ebs_encryption_by_default_list/results.golden.json b/enumeration/remote/test/aws_ebs_encryption_by_default_list/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ebs_encryption_by_default_list/results.golden.json rename to enumeration/remote/test/aws_ebs_encryption_by_default_list/results.golden.json diff --git a/pkg/remote/test/aws_ec2_ami_multiple/aws_ami-ami-025962fd8b456731f.res.golden.json b/enumeration/remote/test/aws_ec2_ami_multiple/aws_ami-ami-025962fd8b456731f.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_ami_multiple/aws_ami-ami-025962fd8b456731f.res.golden.json rename to enumeration/remote/test/aws_ec2_ami_multiple/aws_ami-ami-025962fd8b456731f.res.golden.json diff --git a/pkg/remote/test/aws_ec2_ami_multiple/aws_ami-ami-03a578b46f4c3081b.res.golden.json b/enumeration/remote/test/aws_ec2_ami_multiple/aws_ami-ami-03a578b46f4c3081b.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_ami_multiple/aws_ami-ami-03a578b46f4c3081b.res.golden.json rename to enumeration/remote/test/aws_ec2_ami_multiple/aws_ami-ami-03a578b46f4c3081b.res.golden.json diff --git a/pkg/remote/test/aws_ec2_ami_multiple/results.golden.json b/enumeration/remote/test/aws_ec2_ami_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_ami_multiple/results.golden.json rename to enumeration/remote/test/aws_ec2_ami_multiple/results.golden.json diff --git a/pkg/remote/test/aws_ec2_default_network_acl/aws_default_network_acl-acl-e88ee595.res.golden.json b/enumeration/remote/test/aws_ec2_default_network_acl/aws_default_network_acl-acl-e88ee595.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_default_network_acl/aws_default_network_acl-acl-e88ee595.res.golden.json rename to enumeration/remote/test/aws_ec2_default_network_acl/aws_default_network_acl-acl-e88ee595.res.golden.json diff --git a/enumeration/remote/test/aws_ec2_default_network_acl/results.golden.json b/enumeration/remote/test/aws_ec2_default_network_acl/results.golden.json new file mode 100755 index 00000000..1df22de5 --- /dev/null +++ b/enumeration/remote/test/aws_ec2_default_network_acl/results.golden.json @@ -0,0 +1,44 @@ +[ + { + "arn": "arn:aws:ec2:us-east-1:929327065333:network-acl/acl-e88ee595", + "default_network_acl_id": null, + "egress": [ + { + "action": "allow", + "cidr_block": "0.0.0.0/0", + "from_port": 0, + "icmp_code": 0, + "icmp_type": 0, + "ipv6_cidr_block": "", + "protocol": "17", + "rule_no": 100, + "to_port": 0 + } + ], + "id": "acl-e88ee595", + "ingress": [ + { + "action": "allow", + "cidr_block": "172.31.0.0/16", + "from_port": 0, + "icmp_code": 0, + "icmp_type": 0, + "ipv6_cidr_block": "", + "protocol": "6", + "rule_no": 100, + "to_port": 0 + } + ], + "owner_id": "929327065333", + "subnet_ids": [ + "subnet-0000ae0e", + "subnet-1138032f", + "subnet-44fe0c65", + "subnet-65e16628", + "subnet-ad17e3cb", + "subnet-afa656f0" + ], + "tags": null, + "vpc_id": "vpc-41d1d13b" + } +] \ No newline at end of file diff --git a/pkg/remote/test/aws_ec2_default_route_table_single/aws_default_route_table-rtb-0eabf071c709c0976-vpc-0b4a6b3536da20ecd.res.golden.json b/enumeration/remote/test/aws_ec2_default_route_table_single/aws_default_route_table-rtb-0eabf071c709c0976-vpc-0b4a6b3536da20ecd.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_default_route_table_single/aws_default_route_table-rtb-0eabf071c709c0976-vpc-0b4a6b3536da20ecd.res.golden.json rename to enumeration/remote/test/aws_ec2_default_route_table_single/aws_default_route_table-rtb-0eabf071c709c0976-vpc-0b4a6b3536da20ecd.res.golden.json diff --git a/pkg/remote/test/aws_ec2_default_route_table_single/results.golden.json b/enumeration/remote/test/aws_ec2_default_route_table_single/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_default_route_table_single/results.golden.json rename to enumeration/remote/test/aws_ec2_default_route_table_single/results.golden.json diff --git a/pkg/remote/test/aws_ec2_default_route_table_single/terraform.tf b/enumeration/remote/test/aws_ec2_default_route_table_single/terraform.tf similarity index 100% rename from pkg/remote/test/aws_ec2_default_route_table_single/terraform.tf rename to enumeration/remote/test/aws_ec2_default_route_table_single/terraform.tf diff --git a/pkg/remote/test/aws_ec2_default_subnet_multiple/aws_default_subnet-subnet-44fe0c65.res.golden.json b/enumeration/remote/test/aws_ec2_default_subnet_multiple/aws_default_subnet-subnet-44fe0c65.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_default_subnet_multiple/aws_default_subnet-subnet-44fe0c65.res.golden.json rename to enumeration/remote/test/aws_ec2_default_subnet_multiple/aws_default_subnet-subnet-44fe0c65.res.golden.json diff --git a/pkg/remote/test/aws_ec2_default_subnet_multiple/aws_default_subnet-subnet-65e16628.res.golden.json b/enumeration/remote/test/aws_ec2_default_subnet_multiple/aws_default_subnet-subnet-65e16628.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_default_subnet_multiple/aws_default_subnet-subnet-65e16628.res.golden.json rename to enumeration/remote/test/aws_ec2_default_subnet_multiple/aws_default_subnet-subnet-65e16628.res.golden.json diff --git a/pkg/remote/test/aws_ec2_default_subnet_multiple/aws_default_subnet-subnet-afa656f0.res.golden.json b/enumeration/remote/test/aws_ec2_default_subnet_multiple/aws_default_subnet-subnet-afa656f0.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_default_subnet_multiple/aws_default_subnet-subnet-afa656f0.res.golden.json rename to enumeration/remote/test/aws_ec2_default_subnet_multiple/aws_default_subnet-subnet-afa656f0.res.golden.json diff --git a/pkg/remote/test/aws_ec2_default_subnet_multiple/results.golden.json b/enumeration/remote/test/aws_ec2_default_subnet_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_default_subnet_multiple/results.golden.json rename to enumeration/remote/test/aws_ec2_default_subnet_multiple/results.golden.json diff --git a/pkg/remote/test/aws_ec2_default_subnet_multiple/terraform.tf b/enumeration/remote/test/aws_ec2_default_subnet_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_ec2_default_subnet_multiple/terraform.tf rename to enumeration/remote/test/aws_ec2_default_subnet_multiple/terraform.tf diff --git a/pkg/remote/test/aws_ec2_ebs_snapshot_multiple/aws_ebs_snapshot-snap-00672558cecd93a61.res.golden.json b/enumeration/remote/test/aws_ec2_ebs_snapshot_multiple/aws_ebs_snapshot-snap-00672558cecd93a61.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_ebs_snapshot_multiple/aws_ebs_snapshot-snap-00672558cecd93a61.res.golden.json rename to enumeration/remote/test/aws_ec2_ebs_snapshot_multiple/aws_ebs_snapshot-snap-00672558cecd93a61.res.golden.json diff --git a/pkg/remote/test/aws_ec2_ebs_snapshot_multiple/aws_ebs_snapshot-snap-0c509a2a880d95a39.res.golden.json b/enumeration/remote/test/aws_ec2_ebs_snapshot_multiple/aws_ebs_snapshot-snap-0c509a2a880d95a39.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_ebs_snapshot_multiple/aws_ebs_snapshot-snap-0c509a2a880d95a39.res.golden.json rename to enumeration/remote/test/aws_ec2_ebs_snapshot_multiple/aws_ebs_snapshot-snap-0c509a2a880d95a39.res.golden.json diff --git a/pkg/remote/test/aws_ec2_ebs_snapshot_multiple/results.golden.json b/enumeration/remote/test/aws_ec2_ebs_snapshot_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_ebs_snapshot_multiple/results.golden.json rename to enumeration/remote/test/aws_ec2_ebs_snapshot_multiple/results.golden.json diff --git a/pkg/remote/test/aws_ec2_ebs_volume_multiple/aws_ebs_volume-vol-01ddc91d3d9d1318b.res.golden.json b/enumeration/remote/test/aws_ec2_ebs_volume_multiple/aws_ebs_volume-vol-01ddc91d3d9d1318b.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_ebs_volume_multiple/aws_ebs_volume-vol-01ddc91d3d9d1318b.res.golden.json rename to enumeration/remote/test/aws_ec2_ebs_volume_multiple/aws_ebs_volume-vol-01ddc91d3d9d1318b.res.golden.json diff --git a/pkg/remote/test/aws_ec2_ebs_volume_multiple/aws_ebs_volume-vol-081c7272a57a09db1.res.golden.json b/enumeration/remote/test/aws_ec2_ebs_volume_multiple/aws_ebs_volume-vol-081c7272a57a09db1.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_ebs_volume_multiple/aws_ebs_volume-vol-081c7272a57a09db1.res.golden.json rename to enumeration/remote/test/aws_ec2_ebs_volume_multiple/aws_ebs_volume-vol-081c7272a57a09db1.res.golden.json diff --git a/pkg/remote/test/aws_ec2_ebs_volume_multiple/results.golden.json b/enumeration/remote/test/aws_ec2_ebs_volume_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_ebs_volume_multiple/results.golden.json rename to enumeration/remote/test/aws_ec2_ebs_volume_multiple/results.golden.json diff --git a/pkg/remote/test/aws_ec2_eip_association_single/aws_eip_association-eipassoc-0e9a7356e30f0c3d1.res.golden.json b/enumeration/remote/test/aws_ec2_eip_association_single/aws_eip_association-eipassoc-0e9a7356e30f0c3d1.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_eip_association_single/aws_eip_association-eipassoc-0e9a7356e30f0c3d1.res.golden.json rename to enumeration/remote/test/aws_ec2_eip_association_single/aws_eip_association-eipassoc-0e9a7356e30f0c3d1.res.golden.json diff --git a/pkg/remote/test/aws_ec2_eip_association_single/results.golden.json b/enumeration/remote/test/aws_ec2_eip_association_single/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_eip_association_single/results.golden.json rename to enumeration/remote/test/aws_ec2_eip_association_single/results.golden.json diff --git a/pkg/remote/test/aws_ec2_eip_multiple/aws_eip-eipalloc-017d5267e4dda73f1.res.golden.json b/enumeration/remote/test/aws_ec2_eip_multiple/aws_eip-eipalloc-017d5267e4dda73f1.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_eip_multiple/aws_eip-eipalloc-017d5267e4dda73f1.res.golden.json rename to enumeration/remote/test/aws_ec2_eip_multiple/aws_eip-eipalloc-017d5267e4dda73f1.res.golden.json diff --git a/pkg/remote/test/aws_ec2_eip_multiple/aws_eip-eipalloc-0cf714dc097c992cc.res.golden.json b/enumeration/remote/test/aws_ec2_eip_multiple/aws_eip-eipalloc-0cf714dc097c992cc.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_eip_multiple/aws_eip-eipalloc-0cf714dc097c992cc.res.golden.json rename to enumeration/remote/test/aws_ec2_eip_multiple/aws_eip-eipalloc-0cf714dc097c992cc.res.golden.json diff --git a/pkg/remote/test/aws_ec2_eip_multiple/results.golden.json b/enumeration/remote/test/aws_ec2_eip_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_eip_multiple/results.golden.json rename to enumeration/remote/test/aws_ec2_eip_multiple/results.golden.json diff --git a/pkg/remote/test/aws_ec2_instance_multiple/aws_instance-i-010376047a71419f1.res.golden.json b/enumeration/remote/test/aws_ec2_instance_multiple/aws_instance-i-010376047a71419f1.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_instance_multiple/aws_instance-i-010376047a71419f1.res.golden.json rename to enumeration/remote/test/aws_ec2_instance_multiple/aws_instance-i-010376047a71419f1.res.golden.json diff --git a/pkg/remote/test/aws_ec2_instance_multiple/aws_instance-i-0d3650a23f4e45dc0.res.golden.json b/enumeration/remote/test/aws_ec2_instance_multiple/aws_instance-i-0d3650a23f4e45dc0.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_instance_multiple/aws_instance-i-0d3650a23f4e45dc0.res.golden.json rename to enumeration/remote/test/aws_ec2_instance_multiple/aws_instance-i-0d3650a23f4e45dc0.res.golden.json diff --git a/pkg/remote/test/aws_ec2_instance_multiple/results.golden.json b/enumeration/remote/test/aws_ec2_instance_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_instance_multiple/results.golden.json rename to enumeration/remote/test/aws_ec2_instance_multiple/results.golden.json diff --git a/pkg/remote/test/aws_ec2_instance_terminated/aws_instance-i-0a3a7ed51ae2b4fa0.res.golden.json b/enumeration/remote/test/aws_ec2_instance_terminated/aws_instance-i-0a3a7ed51ae2b4fa0.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_instance_terminated/aws_instance-i-0a3a7ed51ae2b4fa0.res.golden.json rename to enumeration/remote/test/aws_ec2_instance_terminated/aws_instance-i-0a3a7ed51ae2b4fa0.res.golden.json diff --git a/pkg/remote/test/aws_ec2_instance_terminated/aws_instance-i-0e1543baf4f2cd990.res.golden.json b/enumeration/remote/test/aws_ec2_instance_terminated/aws_instance-i-0e1543baf4f2cd990.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_instance_terminated/aws_instance-i-0e1543baf4f2cd990.res.golden.json rename to enumeration/remote/test/aws_ec2_instance_terminated/aws_instance-i-0e1543baf4f2cd990.res.golden.json diff --git a/pkg/remote/test/aws_ec2_instance_terminated/results.golden.json b/enumeration/remote/test/aws_ec2_instance_terminated/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_instance_terminated/results.golden.json rename to enumeration/remote/test/aws_ec2_instance_terminated/results.golden.json diff --git a/pkg/remote/test/aws_ec2_internet_gateway_multiple/aws_internet_gateway-igw-0184eb41aadc62d1c.res.golden.json b/enumeration/remote/test/aws_ec2_internet_gateway_multiple/aws_internet_gateway-igw-0184eb41aadc62d1c.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_internet_gateway_multiple/aws_internet_gateway-igw-0184eb41aadc62d1c.res.golden.json rename to enumeration/remote/test/aws_ec2_internet_gateway_multiple/aws_internet_gateway-igw-0184eb41aadc62d1c.res.golden.json diff --git a/pkg/remote/test/aws_ec2_internet_gateway_multiple/aws_internet_gateway-igw-047b487f5c60fca99.res.golden.json b/enumeration/remote/test/aws_ec2_internet_gateway_multiple/aws_internet_gateway-igw-047b487f5c60fca99.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_internet_gateway_multiple/aws_internet_gateway-igw-047b487f5c60fca99.res.golden.json rename to enumeration/remote/test/aws_ec2_internet_gateway_multiple/aws_internet_gateway-igw-047b487f5c60fca99.res.golden.json diff --git a/pkg/remote/test/aws_ec2_internet_gateway_multiple/results.golden.json b/enumeration/remote/test/aws_ec2_internet_gateway_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_internet_gateway_multiple/results.golden.json rename to enumeration/remote/test/aws_ec2_internet_gateway_multiple/results.golden.json diff --git a/pkg/remote/test/aws_ec2_internet_gateway_multiple/terraform.tf b/enumeration/remote/test/aws_ec2_internet_gateway_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_ec2_internet_gateway_multiple/terraform.tf rename to enumeration/remote/test/aws_ec2_internet_gateway_multiple/terraform.tf diff --git a/pkg/remote/test/aws_ec2_key_pair_multiple/aws_key_pair-bar.res.golden.json b/enumeration/remote/test/aws_ec2_key_pair_multiple/aws_key_pair-bar.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_key_pair_multiple/aws_key_pair-bar.res.golden.json rename to enumeration/remote/test/aws_ec2_key_pair_multiple/aws_key_pair-bar.res.golden.json diff --git a/pkg/remote/test/aws_ec2_key_pair_multiple/aws_key_pair-test.res.golden.json b/enumeration/remote/test/aws_ec2_key_pair_multiple/aws_key_pair-test.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_key_pair_multiple/aws_key_pair-test.res.golden.json rename to enumeration/remote/test/aws_ec2_key_pair_multiple/aws_key_pair-test.res.golden.json diff --git a/pkg/remote/test/aws_ec2_key_pair_multiple/results.golden.json b/enumeration/remote/test/aws_ec2_key_pair_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_key_pair_multiple/results.golden.json rename to enumeration/remote/test/aws_ec2_key_pair_multiple/results.golden.json diff --git a/pkg/remote/test/aws_ec2_nat_gateway_single/aws_nat_gateway-nat-0a5408508b19ef490.res.golden.json b/enumeration/remote/test/aws_ec2_nat_gateway_single/aws_nat_gateway-nat-0a5408508b19ef490.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_nat_gateway_single/aws_nat_gateway-nat-0a5408508b19ef490.res.golden.json rename to enumeration/remote/test/aws_ec2_nat_gateway_single/aws_nat_gateway-nat-0a5408508b19ef490.res.golden.json diff --git a/pkg/remote/test/aws_ec2_nat_gateway_single/results.golden.json b/enumeration/remote/test/aws_ec2_nat_gateway_single/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_nat_gateway_single/results.golden.json rename to enumeration/remote/test/aws_ec2_nat_gateway_single/results.golden.json diff --git a/pkg/remote/test/aws_ec2_nat_gateway_single/terraform.tf b/enumeration/remote/test/aws_ec2_nat_gateway_single/terraform.tf similarity index 100% rename from pkg/remote/test/aws_ec2_nat_gateway_single/terraform.tf rename to enumeration/remote/test/aws_ec2_nat_gateway_single/terraform.tf diff --git a/pkg/remote/test/aws_ec2_network_acl/aws_network_acl-acl-043880b4682d2366b.res.golden.json b/enumeration/remote/test/aws_ec2_network_acl/aws_network_acl-acl-043880b4682d2366b.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_network_acl/aws_network_acl-acl-043880b4682d2366b.res.golden.json rename to enumeration/remote/test/aws_ec2_network_acl/aws_network_acl-acl-043880b4682d2366b.res.golden.json diff --git a/pkg/remote/test/aws_ec2_network_acl/aws_network_acl-acl-07a565dbe518c0713.res.golden.json b/enumeration/remote/test/aws_ec2_network_acl/aws_network_acl-acl-07a565dbe518c0713.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_network_acl/aws_network_acl-acl-07a565dbe518c0713.res.golden.json rename to enumeration/remote/test/aws_ec2_network_acl/aws_network_acl-acl-07a565dbe518c0713.res.golden.json diff --git a/pkg/remote/test/aws_ec2_network_acl/results.golden.json b/enumeration/remote/test/aws_ec2_network_acl/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_network_acl/results.golden.json rename to enumeration/remote/test/aws_ec2_network_acl/results.golden.json diff --git a/pkg/remote/test/aws_ec2_network_acl/terraform.tf b/enumeration/remote/test/aws_ec2_network_acl/terraform.tf similarity index 100% rename from pkg/remote/test/aws_ec2_network_acl/terraform.tf rename to enumeration/remote/test/aws_ec2_network_acl/terraform.tf diff --git a/pkg/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-2289824980-false-acl-0ad6d657494d17ee2-200.res.golden.json b/enumeration/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-2289824980-false-acl-0ad6d657494d17ee2-200.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-2289824980-false-acl-0ad6d657494d17ee2-200.res.golden.json rename to enumeration/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-2289824980-false-acl-0ad6d657494d17ee2-200.res.golden.json diff --git a/pkg/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-246660311-true-acl-0ad6d657494d17ee2-100.res.golden.json b/enumeration/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-246660311-true-acl-0ad6d657494d17ee2-100.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-246660311-true-acl-0ad6d657494d17ee2-100.res.golden.json rename to enumeration/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-246660311-true-acl-0ad6d657494d17ee2-100.res.golden.json diff --git a/pkg/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-4268384215-true-acl-0de54ef59074b622e-100.res.golden.json b/enumeration/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-4268384215-true-acl-0de54ef59074b622e-100.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-4268384215-true-acl-0de54ef59074b622e-100.res.golden.json rename to enumeration/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-4268384215-true-acl-0de54ef59074b622e-100.res.golden.json diff --git a/pkg/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-4293207588-false-acl-0ad6d657494d17ee2-100.res.golden.json b/enumeration/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-4293207588-false-acl-0ad6d657494d17ee2-100.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-4293207588-false-acl-0ad6d657494d17ee2-100.res.golden.json rename to enumeration/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-4293207588-false-acl-0ad6d657494d17ee2-100.res.golden.json diff --git a/pkg/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-515082162-false-acl-0de54ef59074b622e-100.res.golden.json b/enumeration/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-515082162-false-acl-0de54ef59074b622e-100.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-515082162-false-acl-0de54ef59074b622e-100.res.golden.json rename to enumeration/remote/test/aws_ec2_network_acl_rule/aws_network_acl_rule-nacl-515082162-false-acl-0de54ef59074b622e-100.res.golden.json diff --git a/enumeration/remote/test/aws_ec2_network_acl_rule/results.golden.json b/enumeration/remote/test/aws_ec2_network_acl_rule/results.golden.json new file mode 100755 index 00000000..8c878893 --- /dev/null +++ b/enumeration/remote/test/aws_ec2_network_acl_rule/results.golden.json @@ -0,0 +1,72 @@ +[ + { + "cidr_block": "172.31.0.0/16", + "egress": true, + "from_port": 80, + "icmp_code": null, + "icmp_type": null, + "id": "nacl-246660311", + "ipv6_cidr_block": "", + "network_acl_id": "acl-0ad6d657494d17ee2", + "protocol": "udp", + "rule_action": "allow", + "rule_number": 100, + "to_port": 80 + }, + { + "cidr_block": "", + "egress": false, + "from_port": 80, + "icmp_code": null, + "icmp_type": null, + "id": "nacl-2289824980", + "ipv6_cidr_block": "::/0", + "network_acl_id": "acl-0ad6d657494d17ee2", + "protocol": "tcp", + "rule_action": "allow", + "rule_number": 200, + "to_port": 80 + }, + { + "cidr_block": "172.31.0.0/16", + "egress": false, + "from_port": 80, + "icmp_code": null, + "icmp_type": null, + "id": "nacl-515082162", + "ipv6_cidr_block": "", + "network_acl_id": "acl-0de54ef59074b622e", + "protocol": "udp", + "rule_action": "allow", + "rule_number": 100, + "to_port": 80 + }, + { + "cidr_block": "172.31.0.0/16", + "egress": true, + "from_port": 80, + "icmp_code": null, + "icmp_type": null, + "id": "nacl-4268384215", + "ipv6_cidr_block": "", + "network_acl_id": "acl-0de54ef59074b622e", + "protocol": "udp", + "rule_action": "allow", + "rule_number": 100, + "to_port": 80 + }, + { + "cidr_block": "172.31.0.0/16", + "egress": false, + "from_port": 80, + "icmp_code": null, + "icmp_type": null, + "id": "nacl-4293207588", + "ipv6_cidr_block": "", + "network_acl_id": "acl-0ad6d657494d17ee2", + "protocol": "tcp", + "rule_action": "allow", + "rule_number": 100, + "to_port": 80 + } +] \ No newline at end of file diff --git a/pkg/remote/test/aws_ec2_network_acl_rule/terraform.tf b/enumeration/remote/test/aws_ec2_network_acl_rule/terraform.tf similarity index 100% rename from pkg/remote/test/aws_ec2_network_acl_rule/terraform.tf rename to enumeration/remote/test/aws_ec2_network_acl_rule/terraform.tf diff --git a/pkg/remote/test/aws_ec2_route_multiple/aws_route-r-179966490-10.0.0.0_16-.res.golden.json b/enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-179966490-10.0.0.0_16-.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_multiple/aws_route-r-179966490-10.0.0.0_16-.res.golden.json rename to enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-179966490-10.0.0.0_16-.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-0169b0937fd963ddc1080289494-0.0.0.0_0-rtb-0169b0937fd963ddc.res.golden.json b/enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-0169b0937fd963ddc1080289494-0.0.0.0_0-rtb-0169b0937fd963ddc.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-0169b0937fd963ddc1080289494-0.0.0.0_0-rtb-0169b0937fd963ddc.res.golden.json rename to enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-0169b0937fd963ddc1080289494-0.0.0.0_0-rtb-0169b0937fd963ddc.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-0169b0937fd963ddc179966490-10.0.0.0_16-rtb-0169b0937fd963ddc.res.golden.json b/enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-0169b0937fd963ddc179966490-10.0.0.0_16-rtb-0169b0937fd963ddc.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-0169b0937fd963ddc179966490-10.0.0.0_16-rtb-0169b0937fd963ddc.res.golden.json rename to enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-0169b0937fd963ddc179966490-10.0.0.0_16-rtb-0169b0937fd963ddc.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-0169b0937fd963ddc2750132062-___0-rtb-0169b0937fd963ddc.res.golden.json b/enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-0169b0937fd963ddc2750132062-___0-rtb-0169b0937fd963ddc.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-0169b0937fd963ddc2750132062-___0-rtb-0169b0937fd963ddc.res.golden.json rename to enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-0169b0937fd963ddc2750132062-___0-rtb-0169b0937fd963ddc.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-02780c485f0be93c51325105504-10.1.2.0_24-rtb-02780c485f0be93c5.res.golden.json b/enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-02780c485f0be93c51325105504-10.1.2.0_24-rtb-02780c485f0be93c5.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-02780c485f0be93c51325105504-10.1.2.0_24-rtb-02780c485f0be93c5.res.golden.json rename to enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-02780c485f0be93c51325105504-10.1.2.0_24-rtb-02780c485f0be93c5.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-02780c485f0be93c5179966490-10.0.0.0_16-rtb-02780c485f0be93c5.res.golden.json b/enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-02780c485f0be93c5179966490-10.0.0.0_16-rtb-02780c485f0be93c5.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-02780c485f0be93c5179966490-10.0.0.0_16-rtb-02780c485f0be93c5.res.golden.json rename to enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-02780c485f0be93c5179966490-10.0.0.0_16-rtb-02780c485f0be93c5.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-02780c485f0be93c53362780110-10.1.1.0_24-rtb-02780c485f0be93c5.res.golden.json b/enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-02780c485f0be93c53362780110-10.1.1.0_24-rtb-02780c485f0be93c5.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-02780c485f0be93c53362780110-10.1.1.0_24-rtb-02780c485f0be93c5.res.golden.json rename to enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-02780c485f0be93c53362780110-10.1.1.0_24-rtb-02780c485f0be93c5.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-096bdfb69309c54c3179966490-10.0.0.0_16-rtb-096bdfb69309c54c3.res.golden.json b/enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-096bdfb69309c54c3179966490-10.0.0.0_16-rtb-096bdfb69309c54c3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-096bdfb69309c54c3179966490-10.0.0.0_16-rtb-096bdfb69309c54c3.res.golden.json rename to enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-096bdfb69309c54c3179966490-10.0.0.0_16-rtb-096bdfb69309c54c3.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-096bdfb69309c54c3243279527-1.1.1.1_32-rtb-096bdfb69309c54c3.res.golden.json b/enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-096bdfb69309c54c3243279527-1.1.1.1_32-rtb-096bdfb69309c54c3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-096bdfb69309c54c3243279527-1.1.1.1_32-rtb-096bdfb69309c54c3.res.golden.json rename to enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-096bdfb69309c54c3243279527-1.1.1.1_32-rtb-096bdfb69309c54c3.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-096bdfb69309c54c32750132062-___0-rtb-096bdfb69309c54c3.res.golden.json b/enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-096bdfb69309c54c32750132062-___0-rtb-096bdfb69309c54c3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-096bdfb69309c54c32750132062-___0-rtb-096bdfb69309c54c3.res.golden.json rename to enumeration/remote/test/aws_ec2_route_multiple/aws_route-r-rtb-096bdfb69309c54c32750132062-___0-rtb-096bdfb69309c54c3.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_multiple/results.golden.json b/enumeration/remote/test/aws_ec2_route_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_multiple/results.golden.json rename to enumeration/remote/test/aws_ec2_route_multiple/results.golden.json diff --git a/pkg/remote/test/aws_ec2_route_multiple/terraform.tf b/enumeration/remote/test/aws_ec2_route_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_ec2_route_multiple/terraform.tf rename to enumeration/remote/test/aws_ec2_route_multiple/terraform.tf diff --git a/pkg/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-01957791b2cfe6ea4-rtb-05aa6c5673311a17b.res.golden.json b/enumeration/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-01957791b2cfe6ea4-rtb-05aa6c5673311a17b.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-01957791b2cfe6ea4-rtb-05aa6c5673311a17b.res.golden.json rename to enumeration/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-01957791b2cfe6ea4-rtb-05aa6c5673311a17b.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-0809598f92dbec03b-rtb-05aa6c5673311a17b.res.golden.json b/enumeration/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-0809598f92dbec03b-rtb-05aa6c5673311a17b.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-0809598f92dbec03b-rtb-05aa6c5673311a17b.res.golden.json rename to enumeration/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-0809598f92dbec03b-rtb-05aa6c5673311a17b.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-0a79ccacfceb4944b-rtb-09df7cc9d16de9f8f.res.golden.json b/enumeration/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-0a79ccacfceb4944b-rtb-09df7cc9d16de9f8f.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-0a79ccacfceb4944b-rtb-09df7cc9d16de9f8f.res.golden.json rename to enumeration/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-0a79ccacfceb4944b-rtb-09df7cc9d16de9f8f.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-0b4f97ea57490e213-rtb-05aa6c5673311a17b.res.golden.json b/enumeration/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-0b4f97ea57490e213-rtb-05aa6c5673311a17b.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-0b4f97ea57490e213-rtb-05aa6c5673311a17b.res.golden.json rename to enumeration/remote/test/aws_ec2_route_table_association_multiple/aws_route_table_association-rtbassoc-0b4f97ea57490e213-rtb-05aa6c5673311a17b.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_table_association_multiple/results.golden.json b/enumeration/remote/test/aws_ec2_route_table_association_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_table_association_multiple/results.golden.json rename to enumeration/remote/test/aws_ec2_route_table_association_multiple/results.golden.json diff --git a/pkg/remote/test/aws_ec2_route_table_multiple/aws_route_table-rtb-0002ac731f6fdea55.res.golden.json b/enumeration/remote/test/aws_ec2_route_table_multiple/aws_route_table-rtb-0002ac731f6fdea55.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_table_multiple/aws_route_table-rtb-0002ac731f6fdea55.res.golden.json rename to enumeration/remote/test/aws_ec2_route_table_multiple/aws_route_table-rtb-0002ac731f6fdea55.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_table_multiple/aws_route_table-rtb-08b7b71af15e183ce.res.golden.json b/enumeration/remote/test/aws_ec2_route_table_multiple/aws_route_table-rtb-08b7b71af15e183ce.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_table_multiple/aws_route_table-rtb-08b7b71af15e183ce.res.golden.json rename to enumeration/remote/test/aws_ec2_route_table_multiple/aws_route_table-rtb-08b7b71af15e183ce.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_table_multiple/aws_route_table-rtb-0c55d55593f33fbac.res.golden.json b/enumeration/remote/test/aws_ec2_route_table_multiple/aws_route_table-rtb-0c55d55593f33fbac.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_table_multiple/aws_route_table-rtb-0c55d55593f33fbac.res.golden.json rename to enumeration/remote/test/aws_ec2_route_table_multiple/aws_route_table-rtb-0c55d55593f33fbac.res.golden.json diff --git a/pkg/remote/test/aws_ec2_route_table_multiple/results.golden.json b/enumeration/remote/test/aws_ec2_route_table_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_route_table_multiple/results.golden.json rename to enumeration/remote/test/aws_ec2_route_table_multiple/results.golden.json diff --git a/pkg/remote/test/aws_ec2_route_table_multiple/terraform.tf b/enumeration/remote/test/aws_ec2_route_table_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_ec2_route_table_multiple/terraform.tf rename to enumeration/remote/test/aws_ec2_route_table_multiple/terraform.tf diff --git a/pkg/remote/test/aws_ec2_subnet_multiple/aws_subnet-subnet-05810d3f933925f6d.res.golden.json b/enumeration/remote/test/aws_ec2_subnet_multiple/aws_subnet-subnet-05810d3f933925f6d.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_subnet_multiple/aws_subnet-subnet-05810d3f933925f6d.res.golden.json rename to enumeration/remote/test/aws_ec2_subnet_multiple/aws_subnet-subnet-05810d3f933925f6d.res.golden.json diff --git a/pkg/remote/test/aws_ec2_subnet_multiple/aws_subnet-subnet-0b13f1e0eacf67424.res.golden.json b/enumeration/remote/test/aws_ec2_subnet_multiple/aws_subnet-subnet-0b13f1e0eacf67424.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_subnet_multiple/aws_subnet-subnet-0b13f1e0eacf67424.res.golden.json rename to enumeration/remote/test/aws_ec2_subnet_multiple/aws_subnet-subnet-0b13f1e0eacf67424.res.golden.json diff --git a/pkg/remote/test/aws_ec2_subnet_multiple/aws_subnet-subnet-0c9b78001fe186e22.res.golden.json b/enumeration/remote/test/aws_ec2_subnet_multiple/aws_subnet-subnet-0c9b78001fe186e22.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_subnet_multiple/aws_subnet-subnet-0c9b78001fe186e22.res.golden.json rename to enumeration/remote/test/aws_ec2_subnet_multiple/aws_subnet-subnet-0c9b78001fe186e22.res.golden.json diff --git a/pkg/remote/test/aws_ec2_subnet_multiple/results.golden.json b/enumeration/remote/test/aws_ec2_subnet_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ec2_subnet_multiple/results.golden.json rename to enumeration/remote/test/aws_ec2_subnet_multiple/results.golden.json diff --git a/pkg/remote/test/aws_ec2_subnet_multiple/terraform.tf b/enumeration/remote/test/aws_ec2_subnet_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_ec2_subnet_multiple/terraform.tf rename to enumeration/remote/test/aws_ec2_subnet_multiple/terraform.tf diff --git a/pkg/remote/test/aws_ecr_repository_multiple/aws_ecr_repository-bar.res.golden.json b/enumeration/remote/test/aws_ecr_repository_multiple/aws_ecr_repository-bar.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ecr_repository_multiple/aws_ecr_repository-bar.res.golden.json rename to enumeration/remote/test/aws_ecr_repository_multiple/aws_ecr_repository-bar.res.golden.json diff --git a/pkg/remote/test/aws_ecr_repository_multiple/aws_ecr_repository-test_ecr.res.golden.json b/enumeration/remote/test/aws_ecr_repository_multiple/aws_ecr_repository-test_ecr.res.golden.json similarity index 100% rename from pkg/remote/test/aws_ecr_repository_multiple/aws_ecr_repository-test_ecr.res.golden.json rename to enumeration/remote/test/aws_ecr_repository_multiple/aws_ecr_repository-test_ecr.res.golden.json diff --git a/pkg/remote/test/aws_ecr_repository_multiple/results.golden.json b/enumeration/remote/test/aws_ecr_repository_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_ecr_repository_multiple/results.golden.json rename to enumeration/remote/test/aws_ecr_repository_multiple/results.golden.json diff --git a/pkg/remote/test/aws_ecr_repository_multiple/terraform.tf b/enumeration/remote/test/aws_ecr_repository_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_ecr_repository_multiple/terraform.tf rename to enumeration/remote/test/aws_ecr_repository_multiple/terraform.tf diff --git a/pkg/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD223VWU32A-test-driftctl.res.golden.json b/enumeration/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD223VWU32A-test-driftctl.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD223VWU32A-test-driftctl.res.golden.json rename to enumeration/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD223VWU32A-test-driftctl.res.golden.json diff --git a/pkg/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD26EJME25D-test-driftctl2.res.golden.json b/enumeration/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD26EJME25D-test-driftctl2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD26EJME25D-test-driftctl2.res.golden.json rename to enumeration/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD26EJME25D-test-driftctl2.res.golden.json diff --git a/pkg/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD2QYI36UZP-test-driftctl.res.golden.json b/enumeration/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD2QYI36UZP-test-driftctl.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD2QYI36UZP-test-driftctl.res.golden.json rename to enumeration/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD2QYI36UZP-test-driftctl.res.golden.json diff --git a/pkg/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD2SWDFVVMG-test-driftctl2.res.golden.json b/enumeration/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD2SWDFVVMG-test-driftctl2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD2SWDFVVMG-test-driftctl2.res.golden.json rename to enumeration/remote/test/aws_iam_access_key_multiple/aws_iam_access_key-AKIA5QYBVVD2SWDFVVMG-test-driftctl2.res.golden.json diff --git a/pkg/remote/test/aws_iam_access_key_multiple/results.golden.json b/enumeration/remote/test/aws_iam_access_key_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_access_key_multiple/results.golden.json rename to enumeration/remote/test/aws_iam_access_key_multiple/results.golden.json diff --git a/pkg/remote/test/aws_iam_access_key_multiple/terraform.tf b/enumeration/remote/test/aws_iam_access_key_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_iam_access_key_multiple/terraform.tf rename to enumeration/remote/test/aws_iam_access_key_multiple/terraform.tf diff --git a/pkg/remote/test/aws_iam_policy_multiple/aws_iam_policy-arn_aws_iam__929327065333_policy_policy-0.res.golden.json b/enumeration/remote/test/aws_iam_policy_multiple/aws_iam_policy-arn_aws_iam__929327065333_policy_policy-0.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_policy_multiple/aws_iam_policy-arn_aws_iam__929327065333_policy_policy-0.res.golden.json rename to enumeration/remote/test/aws_iam_policy_multiple/aws_iam_policy-arn_aws_iam__929327065333_policy_policy-0.res.golden.json diff --git a/pkg/remote/test/aws_iam_policy_multiple/aws_iam_policy-arn_aws_iam__929327065333_policy_policy-1.res.golden.json b/enumeration/remote/test/aws_iam_policy_multiple/aws_iam_policy-arn_aws_iam__929327065333_policy_policy-1.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_policy_multiple/aws_iam_policy-arn_aws_iam__929327065333_policy_policy-1.res.golden.json rename to enumeration/remote/test/aws_iam_policy_multiple/aws_iam_policy-arn_aws_iam__929327065333_policy_policy-1.res.golden.json diff --git a/pkg/remote/test/aws_iam_policy_multiple/aws_iam_policy-arn_aws_iam__929327065333_policy_policy-2.res.golden.json b/enumeration/remote/test/aws_iam_policy_multiple/aws_iam_policy-arn_aws_iam__929327065333_policy_policy-2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_policy_multiple/aws_iam_policy-arn_aws_iam__929327065333_policy_policy-2.res.golden.json rename to enumeration/remote/test/aws_iam_policy_multiple/aws_iam_policy-arn_aws_iam__929327065333_policy_policy-2.res.golden.json diff --git a/pkg/remote/test/aws_iam_policy_multiple/results.golden.json b/enumeration/remote/test/aws_iam_policy_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_policy_multiple/results.golden.json rename to enumeration/remote/test/aws_iam_policy_multiple/results.golden.json diff --git a/pkg/remote/test/aws_iam_role_multiple/aws_iam_role-test_role_0.res.golden.json b/enumeration/remote/test/aws_iam_role_multiple/aws_iam_role-test_role_0.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_multiple/aws_iam_role-test_role_0.res.golden.json rename to enumeration/remote/test/aws_iam_role_multiple/aws_iam_role-test_role_0.res.golden.json diff --git a/pkg/remote/test/aws_iam_role_multiple/aws_iam_role-test_role_1.res.golden.json b/enumeration/remote/test/aws_iam_role_multiple/aws_iam_role-test_role_1.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_multiple/aws_iam_role-test_role_1.res.golden.json rename to enumeration/remote/test/aws_iam_role_multiple/aws_iam_role-test_role_1.res.golden.json diff --git a/pkg/remote/test/aws_iam_role_multiple/aws_iam_role-test_role_2.res.golden.json b/enumeration/remote/test/aws_iam_role_multiple/aws_iam_role-test_role_2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_multiple/aws_iam_role-test_role_2.res.golden.json rename to enumeration/remote/test/aws_iam_role_multiple/aws_iam_role-test_role_2.res.golden.json diff --git a/pkg/remote/test/aws_iam_role_multiple/results.golden.json b/enumeration/remote/test/aws_iam_role_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_multiple/results.golden.json rename to enumeration/remote/test/aws_iam_role_multiple/results.golden.json diff --git a/pkg/remote/test/aws_iam_role_multiple/terraform.tf b/enumeration/remote/test/aws_iam_role_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_iam_role_multiple/terraform.tf rename to enumeration/remote/test/aws_iam_role_multiple/terraform.tf diff --git a/pkg/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy-test-role-arn_aws_iam__929327065333_policy_test-policy-test-role.res.golden.json b/enumeration/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy-test-role-arn_aws_iam__929327065333_policy_test-policy-test-role.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy-test-role-arn_aws_iam__929327065333_policy_test-policy-test-role.res.golden.json rename to enumeration/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy-test-role-arn_aws_iam__929327065333_policy_test-policy-test-role.res.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy-test-role2-arn_aws_iam__929327065333_policy_test-policy-test-role2.res.golden.json b/enumeration/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy-test-role2-arn_aws_iam__929327065333_policy_test-policy-test-role2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy-test-role2-arn_aws_iam__929327065333_policy_test-policy-test-role2.res.golden.json rename to enumeration/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy-test-role2-arn_aws_iam__929327065333_policy_test-policy-test-role2.res.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy2-test-role-arn_aws_iam__929327065333_policy_test-policy2-test-role.res.golden.json b/enumeration/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy2-test-role-arn_aws_iam__929327065333_policy_test-policy2-test-role.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy2-test-role-arn_aws_iam__929327065333_policy_test-policy2-test-role.res.golden.json rename to enumeration/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy2-test-role-arn_aws_iam__929327065333_policy_test-policy2-test-role.res.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy2-test-role2-arn_aws_iam__929327065333_policy_test-policy2-test-role2.res.golden.json b/enumeration/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy2-test-role2-arn_aws_iam__929327065333_policy_test-policy2-test-role2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy2-test-role2-arn_aws_iam__929327065333_policy_test-policy2-test-role2.res.golden.json rename to enumeration/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy2-test-role2-arn_aws_iam__929327065333_policy_test-policy2-test-role2.res.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy3-test-role-arn_aws_iam__929327065333_policy_test-policy3-test-role.res.golden.json b/enumeration/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy3-test-role-arn_aws_iam__929327065333_policy_test-policy3-test-role.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy3-test-role-arn_aws_iam__929327065333_policy_test-policy3-test-role.res.golden.json rename to enumeration/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy3-test-role-arn_aws_iam__929327065333_policy_test-policy3-test-role.res.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy3-test-role2-arn_aws_iam__929327065333_policy_test-policy3-test-role2.res.golden.json b/enumeration/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy3-test-role2-arn_aws_iam__929327065333_policy_test-policy3-test-role2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy3-test-role2-arn_aws_iam__929327065333_policy_test-policy3-test-role2.res.golden.json rename to enumeration/remote/test/aws_iam_role_policy_attachment_multiple/aws_iam_role_policy_attachment-test-policy3-test-role2-arn_aws_iam__929327065333_policy_test-policy3-test-role2.res.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_attachment_multiple/main.tf b/enumeration/remote/test/aws_iam_role_policy_attachment_multiple/main.tf similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_attachment_multiple/main.tf rename to enumeration/remote/test/aws_iam_role_policy_attachment_multiple/main.tf diff --git a/pkg/remote/test/aws_iam_role_policy_attachment_multiple/results.golden.json b/enumeration/remote/test/aws_iam_role_policy_attachment_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_attachment_multiple/results.golden.json rename to enumeration/remote/test/aws_iam_role_policy_attachment_multiple/results.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_empty/schema.golden.json b/enumeration/remote/test/aws_iam_role_policy_empty/schema.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_empty/schema.golden.json rename to enumeration/remote/test/aws_iam_role_policy_empty/schema.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_0_policy-role0-0.res.golden.json b/enumeration/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_0_policy-role0-0.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_0_policy-role0-0.res.golden.json rename to enumeration/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_0_policy-role0-0.res.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_0_policy-role0-1.res.golden.json b/enumeration/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_0_policy-role0-1.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_0_policy-role0-1.res.golden.json rename to enumeration/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_0_policy-role0-1.res.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_0_policy-role0-2.res.golden.json b/enumeration/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_0_policy-role0-2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_0_policy-role0-2.res.golden.json rename to enumeration/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_0_policy-role0-2.res.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_1_policy-role1-0.res.golden.json b/enumeration/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_1_policy-role1-0.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_1_policy-role1-0.res.golden.json rename to enumeration/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_1_policy-role1-0.res.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_1_policy-role1-1.res.golden.json b/enumeration/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_1_policy-role1-1.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_1_policy-role1-1.res.golden.json rename to enumeration/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_1_policy-role1-1.res.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_1_policy-role1-2.res.golden.json b/enumeration/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_1_policy-role1-2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_1_policy-role1-2.res.golden.json rename to enumeration/remote/test/aws_iam_role_policy_multiple/aws_iam_role_policy-test_role_1_policy-role1-2.res.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_multiple/results.golden.json b/enumeration/remote/test/aws_iam_role_policy_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_multiple/results.golden.json rename to enumeration/remote/test/aws_iam_role_policy_multiple/results.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_multiple/schema.golden.json b/enumeration/remote/test/aws_iam_role_policy_multiple/schema.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_multiple/schema.golden.json rename to enumeration/remote/test/aws_iam_role_policy_multiple/schema.golden.json diff --git a/pkg/remote/test/aws_iam_role_policy_multiple/terraform.tf b/enumeration/remote/test/aws_iam_role_policy_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_iam_role_policy_multiple/terraform.tf rename to enumeration/remote/test/aws_iam_role_policy_multiple/terraform.tf diff --git a/pkg/remote/test/aws_iam_user_multiple/aws_iam_user-test-driftctl-0.res.golden.json b/enumeration/remote/test/aws_iam_user_multiple/aws_iam_user-test-driftctl-0.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_multiple/aws_iam_user-test-driftctl-0.res.golden.json rename to enumeration/remote/test/aws_iam_user_multiple/aws_iam_user-test-driftctl-0.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_multiple/aws_iam_user-test-driftctl-1.res.golden.json b/enumeration/remote/test/aws_iam_user_multiple/aws_iam_user-test-driftctl-1.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_multiple/aws_iam_user-test-driftctl-1.res.golden.json rename to enumeration/remote/test/aws_iam_user_multiple/aws_iam_user-test-driftctl-1.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_multiple/aws_iam_user-test-driftctl-2.res.golden.json b/enumeration/remote/test/aws_iam_user_multiple/aws_iam_user-test-driftctl-2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_multiple/aws_iam_user-test-driftctl-2.res.golden.json rename to enumeration/remote/test/aws_iam_user_multiple/aws_iam_user-test-driftctl-2.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_multiple/results.golden.json b/enumeration/remote/test/aws_iam_user_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_multiple/results.golden.json rename to enumeration/remote/test/aws_iam_user_multiple/results.golden.json diff --git a/pkg/remote/test/aws_iam_user_multiple/terraform.tf b/enumeration/remote/test/aws_iam_user_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_iam_user_multiple/terraform.tf rename to enumeration/remote/test/aws_iam_user_multiple/terraform.tf diff --git a/pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test-loadbalancer-arn_aws_iam__726421854799_policy_test-loadbalancer.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test-loadbalancer-arn_aws_iam__726421854799_policy_test-loadbalancer.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test-loadbalancer-arn_aws_iam__726421854799_policy_test-loadbalancer.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test-loadbalancer-arn_aws_iam__726421854799_policy_test-loadbalancer.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test-loadbalancer2-arn_aws_iam__726421854799_policy_test-loadbalancer2.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test-loadbalancer2-arn_aws_iam__726421854799_policy_test-loadbalancer2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test-loadbalancer2-arn_aws_iam__726421854799_policy_test-loadbalancer2.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test-loadbalancer2-arn_aws_iam__726421854799_policy_test-loadbalancer2.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test-loadbalancer3-arn_aws_iam__726421854799_policy_test-loadbalancer3.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test-loadbalancer3-arn_aws_iam__726421854799_policy_test-loadbalancer3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test-loadbalancer3-arn_aws_iam__726421854799_policy_test-loadbalancer3.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test-loadbalancer3-arn_aws_iam__726421854799_policy_test-loadbalancer3.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test2-loadbalancer-arn_aws_iam__726421854799_policy_test2-loadbalancer.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test2-loadbalancer-arn_aws_iam__726421854799_policy_test2-loadbalancer.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test2-loadbalancer-arn_aws_iam__726421854799_policy_test2-loadbalancer.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test2-loadbalancer-arn_aws_iam__726421854799_policy_test2-loadbalancer.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test2-loadbalancer2-arn_aws_iam__726421854799_policy_test2-loadbalancer2.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test2-loadbalancer2-arn_aws_iam__726421854799_policy_test2-loadbalancer2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test2-loadbalancer2-arn_aws_iam__726421854799_policy_test2-loadbalancer2.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test2-loadbalancer2-arn_aws_iam__726421854799_policy_test2-loadbalancer2.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test2-loadbalancer3-arn_aws_iam__726421854799_policy_test2-loadbalancer3.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test2-loadbalancer3-arn_aws_iam__726421854799_policy_test2-loadbalancer3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test2-loadbalancer3-arn_aws_iam__726421854799_policy_test2-loadbalancer3.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test2-loadbalancer3-arn_aws_iam__726421854799_policy_test2-loadbalancer3.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test3-loadbalancer-arn_aws_iam__726421854799_policy_test3-loadbalancer.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test3-loadbalancer-arn_aws_iam__726421854799_policy_test3-loadbalancer.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test3-loadbalancer-arn_aws_iam__726421854799_policy_test3-loadbalancer.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test3-loadbalancer-arn_aws_iam__726421854799_policy_test3-loadbalancer.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test3-loadbalancer2-arn_aws_iam__726421854799_policy_test3-loadbalancer2.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test3-loadbalancer2-arn_aws_iam__726421854799_policy_test3-loadbalancer2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test3-loadbalancer2-arn_aws_iam__726421854799_policy_test3-loadbalancer2.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test3-loadbalancer2-arn_aws_iam__726421854799_policy_test3-loadbalancer2.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test3-loadbalancer3-arn_aws_iam__726421854799_policy_test3-loadbalancer3.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test3-loadbalancer3-arn_aws_iam__726421854799_policy_test3-loadbalancer3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test3-loadbalancer3-arn_aws_iam__726421854799_policy_test3-loadbalancer3.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test3-loadbalancer3-arn_aws_iam__726421854799_policy_test3-loadbalancer3.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test4-loadbalancer-arn_aws_iam__726421854799_policy_test4-loadbalancer.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test4-loadbalancer-arn_aws_iam__726421854799_policy_test4-loadbalancer.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test4-loadbalancer-arn_aws_iam__726421854799_policy_test4-loadbalancer.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test4-loadbalancer-arn_aws_iam__726421854799_policy_test4-loadbalancer.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test4-loadbalancer2-arn_aws_iam__726421854799_policy_test4-loadbalancer2.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test4-loadbalancer2-arn_aws_iam__726421854799_policy_test4-loadbalancer2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test4-loadbalancer2-arn_aws_iam__726421854799_policy_test4-loadbalancer2.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test4-loadbalancer2-arn_aws_iam__726421854799_policy_test4-loadbalancer2.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test4-loadbalancer3-arn_aws_iam__726421854799_policy_test4-loadbalancer3.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test4-loadbalancer3-arn_aws_iam__726421854799_policy_test4-loadbalancer3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test4-loadbalancer3-arn_aws_iam__726421854799_policy_test4-loadbalancer3.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_attachment_multiple/aws_iam_user_policy_attachment-test4-loadbalancer3-arn_aws_iam__726421854799_policy_test4-loadbalancer3.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_attachment_multiple/main.tf b/enumeration/remote/test/aws_iam_user_policy_attachment_multiple/main.tf similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_attachment_multiple/main.tf rename to enumeration/remote/test/aws_iam_user_policy_attachment_multiple/main.tf diff --git a/pkg/remote/test/aws_iam_user_policy_attachment_multiple/results.golden.json b/enumeration/remote/test/aws_iam_user_policy_attachment_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_attachment_multiple/results.golden.json rename to enumeration/remote/test/aws_iam_user_policy_attachment_multiple/results.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test2.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test2.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test2.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test22.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test22.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test22.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test22.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test23.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test23.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test23.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test23.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test24.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test24.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test24.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test24.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test3.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test3.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test3.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test4.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test4.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test4.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer2_test4.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test2.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test2.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test2.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test22.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test22.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test22.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test22.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test23.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test23.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test23.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test23.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test24.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test24.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test24.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test24.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test3.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test3.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test3.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test32.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test32.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test32.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test32.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test33.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test33.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test33.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test33.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test34.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test34.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test34.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test34.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test4.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test4.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test4.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer3_test4.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test2.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test2.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test2.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test3.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test3.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test3.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test4.res.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test4.res.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test4.res.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/aws_iam_user_policy-loadbalancer_test4.res.golden.json diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/main.tf b/enumeration/remote/test/aws_iam_user_policy_multiple/main.tf similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/main.tf rename to enumeration/remote/test/aws_iam_user_policy_multiple/main.tf diff --git a/pkg/remote/test/aws_iam_user_policy_multiple/results.golden.json b/enumeration/remote/test/aws_iam_user_policy_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_iam_user_policy_multiple/results.golden.json rename to enumeration/remote/test/aws_iam_user_policy_multiple/results.golden.json diff --git a/pkg/remote/test/aws_kms_alias_multiple/aws_kms_alias-alias_bar.res.golden.json b/enumeration/remote/test/aws_kms_alias_multiple/aws_kms_alias-alias_bar.res.golden.json similarity index 100% rename from pkg/remote/test/aws_kms_alias_multiple/aws_kms_alias-alias_bar.res.golden.json rename to enumeration/remote/test/aws_kms_alias_multiple/aws_kms_alias-alias_bar.res.golden.json diff --git a/pkg/remote/test/aws_kms_alias_multiple/aws_kms_alias-alias_baz20210225124429210500000001.res.golden.json b/enumeration/remote/test/aws_kms_alias_multiple/aws_kms_alias-alias_baz20210225124429210500000001.res.golden.json similarity index 100% rename from pkg/remote/test/aws_kms_alias_multiple/aws_kms_alias-alias_baz20210225124429210500000001.res.golden.json rename to enumeration/remote/test/aws_kms_alias_multiple/aws_kms_alias-alias_baz20210225124429210500000001.res.golden.json diff --git a/pkg/remote/test/aws_kms_alias_multiple/aws_kms_alias-alias_foo.res.golden.json b/enumeration/remote/test/aws_kms_alias_multiple/aws_kms_alias-alias_foo.res.golden.json similarity index 100% rename from pkg/remote/test/aws_kms_alias_multiple/aws_kms_alias-alias_foo.res.golden.json rename to enumeration/remote/test/aws_kms_alias_multiple/aws_kms_alias-alias_foo.res.golden.json diff --git a/pkg/remote/test/aws_kms_alias_multiple/results.golden.json b/enumeration/remote/test/aws_kms_alias_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_kms_alias_multiple/results.golden.json rename to enumeration/remote/test/aws_kms_alias_multiple/results.golden.json diff --git a/pkg/remote/test/aws_kms_alias_multiple/terraform.tf b/enumeration/remote/test/aws_kms_alias_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_kms_alias_multiple/terraform.tf rename to enumeration/remote/test/aws_kms_alias_multiple/terraform.tf diff --git a/pkg/remote/test/aws_kms_key_multiple/aws_kms_key-5d765f32-bfdc-4610-b6ab-f82db5d0601b.res.golden.json b/enumeration/remote/test/aws_kms_key_multiple/aws_kms_key-5d765f32-bfdc-4610-b6ab-f82db5d0601b.res.golden.json similarity index 100% rename from pkg/remote/test/aws_kms_key_multiple/aws_kms_key-5d765f32-bfdc-4610-b6ab-f82db5d0601b.res.golden.json rename to enumeration/remote/test/aws_kms_key_multiple/aws_kms_key-5d765f32-bfdc-4610-b6ab-f82db5d0601b.res.golden.json diff --git a/pkg/remote/test/aws_kms_key_multiple/aws_kms_key-89d2c023-ea53-40a5-b20a-d84905c622d7.res.golden.json b/enumeration/remote/test/aws_kms_key_multiple/aws_kms_key-89d2c023-ea53-40a5-b20a-d84905c622d7.res.golden.json similarity index 100% rename from pkg/remote/test/aws_kms_key_multiple/aws_kms_key-89d2c023-ea53-40a5-b20a-d84905c622d7.res.golden.json rename to enumeration/remote/test/aws_kms_key_multiple/aws_kms_key-89d2c023-ea53-40a5-b20a-d84905c622d7.res.golden.json diff --git a/pkg/remote/test/aws_kms_key_multiple/aws_kms_key-8ee21d91-c000-428c-8032-235aac55da36.res.golden.json b/enumeration/remote/test/aws_kms_key_multiple/aws_kms_key-8ee21d91-c000-428c-8032-235aac55da36.res.golden.json similarity index 100% rename from pkg/remote/test/aws_kms_key_multiple/aws_kms_key-8ee21d91-c000-428c-8032-235aac55da36.res.golden.json rename to enumeration/remote/test/aws_kms_key_multiple/aws_kms_key-8ee21d91-c000-428c-8032-235aac55da36.res.golden.json diff --git a/pkg/remote/test/aws_kms_key_multiple/results.golden.json b/enumeration/remote/test/aws_kms_key_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_kms_key_multiple/results.golden.json rename to enumeration/remote/test/aws_kms_key_multiple/results.golden.json diff --git a/pkg/remote/test/aws_kms_key_multiple/terraform.tf b/enumeration/remote/test/aws_kms_key_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_kms_key_multiple/terraform.tf rename to enumeration/remote/test/aws_kms_key_multiple/terraform.tf diff --git a/pkg/remote/test/aws_lambda_function_multiple/aws_lambda_function-bar-bar.res.golden.json b/enumeration/remote/test/aws_lambda_function_multiple/aws_lambda_function-bar-bar.res.golden.json similarity index 100% rename from pkg/remote/test/aws_lambda_function_multiple/aws_lambda_function-bar-bar.res.golden.json rename to enumeration/remote/test/aws_lambda_function_multiple/aws_lambda_function-bar-bar.res.golden.json diff --git a/pkg/remote/test/aws_lambda_function_multiple/aws_lambda_function-foo-foo.res.golden.json b/enumeration/remote/test/aws_lambda_function_multiple/aws_lambda_function-foo-foo.res.golden.json similarity index 100% rename from pkg/remote/test/aws_lambda_function_multiple/aws_lambda_function-foo-foo.res.golden.json rename to enumeration/remote/test/aws_lambda_function_multiple/aws_lambda_function-foo-foo.res.golden.json diff --git a/pkg/remote/test/aws_lambda_function_multiple/results.golden.json b/enumeration/remote/test/aws_lambda_function_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_lambda_function_multiple/results.golden.json rename to enumeration/remote/test/aws_lambda_function_multiple/results.golden.json diff --git a/pkg/remote/test/aws_lambda_function_signed/aws_lambda_function-foo-foo.res.golden.json b/enumeration/remote/test/aws_lambda_function_signed/aws_lambda_function-foo-foo.res.golden.json similarity index 100% rename from pkg/remote/test/aws_lambda_function_signed/aws_lambda_function-foo-foo.res.golden.json rename to enumeration/remote/test/aws_lambda_function_signed/aws_lambda_function-foo-foo.res.golden.json diff --git a/pkg/remote/test/aws_lambda_function_signed/lambda.zip b/enumeration/remote/test/aws_lambda_function_signed/lambda.zip similarity index 100% rename from pkg/remote/test/aws_lambda_function_signed/lambda.zip rename to enumeration/remote/test/aws_lambda_function_signed/lambda.zip diff --git a/pkg/remote/test/aws_lambda_function_signed/main.tf b/enumeration/remote/test/aws_lambda_function_signed/main.tf similarity index 100% rename from pkg/remote/test/aws_lambda_function_signed/main.tf rename to enumeration/remote/test/aws_lambda_function_signed/main.tf diff --git a/pkg/remote/test/aws_lambda_function_signed/results.golden.json b/enumeration/remote/test/aws_lambda_function_signed/results.golden.json similarity index 100% rename from pkg/remote/test/aws_lambda_function_signed/results.golden.json rename to enumeration/remote/test/aws_lambda_function_signed/results.golden.json diff --git a/pkg/remote/test/aws_lambda_source_mapping_dynamo_multiple/aws_lambda_event_source_mapping-1aa9c4a0-060b-41c1-a9ae-dc304ebcdb00.res.golden.json b/enumeration/remote/test/aws_lambda_source_mapping_dynamo_multiple/aws_lambda_event_source_mapping-1aa9c4a0-060b-41c1-a9ae-dc304ebcdb00.res.golden.json similarity index 100% rename from pkg/remote/test/aws_lambda_source_mapping_dynamo_multiple/aws_lambda_event_source_mapping-1aa9c4a0-060b-41c1-a9ae-dc304ebcdb00.res.golden.json rename to enumeration/remote/test/aws_lambda_source_mapping_dynamo_multiple/aws_lambda_event_source_mapping-1aa9c4a0-060b-41c1-a9ae-dc304ebcdb00.res.golden.json diff --git a/pkg/remote/test/aws_lambda_source_mapping_dynamo_multiple/results.golden.json b/enumeration/remote/test/aws_lambda_source_mapping_dynamo_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_lambda_source_mapping_dynamo_multiple/results.golden.json rename to enumeration/remote/test/aws_lambda_source_mapping_dynamo_multiple/results.golden.json diff --git a/pkg/remote/test/aws_lambda_source_mapping_sqs_multiple/aws_lambda_event_source_mapping-13ff66f8-37eb-4ad6-a0a8-594fea72df4f.res.golden.json b/enumeration/remote/test/aws_lambda_source_mapping_sqs_multiple/aws_lambda_event_source_mapping-13ff66f8-37eb-4ad6-a0a8-594fea72df4f.res.golden.json similarity index 100% rename from pkg/remote/test/aws_lambda_source_mapping_sqs_multiple/aws_lambda_event_source_mapping-13ff66f8-37eb-4ad6-a0a8-594fea72df4f.res.golden.json rename to enumeration/remote/test/aws_lambda_source_mapping_sqs_multiple/aws_lambda_event_source_mapping-13ff66f8-37eb-4ad6-a0a8-594fea72df4f.res.golden.json diff --git a/pkg/remote/test/aws_lambda_source_mapping_sqs_multiple/aws_lambda_event_source_mapping-4ad7e2b3-79e9-4713-9d9d-5af2c01d9058.res.golden.json b/enumeration/remote/test/aws_lambda_source_mapping_sqs_multiple/aws_lambda_event_source_mapping-4ad7e2b3-79e9-4713-9d9d-5af2c01d9058.res.golden.json similarity index 100% rename from pkg/remote/test/aws_lambda_source_mapping_sqs_multiple/aws_lambda_event_source_mapping-4ad7e2b3-79e9-4713-9d9d-5af2c01d9058.res.golden.json rename to enumeration/remote/test/aws_lambda_source_mapping_sqs_multiple/aws_lambda_event_source_mapping-4ad7e2b3-79e9-4713-9d9d-5af2c01d9058.res.golden.json diff --git a/pkg/remote/test/aws_lambda_source_mapping_sqs_multiple/results.golden.json b/enumeration/remote/test/aws_lambda_source_mapping_sqs_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_lambda_source_mapping_sqs_multiple/results.golden.json rename to enumeration/remote/test/aws_lambda_source_mapping_sqs_multiple/results.golden.json diff --git a/pkg/remote/test/aws_launch_template_multiple/aws_launch_template-lt-00b2d18c6cee7fe23.res.golden.json b/enumeration/remote/test/aws_launch_template_multiple/aws_launch_template-lt-00b2d18c6cee7fe23.res.golden.json similarity index 100% rename from pkg/remote/test/aws_launch_template_multiple/aws_launch_template-lt-00b2d18c6cee7fe23.res.golden.json rename to enumeration/remote/test/aws_launch_template_multiple/aws_launch_template-lt-00b2d18c6cee7fe23.res.golden.json diff --git a/pkg/remote/test/aws_launch_template_multiple/aws_launch_template-lt-0ed993d09ce6afc67.res.golden.json b/enumeration/remote/test/aws_launch_template_multiple/aws_launch_template-lt-0ed993d09ce6afc67.res.golden.json similarity index 100% rename from pkg/remote/test/aws_launch_template_multiple/aws_launch_template-lt-0ed993d09ce6afc67.res.golden.json rename to enumeration/remote/test/aws_launch_template_multiple/aws_launch_template-lt-0ed993d09ce6afc67.res.golden.json diff --git a/pkg/remote/test/aws_launch_template_multiple/results.golden.json b/enumeration/remote/test/aws_launch_template_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_launch_template_multiple/results.golden.json rename to enumeration/remote/test/aws_launch_template_multiple/results.golden.json diff --git a/pkg/remote/test/aws_rds_clusters_results/aws_rds_cluster-aurora-cluster-demo-2-aurora-cluster-demo-2-.res.golden.json b/enumeration/remote/test/aws_rds_clusters_results/aws_rds_cluster-aurora-cluster-demo-2-aurora-cluster-demo-2-.res.golden.json similarity index 100% rename from pkg/remote/test/aws_rds_clusters_results/aws_rds_cluster-aurora-cluster-demo-2-aurora-cluster-demo-2-.res.golden.json rename to enumeration/remote/test/aws_rds_clusters_results/aws_rds_cluster-aurora-cluster-demo-2-aurora-cluster-demo-2-.res.golden.json diff --git a/pkg/remote/test/aws_rds_clusters_results/aws_rds_cluster-aurora-cluster-demo-aurora-cluster-demo-mydb.res.golden.json b/enumeration/remote/test/aws_rds_clusters_results/aws_rds_cluster-aurora-cluster-demo-aurora-cluster-demo-mydb.res.golden.json similarity index 100% rename from pkg/remote/test/aws_rds_clusters_results/aws_rds_cluster-aurora-cluster-demo-aurora-cluster-demo-mydb.res.golden.json rename to enumeration/remote/test/aws_rds_clusters_results/aws_rds_cluster-aurora-cluster-demo-aurora-cluster-demo-mydb.res.golden.json diff --git a/enumeration/remote/test/aws_rds_clusters_results/results.golden.json b/enumeration/remote/test/aws_rds_clusters_results/results.golden.json new file mode 100755 index 00000000..a17efc9d --- /dev/null +++ b/enumeration/remote/test/aws_rds_clusters_results/results.golden.json @@ -0,0 +1,112 @@ +[ + { + "allow_major_version_upgrade": null, + "apply_immediately": null, + "arn": "arn:aws:rds:us-east-1:533948124879:cluster:aurora-cluster-demo", + "availability_zones": [ + "us-east-1a", + "us-east-1b", + "us-east-1d" + ], + "backtrack_window": 0, + "backup_retention_period": 5, + "cluster_identifier": "aurora-cluster-demo", + "cluster_identifier_prefix": null, + "cluster_members": [ + "aurora-cluster-demo-0" + ], + "cluster_resource_id": "cluster-TISYDSSX4J5R6ZGUTV6LLJW73E", + "copy_tags_to_snapshot": false, + "database_name": "mydb", + "db_cluster_parameter_group_name": "default.aurora-postgresql11", + "db_subnet_group_name": "default", + "deletion_protection": false, + "enable_http_endpoint": false, + "enabled_cloudwatch_logs_exports": null, + "endpoint": "aurora-cluster-demo.cluster-cd539r6quiux.us-east-1.rds.amazonaws.com", + "engine": "aurora-postgresql", + "engine_mode": "provisioned", + "engine_version": "11.9", + "final_snapshot_identifier": null, + "global_cluster_identifier": "", + "hosted_zone_id": "Z2R2ITUGPM61AM", + "iam_database_authentication_enabled": false, + "iam_roles": null, + "id": "aurora-cluster-demo", + "kms_key_id": "", + "master_password": null, + "master_username": "foo", + "port": 5432, + "preferred_backup_window": "07:00-09:00", + "preferred_maintenance_window": "fri:03:03-fri:03:33", + "reader_endpoint": "aurora-cluster-demo.cluster-ro-cd539r6quiux.us-east-1.rds.amazonaws.com", + "replication_source_identifier": "", + "restore_to_point_in_time": null, + "s3_import": null, + "scaling_configuration": null, + "skip_final_snapshot": null, + "snapshot_identifier": null, + "source_region": null, + "storage_encrypted": false, + "tags": null, + "timeouts": {}, + "vpc_security_group_ids": [ + "sg-49e38646" + ] + }, + { + "allow_major_version_upgrade": null, + "apply_immediately": null, + "arn": "arn:aws:rds:us-east-1:533948124879:cluster:aurora-cluster-demo", + "availability_zones": [ + "us-east-1a", + "us-east-1b", + "us-east-1d" + ], + "backtrack_window": 0, + "backup_retention_period": 5, + "cluster_identifier": "aurora-cluster-demo-2", + "cluster_identifier_prefix": null, + "cluster_members": [ + "aurora-cluster-demo-0" + ], + "cluster_resource_id": "cluster-TISYDSSX4J5R6ZGUTV6LLJW73E", + "copy_tags_to_snapshot": false, + "database_name": "", + "db_cluster_parameter_group_name": "default.aurora-postgresql11", + "db_subnet_group_name": "default", + "deletion_protection": false, + "enable_http_endpoint": false, + "enabled_cloudwatch_logs_exports": null, + "endpoint": "aurora-cluster-demo.cluster-cd539r6quiux.us-east-1.rds.amazonaws.com", + "engine": "aurora-postgresql", + "engine_mode": "provisioned", + "engine_version": "11.9", + "final_snapshot_identifier": null, + "global_cluster_identifier": "", + "hosted_zone_id": "Z2R2ITUGPM61AM", + "iam_database_authentication_enabled": false, + "iam_roles": null, + "id": "aurora-cluster-demo-2", + "kms_key_id": "", + "master_password": null, + "master_username": "foo", + "port": 5432, + "preferred_backup_window": "07:00-09:00", + "preferred_maintenance_window": "fri:03:03-fri:03:33", + "reader_endpoint": "aurora-cluster-demo.cluster-ro-cd539r6quiux.us-east-1.rds.amazonaws.com", + "replication_source_identifier": "", + "restore_to_point_in_time": null, + "s3_import": null, + "scaling_configuration": null, + "skip_final_snapshot": null, + "snapshot_identifier": null, + "source_region": null, + "storage_encrypted": false, + "tags": null, + "timeouts": {}, + "vpc_security_group_ids": [ + "sg-49e38646" + ] + } +] \ No newline at end of file diff --git a/pkg/remote/test/aws_rds_db_instance_multiple/aws_db_instance-database-1.res.golden.json b/enumeration/remote/test/aws_rds_db_instance_multiple/aws_db_instance-database-1.res.golden.json similarity index 100% rename from pkg/remote/test/aws_rds_db_instance_multiple/aws_db_instance-database-1.res.golden.json rename to enumeration/remote/test/aws_rds_db_instance_multiple/aws_db_instance-database-1.res.golden.json diff --git a/pkg/remote/test/aws_rds_db_instance_multiple/aws_db_instance-terraform-20201015115018309600000001.res.golden.json b/enumeration/remote/test/aws_rds_db_instance_multiple/aws_db_instance-terraform-20201015115018309600000001.res.golden.json similarity index 100% rename from pkg/remote/test/aws_rds_db_instance_multiple/aws_db_instance-terraform-20201015115018309600000001.res.golden.json rename to enumeration/remote/test/aws_rds_db_instance_multiple/aws_db_instance-terraform-20201015115018309600000001.res.golden.json diff --git a/pkg/remote/test/aws_rds_db_instance_multiple/results.golden.json b/enumeration/remote/test/aws_rds_db_instance_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_rds_db_instance_multiple/results.golden.json rename to enumeration/remote/test/aws_rds_db_instance_multiple/results.golden.json diff --git a/pkg/remote/test/aws_rds_db_instance_single/aws_db_instance-terraform-20201015115018309600000001.res.golden.json b/enumeration/remote/test/aws_rds_db_instance_single/aws_db_instance-terraform-20201015115018309600000001.res.golden.json similarity index 100% rename from pkg/remote/test/aws_rds_db_instance_single/aws_db_instance-terraform-20201015115018309600000001.res.golden.json rename to enumeration/remote/test/aws_rds_db_instance_single/aws_db_instance-terraform-20201015115018309600000001.res.golden.json diff --git a/pkg/remote/test/aws_rds_db_instance_single/results.golden.json b/enumeration/remote/test/aws_rds_db_instance_single/results.golden.json similarity index 100% rename from pkg/remote/test/aws_rds_db_instance_single/results.golden.json rename to enumeration/remote/test/aws_rds_db_instance_single/results.golden.json diff --git a/pkg/remote/test/aws_rds_db_subnet_group_multiple/aws_db_subnet_group-bar.res.golden.json b/enumeration/remote/test/aws_rds_db_subnet_group_multiple/aws_db_subnet_group-bar.res.golden.json similarity index 100% rename from pkg/remote/test/aws_rds_db_subnet_group_multiple/aws_db_subnet_group-bar.res.golden.json rename to enumeration/remote/test/aws_rds_db_subnet_group_multiple/aws_db_subnet_group-bar.res.golden.json diff --git a/pkg/remote/test/aws_rds_db_subnet_group_multiple/aws_db_subnet_group-foo.res.golden.json b/enumeration/remote/test/aws_rds_db_subnet_group_multiple/aws_db_subnet_group-foo.res.golden.json similarity index 100% rename from pkg/remote/test/aws_rds_db_subnet_group_multiple/aws_db_subnet_group-foo.res.golden.json rename to enumeration/remote/test/aws_rds_db_subnet_group_multiple/aws_db_subnet_group-foo.res.golden.json diff --git a/pkg/remote/test/aws_rds_db_subnet_group_multiple/results.golden.json b/enumeration/remote/test/aws_rds_db_subnet_group_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_rds_db_subnet_group_multiple/results.golden.json rename to enumeration/remote/test/aws_rds_db_subnet_group_multiple/results.golden.json diff --git a/pkg/remote/test/aws_route53_health_check_empty/schema.golden.json b/enumeration/remote/test/aws_route53_health_check_empty/schema.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_health_check_empty/schema.golden.json rename to enumeration/remote/test/aws_route53_health_check_empty/schema.golden.json diff --git a/pkg/remote/test/aws_route53_health_check_multiple/aws_route53_health_check-7001a9df-ded4-4802-9909-668eb80b972b.res.golden.json b/enumeration/remote/test/aws_route53_health_check_multiple/aws_route53_health_check-7001a9df-ded4-4802-9909-668eb80b972b.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_health_check_multiple/aws_route53_health_check-7001a9df-ded4-4802-9909-668eb80b972b.res.golden.json rename to enumeration/remote/test/aws_route53_health_check_multiple/aws_route53_health_check-7001a9df-ded4-4802-9909-668eb80b972b.res.golden.json diff --git a/pkg/remote/test/aws_route53_health_check_multiple/aws_route53_health_check-84fc318a-2e0d-41d6-b638-280e2f0f4e26.res.golden.json b/enumeration/remote/test/aws_route53_health_check_multiple/aws_route53_health_check-84fc318a-2e0d-41d6-b638-280e2f0f4e26.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_health_check_multiple/aws_route53_health_check-84fc318a-2e0d-41d6-b638-280e2f0f4e26.res.golden.json rename to enumeration/remote/test/aws_route53_health_check_multiple/aws_route53_health_check-84fc318a-2e0d-41d6-b638-280e2f0f4e26.res.golden.json diff --git a/pkg/remote/test/aws_route53_health_check_multiple/results.golden.json b/enumeration/remote/test/aws_route53_health_check_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_health_check_multiple/results.golden.json rename to enumeration/remote/test/aws_route53_health_check_multiple/results.golden.json diff --git a/pkg/remote/test/aws_route53_health_check_multiple/terraform.tf b/enumeration/remote/test/aws_route53_health_check_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_route53_health_check_multiple/terraform.tf rename to enumeration/remote/test/aws_route53_health_check_multiple/terraform.tf diff --git a/pkg/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM__test2.foo-2.com_A.res.golden.json b/enumeration/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM__test2.foo-2.com_A.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM__test2.foo-2.com_A.res.golden.json rename to enumeration/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM__test2.foo-2.com_A.res.golden.json diff --git a/pkg/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM__test2.foo-2.com_TXT.res.golden.json b/enumeration/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM__test2.foo-2.com_TXT.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM__test2.foo-2.com_TXT.res.golden.json rename to enumeration/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM__test2.foo-2.com_TXT.res.golden.json diff --git a/pkg/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test0_A.res.golden.json b/enumeration/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test0_A.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test0_A.res.golden.json rename to enumeration/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test0_A.res.golden.json diff --git a/pkg/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test0_TXT.res.golden.json b/enumeration/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test0_TXT.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test0_TXT.res.golden.json rename to enumeration/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test0_TXT.res.golden.json diff --git a/pkg/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test1.foo-2.com_A.res.golden.json b/enumeration/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test1.foo-2.com_A.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test1.foo-2.com_A.res.golden.json rename to enumeration/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test1.foo-2.com_A.res.golden.json diff --git a/pkg/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test1.foo-2.com_TXT.res.golden.json b/enumeration/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test1.foo-2.com_TXT.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test1.foo-2.com_TXT.res.golden.json rename to enumeration/remote/test/aws_route53_record_explicit_subdomain/aws_route53_record-Z06486383UC8WYSBZTWFM_test1.foo-2.com_TXT.res.golden.json diff --git a/pkg/remote/test/aws_route53_record_explicit_subdomain/main.tf b/enumeration/remote/test/aws_route53_record_explicit_subdomain/main.tf similarity index 100% rename from pkg/remote/test/aws_route53_record_explicit_subdomain/main.tf rename to enumeration/remote/test/aws_route53_record_explicit_subdomain/main.tf diff --git a/pkg/remote/test/aws_route53_record_explicit_subdomain/results.golden.json b/enumeration/remote/test/aws_route53_record_explicit_subdomain/results.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_explicit_subdomain/results.golden.json rename to enumeration/remote/test/aws_route53_record_explicit_subdomain/results.golden.json diff --git a/pkg/remote/test/aws_route53_record_explicit_subdomain/schema.golden.json b/enumeration/remote/test/aws_route53_record_explicit_subdomain/schema.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_explicit_subdomain/schema.golden.json rename to enumeration/remote/test/aws_route53_record_explicit_subdomain/schema.golden.json diff --git a/pkg/remote/test/aws_route53_record_multiples/aws_route53_record-Z10347383HV75H96J919W_test2_A.res.golden.json b/enumeration/remote/test/aws_route53_record_multiples/aws_route53_record-Z10347383HV75H96J919W_test2_A.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_multiples/aws_route53_record-Z10347383HV75H96J919W_test2_A.res.golden.json rename to enumeration/remote/test/aws_route53_record_multiples/aws_route53_record-Z10347383HV75H96J919W_test2_A.res.golden.json diff --git a/pkg/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G__.test4_A.res.golden.json b/enumeration/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G__.test4_A.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G__.test4_A.res.golden.json rename to enumeration/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G__.test4_A.res.golden.json diff --git a/pkg/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_foo-0.com_NS.res.golden.json b/enumeration/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_foo-0.com_NS.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_foo-0.com_NS.res.golden.json rename to enumeration/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_foo-0.com_NS.res.golden.json diff --git a/pkg/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_test0_A.res.golden.json b/enumeration/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_test0_A.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_test0_A.res.golden.json rename to enumeration/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_test0_A.res.golden.json diff --git a/pkg/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_test1_A.res.golden.json b/enumeration/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_test1_A.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_test1_A.res.golden.json rename to enumeration/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_test1_A.res.golden.json diff --git a/pkg/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_test2_A.res.golden.json b/enumeration/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_test2_A.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_test2_A.res.golden.json rename to enumeration/remote/test/aws_route53_record_multiples/aws_route53_record-Z1035360GLIB82T1EH2G_test2_A.res.golden.json diff --git a/pkg/remote/test/aws_route53_record_multiples/results.golden.json b/enumeration/remote/test/aws_route53_record_multiples/results.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_multiples/results.golden.json rename to enumeration/remote/test/aws_route53_record_multiples/results.golden.json diff --git a/pkg/remote/test/aws_route53_record_multiples/schema.golden.json b/enumeration/remote/test/aws_route53_record_multiples/schema.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_record_multiples/schema.golden.json rename to enumeration/remote/test/aws_route53_record_multiples/schema.golden.json diff --git a/pkg/remote/test/aws_route53_zone_empty/schema.golden.json b/enumeration/remote/test/aws_route53_zone_empty/schema.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_zone_empty/schema.golden.json rename to enumeration/remote/test/aws_route53_zone_empty/schema.golden.json diff --git a/pkg/remote/test/aws_route53_zone_multiples/aws_route53_zone-Z01804312AV8PHE3C43AD.res.golden.json b/enumeration/remote/test/aws_route53_zone_multiples/aws_route53_zone-Z01804312AV8PHE3C43AD.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_zone_multiples/aws_route53_zone-Z01804312AV8PHE3C43AD.res.golden.json rename to enumeration/remote/test/aws_route53_zone_multiples/aws_route53_zone-Z01804312AV8PHE3C43AD.res.golden.json diff --git a/pkg/remote/test/aws_route53_zone_multiples/aws_route53_zone-Z01809283VH9BBALZHO7B.res.golden.json b/enumeration/remote/test/aws_route53_zone_multiples/aws_route53_zone-Z01809283VH9BBALZHO7B.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_zone_multiples/aws_route53_zone-Z01809283VH9BBALZHO7B.res.golden.json rename to enumeration/remote/test/aws_route53_zone_multiples/aws_route53_zone-Z01809283VH9BBALZHO7B.res.golden.json diff --git a/pkg/remote/test/aws_route53_zone_multiples/aws_route53_zone-Z01874941AR1TCGV5K65C.res.golden.json b/enumeration/remote/test/aws_route53_zone_multiples/aws_route53_zone-Z01874941AR1TCGV5K65C.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_zone_multiples/aws_route53_zone-Z01874941AR1TCGV5K65C.res.golden.json rename to enumeration/remote/test/aws_route53_zone_multiples/aws_route53_zone-Z01874941AR1TCGV5K65C.res.golden.json diff --git a/pkg/remote/test/aws_route53_zone_multiples/results.golden.json b/enumeration/remote/test/aws_route53_zone_multiples/results.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_zone_multiples/results.golden.json rename to enumeration/remote/test/aws_route53_zone_multiples/results.golden.json diff --git a/pkg/remote/test/aws_route53_zone_multiples/schema.golden.json b/enumeration/remote/test/aws_route53_zone_multiples/schema.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_zone_multiples/schema.golden.json rename to enumeration/remote/test/aws_route53_zone_multiples/schema.golden.json diff --git a/pkg/remote/test/aws_route53_zone_multiples/terraform.tf b/enumeration/remote/test/aws_route53_zone_multiples/terraform.tf similarity index 100% rename from pkg/remote/test/aws_route53_zone_multiples/terraform.tf rename to enumeration/remote/test/aws_route53_zone_multiples/terraform.tf diff --git a/pkg/remote/test/aws_route53_zone_single/aws_route53_zone-Z08068311RGDXPHF8KE62.res.golden.json b/enumeration/remote/test/aws_route53_zone_single/aws_route53_zone-Z08068311RGDXPHF8KE62.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_zone_single/aws_route53_zone-Z08068311RGDXPHF8KE62.res.golden.json rename to enumeration/remote/test/aws_route53_zone_single/aws_route53_zone-Z08068311RGDXPHF8KE62.res.golden.json diff --git a/pkg/remote/test/aws_route53_zone_single/aws_route53_zone-Z093553112BLINKY4N57.res.golden.json b/enumeration/remote/test/aws_route53_zone_single/aws_route53_zone-Z093553112BLINKY4N57.res.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_zone_single/aws_route53_zone-Z093553112BLINKY4N57.res.golden.json rename to enumeration/remote/test/aws_route53_zone_single/aws_route53_zone-Z093553112BLINKY4N57.res.golden.json diff --git a/pkg/remote/test/aws_route53_zone_single/results.golden.json b/enumeration/remote/test/aws_route53_zone_single/results.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_zone_single/results.golden.json rename to enumeration/remote/test/aws_route53_zone_single/results.golden.json diff --git a/pkg/remote/test/aws_route53_zone_single/schema.golden.json b/enumeration/remote/test/aws_route53_zone_single/schema.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_zone_single/schema.golden.json rename to enumeration/remote/test/aws_route53_zone_single/schema.golden.json diff --git a/pkg/remote/test/aws_route53_zone_single/terraform.tf b/enumeration/remote/test/aws_route53_zone_single/terraform.tf similarity index 100% rename from pkg/remote/test/aws_route53_zone_single/terraform.tf rename to enumeration/remote/test/aws_route53_zone_single/terraform.tf diff --git a/pkg/remote/test/aws_route53_zone_with_no_record/schema.golden.json b/enumeration/remote/test/aws_route53_zone_with_no_record/schema.golden.json similarity index 100% rename from pkg/remote/test/aws_route53_zone_with_no_record/schema.golden.json rename to enumeration/remote/test/aws_route53_zone_with_no_record/schema.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_analytics_multiple/aws_s3_bucket_analytics_configuration-bucket-martin-test-drift2_Analytics2_Bucket2-eu-west-3.res.golden.json b/enumeration/remote/test/aws_s3_bucket_analytics_multiple/aws_s3_bucket_analytics_configuration-bucket-martin-test-drift2_Analytics2_Bucket2-eu-west-3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_analytics_multiple/aws_s3_bucket_analytics_configuration-bucket-martin-test-drift2_Analytics2_Bucket2-eu-west-3.res.golden.json rename to enumeration/remote/test/aws_s3_bucket_analytics_multiple/aws_s3_bucket_analytics_configuration-bucket-martin-test-drift2_Analytics2_Bucket2-eu-west-3.res.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_analytics_multiple/aws_s3_bucket_analytics_configuration-bucket-martin-test-drift2_Analytics_Bucket2-eu-west-3.res.golden.json b/enumeration/remote/test/aws_s3_bucket_analytics_multiple/aws_s3_bucket_analytics_configuration-bucket-martin-test-drift2_Analytics_Bucket2-eu-west-3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_analytics_multiple/aws_s3_bucket_analytics_configuration-bucket-martin-test-drift2_Analytics_Bucket2-eu-west-3.res.golden.json rename to enumeration/remote/test/aws_s3_bucket_analytics_multiple/aws_s3_bucket_analytics_configuration-bucket-martin-test-drift2_Analytics_Bucket2-eu-west-3.res.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_analytics_multiple/results.golden.json b/enumeration/remote/test/aws_s3_bucket_analytics_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_analytics_multiple/results.golden.json rename to enumeration/remote/test/aws_s3_bucket_analytics_multiple/results.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_inventories_multiple/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory2_Bucket2-eu-west-3.res.golden.json b/enumeration/remote/test/aws_s3_bucket_inventories_multiple/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory2_Bucket2-eu-west-3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_inventories_multiple/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory2_Bucket2-eu-west-3.res.golden.json rename to enumeration/remote/test/aws_s3_bucket_inventories_multiple/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory2_Bucket2-eu-west-3.res.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_inventories_multiple/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory_Bucket2-eu-west-3.res.golden.json b/enumeration/remote/test/aws_s3_bucket_inventories_multiple/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory_Bucket2-eu-west-3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_inventories_multiple/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory_Bucket2-eu-west-3.res.golden.json rename to enumeration/remote/test/aws_s3_bucket_inventories_multiple/aws_s3_bucket_inventory-bucket-martin-test-drift2_Inventory_Bucket2-eu-west-3.res.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_inventories_multiple/results.golden.json b/enumeration/remote/test/aws_s3_bucket_inventories_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_inventories_multiple/results.golden.json rename to enumeration/remote/test/aws_s3_bucket_inventories_multiple/results.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metric2_Bucket2-eu-west-3.res.golden.json b/enumeration/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metric2_Bucket2-eu-west-3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metric2_Bucket2-eu-west-3.res.golden.json rename to enumeration/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metric2_Bucket2-eu-west-3.res.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metric_Bucket2-eu-west-3.res.golden.json b/enumeration/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metric_Bucket2-eu-west-3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metric_Bucket2-eu-west-3.res.golden.json rename to enumeration/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metric_Bucket2-eu-west-3.res.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metrics2_Bucket2-eu-west-3.res.golden.json b/enumeration/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metrics2_Bucket2-eu-west-3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metrics2_Bucket2-eu-west-3.res.golden.json rename to enumeration/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metrics2_Bucket2-eu-west-3.res.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metrics_Bucket2-eu-west-3.res.golden.json b/enumeration/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metrics_Bucket2-eu-west-3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metrics_Bucket2-eu-west-3.res.golden.json rename to enumeration/remote/test/aws_s3_bucket_metrics_multiple/aws_s3_bucket_metric-bucket-martin-test-drift2_Metrics_Bucket2-eu-west-3.res.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_metrics_multiple/results.golden.json b/enumeration/remote/test/aws_s3_bucket_metrics_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_metrics_multiple/results.golden.json rename to enumeration/remote/test/aws_s3_bucket_metrics_multiple/results.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_multiple/aws_s3_bucket-bucket-martin-test-drift2-eu-west-3.res.golden.json b/enumeration/remote/test/aws_s3_bucket_multiple/aws_s3_bucket-bucket-martin-test-drift2-eu-west-3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_multiple/aws_s3_bucket-bucket-martin-test-drift2-eu-west-3.res.golden.json rename to enumeration/remote/test/aws_s3_bucket_multiple/aws_s3_bucket-bucket-martin-test-drift2-eu-west-3.res.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_multiple/main.tf b/enumeration/remote/test/aws_s3_bucket_multiple/main.tf similarity index 100% rename from pkg/remote/test/aws_s3_bucket_multiple/main.tf rename to enumeration/remote/test/aws_s3_bucket_multiple/main.tf diff --git a/pkg/remote/test/aws_s3_bucket_multiple/results.golden.json b/enumeration/remote/test/aws_s3_bucket_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_multiple/results.golden.json rename to enumeration/remote/test/aws_s3_bucket_multiple/results.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_notifications_multiple/aws_s3_bucket_notification-bucket-martin-test-drift2-eu-west-3.res.golden.json b/enumeration/remote/test/aws_s3_bucket_notifications_multiple/aws_s3_bucket_notification-bucket-martin-test-drift2-eu-west-3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_notifications_multiple/aws_s3_bucket_notification-bucket-martin-test-drift2-eu-west-3.res.golden.json rename to enumeration/remote/test/aws_s3_bucket_notifications_multiple/aws_s3_bucket_notification-bucket-martin-test-drift2-eu-west-3.res.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_notifications_multiple/results.golden.json b/enumeration/remote/test/aws_s3_bucket_notifications_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_notifications_multiple/results.golden.json rename to enumeration/remote/test/aws_s3_bucket_notifications_multiple/results.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_notifications_no_notif/aws_s3_bucket_notification-dritftctl-test-no-notifications.res.golden.json b/enumeration/remote/test/aws_s3_bucket_notifications_no_notif/aws_s3_bucket_notification-dritftctl-test-no-notifications.res.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_notifications_no_notif/aws_s3_bucket_notification-dritftctl-test-no-notifications.res.golden.json rename to enumeration/remote/test/aws_s3_bucket_notifications_no_notif/aws_s3_bucket_notification-dritftctl-test-no-notifications.res.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_notifications_no_notif/terraform.tf b/enumeration/remote/test/aws_s3_bucket_notifications_no_notif/terraform.tf similarity index 100% rename from pkg/remote/test/aws_s3_bucket_notifications_no_notif/terraform.tf rename to enumeration/remote/test/aws_s3_bucket_notifications_no_notif/terraform.tf diff --git a/pkg/remote/test/aws_s3_bucket_policies_multiple/aws_s3_bucket_policy-bucket-martin-test-drift2-eu-west-3.res.golden.json b/enumeration/remote/test/aws_s3_bucket_policies_multiple/aws_s3_bucket_policy-bucket-martin-test-drift2-eu-west-3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_policies_multiple/aws_s3_bucket_policy-bucket-martin-test-drift2-eu-west-3.res.golden.json rename to enumeration/remote/test/aws_s3_bucket_policies_multiple/aws_s3_bucket_policy-bucket-martin-test-drift2-eu-west-3.res.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_policies_multiple/results.golden.json b/enumeration/remote/test/aws_s3_bucket_policies_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_s3_bucket_policies_multiple/results.golden.json rename to enumeration/remote/test/aws_s3_bucket_policies_multiple/results.golden.json diff --git a/pkg/remote/test/aws_s3_bucket_policy_no_policy/terraform.tf b/enumeration/remote/test/aws_s3_bucket_policy_no_policy/terraform.tf similarity index 100% rename from pkg/remote/test/aws_s3_bucket_policy_no_policy/terraform.tf rename to enumeration/remote/test/aws_s3_bucket_policy_no_policy/terraform.tf diff --git a/pkg/remote/test/aws_sns_topic_multiple/aws_sns_topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic.res.golden.json b/enumeration/remote/test/aws_sns_topic_multiple/aws_sns_topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic.res.golden.json similarity index 100% rename from pkg/remote/test/aws_sns_topic_multiple/aws_sns_topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic.res.golden.json rename to enumeration/remote/test/aws_sns_topic_multiple/aws_sns_topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic.res.golden.json diff --git a/pkg/remote/test/aws_sns_topic_multiple/aws_sns_topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic2-arn_aws_sns_eu-west-3_526954929923_user-updates-topic2.res.golden.json b/enumeration/remote/test/aws_sns_topic_multiple/aws_sns_topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic2-arn_aws_sns_eu-west-3_526954929923_user-updates-topic2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_sns_topic_multiple/aws_sns_topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic2-arn_aws_sns_eu-west-3_526954929923_user-updates-topic2.res.golden.json rename to enumeration/remote/test/aws_sns_topic_multiple/aws_sns_topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic2-arn_aws_sns_eu-west-3_526954929923_user-updates-topic2.res.golden.json diff --git a/pkg/remote/test/aws_sns_topic_multiple/aws_sns_topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic3-arn_aws_sns_eu-west-3_526954929923_user-updates-topic3.res.golden.json b/enumeration/remote/test/aws_sns_topic_multiple/aws_sns_topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic3-arn_aws_sns_eu-west-3_526954929923_user-updates-topic3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_sns_topic_multiple/aws_sns_topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic3-arn_aws_sns_eu-west-3_526954929923_user-updates-topic3.res.golden.json rename to enumeration/remote/test/aws_sns_topic_multiple/aws_sns_topic-arn_aws_sns_eu-west-3_526954929923_user-updates-topic3-arn_aws_sns_eu-west-3_526954929923_user-updates-topic3.res.golden.json diff --git a/pkg/remote/test/aws_sns_topic_multiple/main.tf b/enumeration/remote/test/aws_sns_topic_multiple/main.tf similarity index 100% rename from pkg/remote/test/aws_sns_topic_multiple/main.tf rename to enumeration/remote/test/aws_sns_topic_multiple/main.tf diff --git a/pkg/remote/test/aws_sns_topic_multiple/results.golden.json b/enumeration/remote/test/aws_sns_topic_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_sns_topic_multiple/results.golden.json rename to enumeration/remote/test/aws_sns_topic_multiple/results.golden.json diff --git a/pkg/remote/test/aws_sns_topic_policy_multiple/aws_sns_topic_policy-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy.res.golden.json b/enumeration/remote/test/aws_sns_topic_policy_multiple/aws_sns_topic_policy-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy.res.golden.json similarity index 100% rename from pkg/remote/test/aws_sns_topic_policy_multiple/aws_sns_topic_policy-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy.res.golden.json rename to enumeration/remote/test/aws_sns_topic_policy_multiple/aws_sns_topic_policy-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy.res.golden.json diff --git a/pkg/remote/test/aws_sns_topic_policy_multiple/aws_sns_topic_policy-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy2-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy2.res.golden.json b/enumeration/remote/test/aws_sns_topic_policy_multiple/aws_sns_topic_policy-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy2-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy2.res.golden.json similarity index 100% rename from pkg/remote/test/aws_sns_topic_policy_multiple/aws_sns_topic_policy-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy2-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy2.res.golden.json rename to enumeration/remote/test/aws_sns_topic_policy_multiple/aws_sns_topic_policy-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy2-arn_aws_sns_us-east-1_526954929923_my-topic-with-policy2.res.golden.json diff --git a/pkg/remote/test/aws_sns_topic_policy_multiple/results.golden.json b/enumeration/remote/test/aws_sns_topic_policy_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_sns_topic_policy_multiple/results.golden.json rename to enumeration/remote/test/aws_sns_topic_policy_multiple/results.golden.json diff --git a/pkg/remote/test/aws_sns_topic_subscription_multiple/aws_sns_topic_subscription-arn_aws_sns_us-east-1_526954929923_user-updates-topic2_c0f794c5-a009-4db4-9147-4c55959787fa-arn_aws_sns_us-east-1_526954929923_user-updates-topic2_c0f794c5-a009-4db4-9147-4c55959787fa.res.golden.json b/enumeration/remote/test/aws_sns_topic_subscription_multiple/aws_sns_topic_subscription-arn_aws_sns_us-east-1_526954929923_user-updates-topic2_c0f794c5-a009-4db4-9147-4c55959787fa-arn_aws_sns_us-east-1_526954929923_user-updates-topic2_c0f794c5-a009-4db4-9147-4c55959787fa.res.golden.json similarity index 100% rename from pkg/remote/test/aws_sns_topic_subscription_multiple/aws_sns_topic_subscription-arn_aws_sns_us-east-1_526954929923_user-updates-topic2_c0f794c5-a009-4db4-9147-4c55959787fa-arn_aws_sns_us-east-1_526954929923_user-updates-topic2_c0f794c5-a009-4db4-9147-4c55959787fa.res.golden.json rename to enumeration/remote/test/aws_sns_topic_subscription_multiple/aws_sns_topic_subscription-arn_aws_sns_us-east-1_526954929923_user-updates-topic2_c0f794c5-a009-4db4-9147-4c55959787fa-arn_aws_sns_us-east-1_526954929923_user-updates-topic2_c0f794c5-a009-4db4-9147-4c55959787fa.res.golden.json diff --git a/pkg/remote/test/aws_sns_topic_subscription_multiple/aws_sns_topic_subscription-arn_aws_sns_us-east-1_526954929923_user-updates-topic_b6e66147-2b31-4486-8d4b-2a2272264c8e-arn_aws_sns_us-east-1_526954929923_user-updates-topic_b6e66147-2b31-4486-8d4b-2a2272264c8e.res.golden.json b/enumeration/remote/test/aws_sns_topic_subscription_multiple/aws_sns_topic_subscription-arn_aws_sns_us-east-1_526954929923_user-updates-topic_b6e66147-2b31-4486-8d4b-2a2272264c8e-arn_aws_sns_us-east-1_526954929923_user-updates-topic_b6e66147-2b31-4486-8d4b-2a2272264c8e.res.golden.json similarity index 100% rename from pkg/remote/test/aws_sns_topic_subscription_multiple/aws_sns_topic_subscription-arn_aws_sns_us-east-1_526954929923_user-updates-topic_b6e66147-2b31-4486-8d4b-2a2272264c8e-arn_aws_sns_us-east-1_526954929923_user-updates-topic_b6e66147-2b31-4486-8d4b-2a2272264c8e.res.golden.json rename to enumeration/remote/test/aws_sns_topic_subscription_multiple/aws_sns_topic_subscription-arn_aws_sns_us-east-1_526954929923_user-updates-topic_b6e66147-2b31-4486-8d4b-2a2272264c8e-arn_aws_sns_us-east-1_526954929923_user-updates-topic_b6e66147-2b31-4486-8d4b-2a2272264c8e.res.golden.json diff --git a/pkg/remote/test/aws_sns_topic_subscription_multiple/main.tf b/enumeration/remote/test/aws_sns_topic_subscription_multiple/main.tf similarity index 100% rename from pkg/remote/test/aws_sns_topic_subscription_multiple/main.tf rename to enumeration/remote/test/aws_sns_topic_subscription_multiple/main.tf diff --git a/pkg/remote/test/aws_sns_topic_subscription_multiple/results.golden.json b/enumeration/remote/test/aws_sns_topic_subscription_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_sns_topic_subscription_multiple/results.golden.json rename to enumeration/remote/test/aws_sns_topic_subscription_multiple/results.golden.json diff --git a/pkg/remote/test/aws_sqs_queue_multiple/aws_sqs_queue-https___sqs.eu-west-3.amazonaws.com_047081014315_bar.fifo.res.golden.json b/enumeration/remote/test/aws_sqs_queue_multiple/aws_sqs_queue-https___sqs.eu-west-3.amazonaws.com_047081014315_bar.fifo.res.golden.json similarity index 100% rename from pkg/remote/test/aws_sqs_queue_multiple/aws_sqs_queue-https___sqs.eu-west-3.amazonaws.com_047081014315_bar.fifo.res.golden.json rename to enumeration/remote/test/aws_sqs_queue_multiple/aws_sqs_queue-https___sqs.eu-west-3.amazonaws.com_047081014315_bar.fifo.res.golden.json diff --git a/pkg/remote/test/aws_sqs_queue_multiple/aws_sqs_queue-https___sqs.eu-west-3.amazonaws.com_047081014315_foo.res.golden.json b/enumeration/remote/test/aws_sqs_queue_multiple/aws_sqs_queue-https___sqs.eu-west-3.amazonaws.com_047081014315_foo.res.golden.json similarity index 100% rename from pkg/remote/test/aws_sqs_queue_multiple/aws_sqs_queue-https___sqs.eu-west-3.amazonaws.com_047081014315_foo.res.golden.json rename to enumeration/remote/test/aws_sqs_queue_multiple/aws_sqs_queue-https___sqs.eu-west-3.amazonaws.com_047081014315_foo.res.golden.json diff --git a/pkg/remote/test/aws_sqs_queue_multiple/results.golden.json b/enumeration/remote/test/aws_sqs_queue_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_sqs_queue_multiple/results.golden.json rename to enumeration/remote/test/aws_sqs_queue_multiple/results.golden.json diff --git a/pkg/remote/test/aws_sqs_queue_multiple/terraform.tf b/enumeration/remote/test/aws_sqs_queue_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_sqs_queue_multiple/terraform.tf rename to enumeration/remote/test/aws_sqs_queue_multiple/terraform.tf diff --git a/pkg/remote/test/aws_sqs_queue_policy_multiple/aws_sqs_queue_policy-https___sqs.eu-west-3.amazonaws.com_047081014315_bar.fifo.res.golden.json b/enumeration/remote/test/aws_sqs_queue_policy_multiple/aws_sqs_queue_policy-https___sqs.eu-west-3.amazonaws.com_047081014315_bar.fifo.res.golden.json similarity index 100% rename from pkg/remote/test/aws_sqs_queue_policy_multiple/aws_sqs_queue_policy-https___sqs.eu-west-3.amazonaws.com_047081014315_bar.fifo.res.golden.json rename to enumeration/remote/test/aws_sqs_queue_policy_multiple/aws_sqs_queue_policy-https___sqs.eu-west-3.amazonaws.com_047081014315_bar.fifo.res.golden.json diff --git a/pkg/remote/test/aws_sqs_queue_policy_multiple/aws_sqs_queue_policy-https___sqs.eu-west-3.amazonaws.com_047081014315_baz.res.golden.json b/enumeration/remote/test/aws_sqs_queue_policy_multiple/aws_sqs_queue_policy-https___sqs.eu-west-3.amazonaws.com_047081014315_baz.res.golden.json similarity index 100% rename from pkg/remote/test/aws_sqs_queue_policy_multiple/aws_sqs_queue_policy-https___sqs.eu-west-3.amazonaws.com_047081014315_baz.res.golden.json rename to enumeration/remote/test/aws_sqs_queue_policy_multiple/aws_sqs_queue_policy-https___sqs.eu-west-3.amazonaws.com_047081014315_baz.res.golden.json diff --git a/pkg/remote/test/aws_sqs_queue_policy_multiple/aws_sqs_queue_policy-https___sqs.eu-west-3.amazonaws.com_047081014315_foo.res.golden.json b/enumeration/remote/test/aws_sqs_queue_policy_multiple/aws_sqs_queue_policy-https___sqs.eu-west-3.amazonaws.com_047081014315_foo.res.golden.json similarity index 100% rename from pkg/remote/test/aws_sqs_queue_policy_multiple/aws_sqs_queue_policy-https___sqs.eu-west-3.amazonaws.com_047081014315_foo.res.golden.json rename to enumeration/remote/test/aws_sqs_queue_policy_multiple/aws_sqs_queue_policy-https___sqs.eu-west-3.amazonaws.com_047081014315_foo.res.golden.json diff --git a/pkg/remote/test/aws_sqs_queue_policy_multiple/policy.json b/enumeration/remote/test/aws_sqs_queue_policy_multiple/policy.json similarity index 100% rename from pkg/remote/test/aws_sqs_queue_policy_multiple/policy.json rename to enumeration/remote/test/aws_sqs_queue_policy_multiple/policy.json diff --git a/pkg/remote/test/aws_sqs_queue_policy_multiple/results.golden.json b/enumeration/remote/test/aws_sqs_queue_policy_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_sqs_queue_policy_multiple/results.golden.json rename to enumeration/remote/test/aws_sqs_queue_policy_multiple/results.golden.json diff --git a/pkg/remote/test/aws_sqs_queue_policy_multiple/terraform.tf b/enumeration/remote/test/aws_sqs_queue_policy_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_sqs_queue_policy_multiple/terraform.tf rename to enumeration/remote/test/aws_sqs_queue_policy_multiple/terraform.tf diff --git a/pkg/remote/test/aws_vpc/aws_vpc-vpc-020b072316a95b97f.res.golden.json b/enumeration/remote/test/aws_vpc/aws_vpc-vpc-020b072316a95b97f.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc/aws_vpc-vpc-020b072316a95b97f.res.golden.json rename to enumeration/remote/test/aws_vpc/aws_vpc-vpc-020b072316a95b97f.res.golden.json diff --git a/pkg/remote/test/aws_vpc/aws_vpc-vpc-02c50896b59598761.res.golden.json b/enumeration/remote/test/aws_vpc/aws_vpc-vpc-02c50896b59598761.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc/aws_vpc-vpc-02c50896b59598761.res.golden.json rename to enumeration/remote/test/aws_vpc/aws_vpc-vpc-02c50896b59598761.res.golden.json diff --git a/pkg/remote/test/aws_vpc/aws_vpc-vpc-0768e1fd0029e3fc3.res.golden.json b/enumeration/remote/test/aws_vpc/aws_vpc-vpc-0768e1fd0029e3fc3.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc/aws_vpc-vpc-0768e1fd0029e3fc3.res.golden.json rename to enumeration/remote/test/aws_vpc/aws_vpc-vpc-0768e1fd0029e3fc3.res.golden.json diff --git a/pkg/remote/test/aws_vpc/aws_vpc-vpc-a8c5d4c1.res.golden.json b/enumeration/remote/test/aws_vpc/aws_vpc-vpc-a8c5d4c1.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc/aws_vpc-vpc-a8c5d4c1.res.golden.json rename to enumeration/remote/test/aws_vpc/aws_vpc-vpc-a8c5d4c1.res.golden.json diff --git a/pkg/remote/test/aws_vpc/results.golden.json b/enumeration/remote/test/aws_vpc/results.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc/results.golden.json rename to enumeration/remote/test/aws_vpc/results.golden.json diff --git a/pkg/remote/test/aws_vpc/terraform.tf b/enumeration/remote/test/aws_vpc/terraform.tf similarity index 100% rename from pkg/remote/test/aws_vpc/terraform.tf rename to enumeration/remote/test/aws_vpc/terraform.tf diff --git a/pkg/remote/test/aws_vpc_default_security_group_multiple/aws_default_security_group-sg-9e0204ff.res.golden.json b/enumeration/remote/test/aws_vpc_default_security_group_multiple/aws_default_security_group-sg-9e0204ff.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_default_security_group_multiple/aws_default_security_group-sg-9e0204ff.res.golden.json rename to enumeration/remote/test/aws_vpc_default_security_group_multiple/aws_default_security_group-sg-9e0204ff.res.golden.json diff --git a/pkg/remote/test/aws_vpc_default_security_group_multiple/results.golden.json b/enumeration/remote/test/aws_vpc_default_security_group_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_default_security_group_multiple/results.golden.json rename to enumeration/remote/test/aws_vpc_default_security_group_multiple/results.golden.json diff --git a/pkg/remote/test/aws_vpc_default_security_group_multiple/terraform.tf b/enumeration/remote/test/aws_vpc_default_security_group_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_vpc_default_security_group_multiple/terraform.tf rename to enumeration/remote/test/aws_vpc_default_security_group_multiple/terraform.tf diff --git a/pkg/remote/test/aws_vpc_security_group_default_rules/aws_security_group_rule-sgrule-3820791514-1-1.2.0.0_16-0-0-0--1-sg-a74815c8-false--0-ingress.res.golden.json b/enumeration/remote/test/aws_vpc_security_group_default_rules/aws_security_group_rule-sgrule-3820791514-1-1.2.0.0_16-0-0-0--1-sg-a74815c8-false--0-ingress.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_default_rules/aws_security_group_rule-sgrule-3820791514-1-1.2.0.0_16-0-0-0--1-sg-a74815c8-false--0-ingress.res.golden.json rename to enumeration/remote/test/aws_vpc_security_group_default_rules/aws_security_group_rule-sgrule-3820791514-1-1.2.0.0_16-0-0-0--1-sg-a74815c8-false--0-ingress.res.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_default_rules/aws_security_group_rule-sgrule-529352477-1-1.2.3.4_32-0-0-0--1-sg-a74815c8-false--0-egress.res.golden.json b/enumeration/remote/test/aws_vpc_security_group_default_rules/aws_security_group_rule-sgrule-529352477-1-1.2.3.4_32-0-0-0--1-sg-a74815c8-false--0-egress.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_default_rules/aws_security_group_rule-sgrule-529352477-1-1.2.3.4_32-0-0-0--1-sg-a74815c8-false--0-egress.res.golden.json rename to enumeration/remote/test/aws_vpc_security_group_default_rules/aws_security_group_rule-sgrule-529352477-1-1.2.3.4_32-0-0-0--1-sg-a74815c8-false--0-egress.res.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_default_rules/results.golden.json b/enumeration/remote/test/aws_vpc_security_group_default_rules/results.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_default_rules/results.golden.json rename to enumeration/remote/test/aws_vpc_security_group_default_rules/results.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_multiple/aws_security_group-sg-0254c038e32f25530.res.golden.json b/enumeration/remote/test/aws_vpc_security_group_multiple/aws_security_group-sg-0254c038e32f25530.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_multiple/aws_security_group-sg-0254c038e32f25530.res.golden.json rename to enumeration/remote/test/aws_vpc_security_group_multiple/aws_security_group-sg-0254c038e32f25530.res.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_multiple/results.golden.json b/enumeration/remote/test/aws_vpc_security_group_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_multiple/results.golden.json rename to enumeration/remote/test/aws_vpc_security_group_multiple/results.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_multiple/terraform.tf b/enumeration/remote/test/aws_vpc_security_group_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_multiple/terraform.tf rename to enumeration/remote/test/aws_vpc_security_group_multiple/terraform.tf diff --git a/pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-1175318309-1-0.0.0.0_0-0--1-sg-0cc8b3c3c2851705a-false-0-egress.res.golden.json b/enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-1175318309-1-0.0.0.0_0-0--1-sg-0cc8b3c3c2851705a-false--0-egress.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-1175318309-1-0.0.0.0_0-0--1-sg-0cc8b3c3c2851705a-false-0-egress.res.golden.json rename to enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-1175318309-1-0.0.0.0_0-0--1-sg-0cc8b3c3c2851705a-false--0-egress.res.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-1707973622-1-0.0.0.0_0-0--1-sg-0254c038e32f25530-false-0-egress.res.golden.json b/enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-1707973622-1-0.0.0.0_0-0--1-sg-0254c038e32f25530-false--0-egress.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-1707973622-1-0.0.0.0_0-0--1-sg-0254c038e32f25530-false-0-egress.res.golden.json rename to enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-1707973622-1-0.0.0.0_0-0--1-sg-0254c038e32f25530-false--0-egress.res.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-2165103420-1-5.6.7.0_24-0--1-sg-0254c038e32f25530-false-0-ingress.res.golden.json b/enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-2165103420-1-5.6.7.0_24-0--1-sg-0254c038e32f25530-false--0-ingress.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-2165103420-1-5.6.7.0_24-0--1-sg-0254c038e32f25530-false-0-ingress.res.golden.json rename to enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-2165103420-1-5.6.7.0_24-0--1-sg-0254c038e32f25530-false--0-ingress.res.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-2582518759-1-1.2.0.0_16-0--1-sg-0254c038e32f25530-false-0-ingress.res.golden.json b/enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-2582518759-1-1.2.0.0_16-0--1-sg-0254c038e32f25530-false--0-ingress.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-2582518759-1-1.2.0.0_16-0--1-sg-0254c038e32f25530-false-0-ingress.res.golden.json rename to enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-2582518759-1-1.2.0.0_16-0--1-sg-0254c038e32f25530-false--0-ingress.res.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-2821752134-0-1-___0--1-sg-0254c038e32f25530-false-0-egress.res.golden.json b/enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-2821752134-0-1-___0--1-sg-0254c038e32f25530-false--0-egress.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-2821752134-0-1-___0--1-sg-0254c038e32f25530-false-0-egress.res.golden.json rename to enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-2821752134-0-1-___0--1-sg-0254c038e32f25530-false--0-egress.res.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-350400929-0-1-___0--1-sg-0cc8b3c3c2851705a-false-0-egress.res.golden.json b/enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-350400929-0-1-___0--1-sg-0cc8b3c3c2851705a-false--0-egress.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-350400929-0-1-___0--1-sg-0cc8b3c3c2851705a-false-0-egress.res.golden.json rename to enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-350400929-0-1-___0--1-sg-0cc8b3c3c2851705a-false--0-egress.res.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-3587309474-0-tcp-sg-0254c038e32f25530-false-sg-9e0204ff-65535-ingress.res.golden.json b/enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-3587309474-0-tcp-sg-0254c038e32f25530-false-sg-9e0204ff-65535-ingress.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-3587309474-0-tcp-sg-0254c038e32f25530-false-sg-9e0204ff-65535-ingress.res.golden.json rename to enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-3587309474-0-tcp-sg-0254c038e32f25530-false-sg-9e0204ff-65535-ingress.res.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-3970541193-0-tcp-sg-0254c038e32f25530-true-65535-ingress.res.golden.json b/enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-3970541193-0-tcp-sg-0254c038e32f25530-true--65535-ingress.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-3970541193-0-tcp-sg-0254c038e32f25530-true-65535-ingress.res.golden.json rename to enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-3970541193-0-tcp-sg-0254c038e32f25530-true--65535-ingress.res.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-453320892-1-0.0.0.0_0-443-tcp-sg-0cc8b3c3c2851705a-false-443-ingress.res.golden.json b/enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-453320892-1-0.0.0.0_0-443-tcp-sg-0cc8b3c3c2851705a-false--443-ingress.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-453320892-1-0.0.0.0_0-443-tcp-sg-0cc8b3c3c2851705a-false-443-ingress.res.golden.json rename to enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-453320892-1-0.0.0.0_0-443-tcp-sg-0cc8b3c3c2851705a-false--443-ingress.res.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-674800228-0-1-___0--1-sg-0254c038e32f25530-false-0-ingress.res.golden.json b/enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-674800228-0-1-___0--1-sg-0254c038e32f25530-false--0-ingress.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-674800228-0-1-___0--1-sg-0254c038e32f25530-false-0-ingress.res.golden.json rename to enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-674800228-0-1-___0--1-sg-0254c038e32f25530-false--0-ingress.res.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-850043874-1-0.0.0.0_0-5-sg-0cc8b3c3c2851705a-false-egress.res.golden.json b/enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-850043874-1-0.0.0.0_0-0-5-sg-0cc8b3c3c2851705a-false--0-egress.res.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-850043874-1-0.0.0.0_0-5-sg-0cc8b3c3c2851705a-false-egress.res.golden.json rename to enumeration/remote/test/aws_vpc_security_group_rule_multiple/aws_security_group_rule-sgrule-850043874-1-0.0.0.0_0-0-5-sg-0cc8b3c3c2851705a-false--0-egress.res.golden.json diff --git a/pkg/remote/test/aws_vpc_security_group_rule_multiple/results.golden.json b/enumeration/remote/test/aws_vpc_security_group_rule_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/aws_vpc_security_group_rule_multiple/results.golden.json rename to enumeration/remote/test/aws_vpc_security_group_rule_multiple/results.golden.json diff --git a/pkg/remote/test/azurerm_lb_rule_multiple/17c4e2b7f4cc466670ccbd8dd1d506289fd818c0.res.golden.json b/enumeration/remote/test/azurerm_lb_rule_multiple/17c4e2b7f4cc466670ccbd8dd1d506289fd818c0.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_lb_rule_multiple/17c4e2b7f4cc466670ccbd8dd1d506289fd818c0.res.golden.json rename to enumeration/remote/test/azurerm_lb_rule_multiple/17c4e2b7f4cc466670ccbd8dd1d506289fd818c0.res.golden.json diff --git a/pkg/remote/test/azurerm_lb_rule_multiple/2f725aa2c71898229fcfa091837deaa387bee9d8.res.golden.json b/enumeration/remote/test/azurerm_lb_rule_multiple/2f725aa2c71898229fcfa091837deaa387bee9d8.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_lb_rule_multiple/2f725aa2c71898229fcfa091837deaa387bee9d8.res.golden.json rename to enumeration/remote/test/azurerm_lb_rule_multiple/2f725aa2c71898229fcfa091837deaa387bee9d8.res.golden.json diff --git a/enumeration/remote/test/azurerm_lb_rule_multiple/results.golden.json b/enumeration/remote/test/azurerm_lb_rule_multiple/results.golden.json new file mode 100755 index 00000000..6266d781 --- /dev/null +++ b/enumeration/remote/test/azurerm_lb_rule_multiple/results.golden.json @@ -0,0 +1,40 @@ +[ + { + "backend_address_pool_id": "", + "backend_port": 80, + "disable_outbound_snat": false, + "enable_floating_ip": false, + "enable_tcp_reset": false, + "frontend_ip_configuration_id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress", + "frontend_ip_configuration_name": "PublicIPAddress", + "frontend_port": 80, + "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/loadBalancingRules/LBRule2", + "idle_timeout_in_minutes": 4, + "load_distribution": "Default", + "loadbalancer_id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer", + "name": "LBRule2", + "probe_id": "", + "protocol": "Tcp", + "resource_group_name": "raphael-dev", + "timeouts": {} + }, + { + "backend_address_pool_id": "", + "backend_port": 3389, + "disable_outbound_snat": false, + "enable_floating_ip": false, + "enable_tcp_reset": false, + "frontend_ip_configuration_id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress", + "frontend_ip_configuration_name": "PublicIPAddress", + "frontend_port": 3389, + "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/loadBalancingRules/LBRule", + "idle_timeout_in_minutes": 4, + "load_distribution": "Default", + "loadbalancer_id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer", + "name": "LBRule", + "probe_id": "", + "protocol": "Tcp", + "resource_group_name": "raphael-dev", + "timeouts": {} + } +] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_network_security_group_multiple/azurerm_network_security_group-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_example-resources_providers_Microsoft.Network_networkSecurityGroups_acceptanceTestSecurityGroup1.res.golden.json b/enumeration/remote/test/azurerm_network_security_group_multiple/azurerm_network_security_group-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_example-resources_providers_Microsoft.Network_networkSecurityGroups_acceptanceTestSecurityGroup1.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_network_security_group_multiple/azurerm_network_security_group-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_example-resources_providers_Microsoft.Network_networkSecurityGroups_acceptanceTestSecurityGroup1.res.golden.json rename to enumeration/remote/test/azurerm_network_security_group_multiple/azurerm_network_security_group-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_example-resources_providers_Microsoft.Network_networkSecurityGroups_acceptanceTestSecurityGroup1.res.golden.json diff --git a/pkg/remote/test/azurerm_network_security_group_multiple/azurerm_network_security_group-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_example-resources_providers_Microsoft.Network_networkSecurityGroups_acceptanceTestSecurityGroup2.res.golden.json b/enumeration/remote/test/azurerm_network_security_group_multiple/azurerm_network_security_group-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_example-resources_providers_Microsoft.Network_networkSecurityGroups_acceptanceTestSecurityGroup2.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_network_security_group_multiple/azurerm_network_security_group-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_example-resources_providers_Microsoft.Network_networkSecurityGroups_acceptanceTestSecurityGroup2.res.golden.json rename to enumeration/remote/test/azurerm_network_security_group_multiple/azurerm_network_security_group-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_example-resources_providers_Microsoft.Network_networkSecurityGroups_acceptanceTestSecurityGroup2.res.golden.json diff --git a/enumeration/remote/test/azurerm_network_security_group_multiple/results.golden.json b/enumeration/remote/test/azurerm_network_security_group_multiple/results.golden.json new file mode 100755 index 00000000..3daa87ad --- /dev/null +++ b/enumeration/remote/test/azurerm_network_security_group_multiple/results.golden.json @@ -0,0 +1,41 @@ +[ + { + "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/example-resources/providers/Microsoft.Network/networkSecurityGroups/acceptanceTestSecurityGroup2", + "location": "westeurope", + "name": "acceptanceTestSecurityGroup2", + "resource_group_name": "example-resources", + "security_rule": null, + "tags": null, + "timeouts": {} + }, + { + "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/example-resources/providers/Microsoft.Network/networkSecurityGroups/acceptanceTestSecurityGroup1", + "location": "westeurope", + "name": "acceptanceTestSecurityGroup1", + "resource_group_name": "example-resources", + "security_rule": [ + { + "access": "Allow", + "description": "", + "destination_address_prefix": "*", + "destination_address_prefixes": null, + "destination_application_security_group_ids": null, + "destination_port_range": "*", + "destination_port_ranges": null, + "direction": "Inbound", + "name": "test123", + "priority": 100, + "protocol": "Tcp", + "source_address_prefix": "*", + "source_address_prefixes": null, + "source_application_security_group_ids": null, + "source_port_range": "*", + "source_port_ranges": null + } + ], + "tags": { + "environment": "Production" + }, + "timeouts": {} + } +] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_a_record_multiple/azurerm_private_dns_a_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_A_othertest.res.golden.json b/enumeration/remote/test/azurerm_private_dns_a_record_multiple/azurerm_private_dns_a_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_A_othertest.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_a_record_multiple/azurerm_private_dns_a_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_A_othertest.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_a_record_multiple/azurerm_private_dns_a_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_A_othertest.res.golden.json diff --git a/pkg/remote/test/azurerm_private_dns_a_record_multiple/azurerm_private_dns_a_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_A_test.res.golden.json b/enumeration/remote/test/azurerm_private_dns_a_record_multiple/azurerm_private_dns_a_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_A_test.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_a_record_multiple/azurerm_private_dns_a_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_A_test.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_a_record_multiple/azurerm_private_dns_a_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_A_test.res.golden.json diff --git a/enumeration/remote/test/azurerm_private_dns_a_record_multiple/results.golden.json b/enumeration/remote/test/azurerm_private_dns_a_record_multiple/results.golden.json new file mode 100755 index 00000000..f0b58fc9 --- /dev/null +++ b/enumeration/remote/test/azurerm_private_dns_a_record_multiple/results.golden.json @@ -0,0 +1,29 @@ +[ + { + "fqdn": "test.thisisatestusingtf.com.", + "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/A/test", + "name": "test", + "records": [ + "10.0.180.17", + "10.0.180.20" + ], + "resource_group_name": "martin-dev", + "tags": null, + "timeouts": {}, + "ttl": 300, + "zone_name": "thisisatestusingtf.com" + }, + { + "fqdn": "othertest.thisisatestusingtf.com.", + "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/A/othertest", + "name": "othertest", + "records": [ + "10.0.180.20" + ], + "resource_group_name": "martin-dev", + "tags": null, + "timeouts": {}, + "ttl": 300, + "zone_name": "thisisatestusingtf.com" + } +] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_aaaaa_record_multiple/azurerm_private_dns_aaaa_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_AAAA_othertest.res.golden.json b/enumeration/remote/test/azurerm_private_dns_aaaaa_record_multiple/azurerm_private_dns_aaaa_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_AAAA_othertest.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_aaaaa_record_multiple/azurerm_private_dns_aaaa_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_AAAA_othertest.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_aaaaa_record_multiple/azurerm_private_dns_aaaa_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_AAAA_othertest.res.golden.json diff --git a/pkg/remote/test/azurerm_private_dns_aaaaa_record_multiple/azurerm_private_dns_aaaa_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_AAAA_test.res.golden.json b/enumeration/remote/test/azurerm_private_dns_aaaaa_record_multiple/azurerm_private_dns_aaaa_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_AAAA_test.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_aaaaa_record_multiple/azurerm_private_dns_aaaa_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_AAAA_test.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_aaaaa_record_multiple/azurerm_private_dns_aaaa_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_AAAA_test.res.golden.json diff --git a/enumeration/remote/test/azurerm_private_dns_aaaaa_record_multiple/results.golden.json b/enumeration/remote/test/azurerm_private_dns_aaaaa_record_multiple/results.golden.json new file mode 100755 index 00000000..cc30ef5a --- /dev/null +++ b/enumeration/remote/test/azurerm_private_dns_aaaaa_record_multiple/results.golden.json @@ -0,0 +1,30 @@ +[ + { + "fqdn": "othertest.thisisatestusingtf.com.", + "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/AAAA/othertest", + "name": "othertest", + "records": [ + "fd5d:70bc:930e:d008:0000:0000:0000:7334", + "fd5d:70bc:930e:d008::7335" + ], + "resource_group_name": "martin-dev", + "tags": null, + "timeouts": {}, + "ttl": 300, + "zone_name": "thisisatestusingtf.com" + }, + { + "fqdn": "test.thisisatestusingtf.com.", + "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/AAAA/test", + "name": "test", + "records": [ + "fd5d:70bc:930e:d008:0000:0000:0000:7334", + "fd5d:70bc:930e:d008::7335" + ], + "resource_group_name": "martin-dev", + "tags": null, + "timeouts": {}, + "ttl": 300, + "zone_name": "thisisatestusingtf.com" + } +] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_cname_record_multiple/azurerm_private_dns_cname_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_CNAME_othertest.res.golden.json b/enumeration/remote/test/azurerm_private_dns_cname_record_multiple/azurerm_private_dns_cname_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_CNAME_othertest.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_cname_record_multiple/azurerm_private_dns_cname_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_CNAME_othertest.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_cname_record_multiple/azurerm_private_dns_cname_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_CNAME_othertest.res.golden.json diff --git a/pkg/remote/test/azurerm_private_dns_cname_record_multiple/azurerm_private_dns_cname_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_CNAME_test.res.golden.json b/enumeration/remote/test/azurerm_private_dns_cname_record_multiple/azurerm_private_dns_cname_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_CNAME_test.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_cname_record_multiple/azurerm_private_dns_cname_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_CNAME_test.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_cname_record_multiple/azurerm_private_dns_cname_record-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_CNAME_test.res.golden.json diff --git a/enumeration/remote/test/azurerm_private_dns_cname_record_multiple/results.golden.json b/enumeration/remote/test/azurerm_private_dns_cname_record_multiple/results.golden.json new file mode 100755 index 00000000..db9064a6 --- /dev/null +++ b/enumeration/remote/test/azurerm_private_dns_cname_record_multiple/results.golden.json @@ -0,0 +1,24 @@ +[ + { + "fqdn": "test.thisisatestusingtf.com.", + "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/CNAME/test", + "name": "test", + "record": "test.com", + "resource_group_name": "martin-dev", + "tags": null, + "timeouts": {}, + "ttl": 300, + "zone_name": "thisisatestusingtf.com" + }, + { + "fqdn": "othertest.thisisatestusingtf.com.", + "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/CNAME/othertest", + "name": "othertest", + "record": "othertest.com", + "resource_group_name": "martin-dev", + "tags": null, + "timeouts": {}, + "ttl": 300, + "zone_name": "thisisatestusingtf.com" + } +] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_mx_record_multiple/azurerm_private_dns_mx_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_MX_othertestmx.res.golden.json b/enumeration/remote/test/azurerm_private_dns_mx_record_multiple/azurerm_private_dns_mx_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_MX_othertestmx.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_mx_record_multiple/azurerm_private_dns_mx_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_MX_othertestmx.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_mx_record_multiple/azurerm_private_dns_mx_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_MX_othertestmx.res.golden.json diff --git a/pkg/remote/test/azurerm_private_dns_mx_record_multiple/azurerm_private_dns_mx_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_MX_testmx.res.golden.json b/enumeration/remote/test/azurerm_private_dns_mx_record_multiple/azurerm_private_dns_mx_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_MX_testmx.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_mx_record_multiple/azurerm_private_dns_mx_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_MX_testmx.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_mx_record_multiple/azurerm_private_dns_mx_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_MX_testmx.res.golden.json diff --git a/enumeration/remote/test/azurerm_private_dns_mx_record_multiple/results.golden.json b/enumeration/remote/test/azurerm_private_dns_mx_record_multiple/results.golden.json new file mode 100755 index 00000000..c0f56139 --- /dev/null +++ b/enumeration/remote/test/azurerm_private_dns_mx_record_multiple/results.golden.json @@ -0,0 +1,38 @@ +[ + { + "fqdn": "testmx.thisisatestusingtf.com.", + "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/MX/testmx", + "name": "testmx", + "record": [ + { + "exchange": "bkpmx.thisisatestusingtf.com", + "preference": 30 + } + ], + "resource_group_name": "martin-dev", + "tags": null, + "timeouts": {}, + "ttl": 300, + "zone_name": "thisisatestusingtf.com" + }, + { + "fqdn": "othertestmx.thisisatestusingtf.com.", + "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/MX/othertestmx", + "name": "othertestmx", + "record": [ + { + "exchange": "backupmx.thisisatestusingtf.com", + "preference": 20 + }, + { + "exchange": "mx.thisisatestusingtf.com", + "preference": 10 + } + ], + "resource_group_name": "martin-dev", + "tags": null, + "timeouts": {}, + "ttl": 300, + "zone_name": "thisisatestusingtf.com" + } +] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_private_zone_multiple/azurerm_private_dns_zone-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_testmartin.com.res.golden.json b/enumeration/remote/test/azurerm_private_dns_private_zone_multiple/azurerm_private_dns_zone-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_testmartin.com.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_private_zone_multiple/azurerm_private_dns_zone-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_testmartin.com.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_private_zone_multiple/azurerm_private_dns_zone-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_testmartin.com.res.golden.json diff --git a/pkg/remote/test/azurerm_private_dns_private_zone_multiple/azurerm_private_dns_zone-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com.res.golden.json b/enumeration/remote/test/azurerm_private_dns_private_zone_multiple/azurerm_private_dns_zone-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_private_zone_multiple/azurerm_private_dns_zone-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_private_zone_multiple/azurerm_private_dns_zone-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com.res.golden.json diff --git a/pkg/remote/test/azurerm_private_dns_private_zone_multiple/azurerm_private_dns_zone-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf2.com.res.golden.json b/enumeration/remote/test/azurerm_private_dns_private_zone_multiple/azurerm_private_dns_zone-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf2.com.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_private_zone_multiple/azurerm_private_dns_zone-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf2.com.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_private_zone_multiple/azurerm_private_dns_zone-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf2.com.res.golden.json diff --git a/enumeration/remote/test/azurerm_private_dns_private_zone_multiple/results.golden.json b/enumeration/remote/test/azurerm_private_dns_private_zone_multiple/results.golden.json new file mode 100755 index 00000000..a1065965 --- /dev/null +++ b/enumeration/remote/test/azurerm_private_dns_private_zone_multiple/results.golden.json @@ -0,0 +1,79 @@ +[ + { + "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf2.com", + "max_number_of_record_sets": 25000, + "max_number_of_virtual_network_links": 1000, + "max_number_of_virtual_network_links_with_registration": 100, + "name": "thisisatestusingtf2.com", + "number_of_record_sets": 1, + "resource_group_name": "martin-dev", + "soa_record": [ + { + "email": "azureprivatedns-host.microsoft.com", + "expire_time": 2419200, + "fqdn": "thisisatestusingtf2.com.", + "host_name": "azureprivatedns.net", + "minimum_ttl": 10, + "refresh_time": 3600, + "retry_time": 300, + "serial_number": 1, + "tags": null, + "ttl": 3600 + } + ], + "tags": null, + "timeouts": {} + }, + { + "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/testmartin.com", + "max_number_of_record_sets": 25000, + "max_number_of_virtual_network_links": 1000, + "max_number_of_virtual_network_links_with_registration": 100, + "name": "testmartin.com", + "number_of_record_sets": 2, + "resource_group_name": "martin-dev", + "soa_record": [ + { + "email": "azureprivatedns-host.microsoft.com", + "expire_time": 2419200, + "fqdn": "testmartin.com.", + "host_name": "azureprivatedns.net", + "minimum_ttl": 10, + "refresh_time": 3600, + "retry_time": 300, + "serial_number": 1, + "tags": null, + "ttl": 3600 + } + ], + "tags": { + "test": "test" + }, + "timeouts": {} + }, + { + "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com", + "max_number_of_record_sets": 25000, + "max_number_of_virtual_network_links": 1000, + "max_number_of_virtual_network_links_with_registration": 100, + "name": "thisisatestusingtf.com", + "number_of_record_sets": 5, + "resource_group_name": "martin-dev", + "soa_record": [ + { + "email": "azureprivatedns-host.microsoft.com", + "expire_time": 2419200, + "fqdn": "thisisatestusingtf.com.", + "host_name": "azureprivatedns.net", + "minimum_ttl": 10, + "refresh_time": 3600, + "retry_time": 300, + "serial_number": 1, + "tags": null, + "ttl": 3600 + } + ], + "tags": null, + "timeouts": {} + } +] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_ptr_record_multiple/azurerm_private_dns_ptr_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_PTR_othertestptr.res.golden.json b/enumeration/remote/test/azurerm_private_dns_ptr_record_multiple/azurerm_private_dns_ptr_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_PTR_othertestptr.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_ptr_record_multiple/azurerm_private_dns_ptr_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_PTR_othertestptr.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_ptr_record_multiple/azurerm_private_dns_ptr_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_PTR_othertestptr.res.golden.json diff --git a/pkg/remote/test/azurerm_private_dns_ptr_record_multiple/azurerm_private_dns_ptr_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_PTR_testptr.res.golden.json b/enumeration/remote/test/azurerm_private_dns_ptr_record_multiple/azurerm_private_dns_ptr_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_PTR_testptr.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_ptr_record_multiple/azurerm_private_dns_ptr_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_PTR_testptr.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_ptr_record_multiple/azurerm_private_dns_ptr_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_PTR_testptr.res.golden.json diff --git a/enumeration/remote/test/azurerm_private_dns_ptr_record_multiple/results.golden.json b/enumeration/remote/test/azurerm_private_dns_ptr_record_multiple/results.golden.json new file mode 100755 index 00000000..3a42c39b --- /dev/null +++ b/enumeration/remote/test/azurerm_private_dns_ptr_record_multiple/results.golden.json @@ -0,0 +1,29 @@ +[ + { + "fqdn": "testptr.thisisatestusingtf.com.", + "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/PTR/testptr", + "name": "testptr", + "records": [ + "ptr3.thisisatestusingtf.com" + ], + "resource_group_name": "martin-dev", + "tags": null, + "timeouts": {}, + "ttl": 300, + "zone_name": "thisisatestusingtf.com" + }, + { + "fqdn": "othertestptr.thisisatestusingtf.com.", + "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/PTR/othertestptr", + "name": "othertestptr", + "records": [ + "ptr1.thisisatestusingtf.com", + "ptr2.thisisatestusingtf.com" + ], + "resource_group_name": "martin-dev", + "tags": null, + "timeouts": {}, + "ttl": 300, + "zone_name": "thisisatestusingtf.com" + } +] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_srv_record_multiple/azurerm_private_dns_srv_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_SRV_othertestptr.res.golden.json b/enumeration/remote/test/azurerm_private_dns_srv_record_multiple/azurerm_private_dns_srv_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_SRV_othertestptr.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_srv_record_multiple/azurerm_private_dns_srv_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_SRV_othertestptr.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_srv_record_multiple/azurerm_private_dns_srv_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_SRV_othertestptr.res.golden.json diff --git a/pkg/remote/test/azurerm_private_dns_srv_record_multiple/azurerm_private_dns_srv_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_SRV_testptr.res.golden.json b/enumeration/remote/test/azurerm_private_dns_srv_record_multiple/azurerm_private_dns_srv_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_SRV_testptr.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_srv_record_multiple/azurerm_private_dns_srv_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_SRV_testptr.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_srv_record_multiple/azurerm_private_dns_srv_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_SRV_testptr.res.golden.json diff --git a/enumeration/remote/test/azurerm_private_dns_srv_record_multiple/results.golden.json b/enumeration/remote/test/azurerm_private_dns_srv_record_multiple/results.golden.json new file mode 100755 index 00000000..08ca167d --- /dev/null +++ b/enumeration/remote/test/azurerm_private_dns_srv_record_multiple/results.golden.json @@ -0,0 +1,44 @@ +[ + { + "fqdn": "othertestptr.thisisatestusingtf.com.", + "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/SRV/othertestptr", + "name": "othertestptr", + "record": [ + { + "port": 8080, + "priority": 10, + "target": "srv2.thisisatestusingtf.com", + "weight": 10 + }, + { + "port": 8080, + "priority": 1, + "target": "srv1.thisisatestusingtf.com", + "weight": 5 + } + ], + "resource_group_name": "martin-dev", + "tags": null, + "timeouts": {}, + "ttl": 300, + "zone_name": "thisisatestusingtf.com" + }, + { + "fqdn": "testptr.thisisatestusingtf.com.", + "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/SRV/testptr", + "name": "testptr", + "record": [ + { + "port": 8080, + "priority": 20, + "target": "srv3.thisisatestusingtf.com", + "weight": 15 + } + ], + "resource_group_name": "martin-dev", + "tags": null, + "timeouts": {}, + "ttl": 300, + "zone_name": "thisisatestusingtf.com" + } +] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_txt_record_multiple/azurerm_private_dns_txt_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_TXT_othertesttxt.res.golden.json b/enumeration/remote/test/azurerm_private_dns_txt_record_multiple/azurerm_private_dns_txt_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_TXT_othertesttxt.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_txt_record_multiple/azurerm_private_dns_txt_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_TXT_othertesttxt.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_txt_record_multiple/azurerm_private_dns_txt_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_TXT_othertesttxt.res.golden.json diff --git a/pkg/remote/test/azurerm_private_dns_txt_record_multiple/azurerm_private_dns_txt_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_TXT_testtxt.res.golden.json b/enumeration/remote/test/azurerm_private_dns_txt_record_multiple/azurerm_private_dns_txt_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_TXT_testtxt.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_private_dns_txt_record_multiple/azurerm_private_dns_txt_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_TXT_testtxt.res.golden.json rename to enumeration/remote/test/azurerm_private_dns_txt_record_multiple/azurerm_private_dns_txt_record-_subscriptions_8cb43347-a79f-4bb2-a8b4-c838b41fa5a5_resourceGroups_martin-dev_providers_Microsoft.Network_privateDnsZones_thisisatestusingtf.com_TXT_testtxt.res.golden.json diff --git a/enumeration/remote/test/azurerm_private_dns_txt_record_multiple/results.golden.json b/enumeration/remote/test/azurerm_private_dns_txt_record_multiple/results.golden.json new file mode 100755 index 00000000..7dfe177f --- /dev/null +++ b/enumeration/remote/test/azurerm_private_dns_txt_record_multiple/results.golden.json @@ -0,0 +1,35 @@ +[ + { + "fqdn": "testtxt.thisisatestusingtf.com.", + "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/TXT/testtxt", + "name": "testtxt", + "record": [ + { + "value": "this is value line 3" + } + ], + "resource_group_name": "martin-dev", + "tags": null, + "timeouts": {}, + "ttl": 300, + "zone_name": "thisisatestusingtf.com" + }, + { + "fqdn": "othertesttxt.thisisatestusingtf.com.", + "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/TXT/othertesttxt", + "name": "othertesttxt", + "record": [ + { + "value": "this is value line 1" + }, + { + "value": "this is value line 2" + } + ], + "resource_group_name": "martin-dev", + "tags": null, + "timeouts": {}, + "ttl": 300, + "zone_name": "thisisatestusingtf.com" + } +] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_ssh_public_key_multiple/azurerm_ssh_public_key-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_TESTRESGROUP_providers_Microsoft.Compute_sshPublicKeys_example-key.res.golden.json b/enumeration/remote/test/azurerm_ssh_public_key_multiple/azurerm_ssh_public_key-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_TESTRESGROUP_providers_Microsoft.Compute_sshPublicKeys_example-key.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_ssh_public_key_multiple/azurerm_ssh_public_key-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_TESTRESGROUP_providers_Microsoft.Compute_sshPublicKeys_example-key.res.golden.json rename to enumeration/remote/test/azurerm_ssh_public_key_multiple/azurerm_ssh_public_key-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_TESTRESGROUP_providers_Microsoft.Compute_sshPublicKeys_example-key.res.golden.json diff --git a/pkg/remote/test/azurerm_ssh_public_key_multiple/azurerm_ssh_public_key-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_TESTRESGROUP_providers_Microsoft.Compute_sshPublicKeys_example-key2.res.golden.json b/enumeration/remote/test/azurerm_ssh_public_key_multiple/azurerm_ssh_public_key-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_TESTRESGROUP_providers_Microsoft.Compute_sshPublicKeys_example-key2.res.golden.json similarity index 100% rename from pkg/remote/test/azurerm_ssh_public_key_multiple/azurerm_ssh_public_key-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_TESTRESGROUP_providers_Microsoft.Compute_sshPublicKeys_example-key2.res.golden.json rename to enumeration/remote/test/azurerm_ssh_public_key_multiple/azurerm_ssh_public_key-_subscriptions_7bfb2c5c-7308-46ed-8ae4-fffa356eb406_resourceGroups_TESTRESGROUP_providers_Microsoft.Compute_sshPublicKeys_example-key2.res.golden.json diff --git a/enumeration/remote/test/azurerm_ssh_public_key_multiple/results.golden.json b/enumeration/remote/test/azurerm_ssh_public_key_multiple/results.golden.json new file mode 100755 index 00000000..4cef8a32 --- /dev/null +++ b/enumeration/remote/test/azurerm_ssh_public_key_multiple/results.golden.json @@ -0,0 +1,20 @@ +[ + { + "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/TESTRESGROUP/providers/Microsoft.Compute/sshPublicKeys/example-key2", + "location": "westeurope", + "name": "example-key2", + "public_key": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQCjeC5sO1EdEfZOrdVCpuOgXcXsKZg9zgfJbHgQgX1R2Nd8mNQrUjpsB4XLHNZ3T6UYrsSh7oxYC3UFu6peO4LmA2WTe2wCWVFn9WW/Lo99WcA/G/fGj6s5HK5CHFVPXnNxM47QJMNm5BWOM55+EWP839SHLH9Fk63H575x7jxZvBvaV0uL84XuVpiEBKhnpQfT4cJGoGLOGgjM+TpHyosbKldu5q2UTF9nOGpmLuku41oihqiPPSJnJRv3TDKFi4mIl9Iz5HJINWvLl1kdCfyjPcHcH5GO0tuA9rP5AbsmG5EAGOKtuFipYA4MyY9SYriZ2V1vpgefUS9lilg9hIPEj/8ZPTxf62XeyC1dQ3cOz6yPWR2sODyVECVf6mrmhZPTjVX+DorByX2uBzLDzF9jGMFMJRhxi0yVpXsqBrP+ps9G+s7oNUDp771d1Bix+gm5EyebEbdiQuf0/8wDlhY5jYAFJW1xkPKXcjJdM1FuVVS1B8zhvRVEJZUngruVfh/7jJUOWNS44F7rVz5a4r/vs84ObFIMeYdFn+uxgUqOlNMAvXLvJ2GzlPXInXW90Uv+JJ5msny/5ygGfHr2D6xOf6P7r7oSalXwjd9BcRS6/4GQAY6LVfPwrpnrpyJBiK/FhEbR+ctfDo81eKhmp0EyxvSJGW46/26/kqHvchf+rQ== ?\n", + "resource_group_name": "TESTRESGROUP", + "tags": null, + "timeouts": {} + }, + { + "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/TESTRESGROUP/providers/Microsoft.Compute/sshPublicKeys/example-key", + "location": "westeurope", + "name": "example-key", + "public_key": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQCjeC5sO1EdEfZOrdVCpuOgXcXsKZg9zgfJbHgQgX1R2Nd8mNQrUjpsB4XLHNZ3T6UYrsSh7oxYC3UFu6peO4LmA2WTe2wCWVFn9WW/Lo99WcA/G/fGj6s5HK5CHFVPXnNxM47QJMNm5BWOM55+EWP839SHLH9Fk63H575x7jxZvBvaV0uL84XuVpiEBKhnpQfT4cJGoGLOGgjM+TpHyosbKldu5q2UTF9nOGpmLuku41oihqiPPSJnJRv3TDKFi4mIl9Iz5HJINWvLl1kdCfyjPcHcH5GO0tuA9rP5AbsmG5EAGOKtuFipYA4MyY9SYriZ2V1vpgefUS9lilg9hIPEj/8ZPTxf62XeyC1dQ3cOz6yPWR2sODyVECVf6mrmhZPTjVX+DorByX2uBzLDzF9jGMFMJRhxi0yVpXsqBrP+ps9G+s7oNUDp771d1Bix+gm5EyebEbdiQuf0/8wDlhY5jYAFJW1xkPKXcjJdM1FuVVS1B8zhvRVEJZUngruVfh/7jJUOWNS44F7rVz5a4r/vs84ObFIMeYdFn+uxgUqOlNMAvXLvJ2GzlPXInXW90Uv+JJ5msny/5ygGfHr2D6xOf6P7r7oSalXwjd9BcRS6/4GQAY6LVfPwrpnrpyJBiK/FhEbR+ctfDo81eKhmp0EyxvSJGW46/26/kqHvchf+rQ== ?\n", + "resource_group_name": "TESTRESGROUP", + "tags": null, + "timeouts": {} + } +] \ No newline at end of file diff --git a/pkg/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzE=.res.golden.json b/enumeration/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzE=.res.golden.json similarity index 100% rename from pkg/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzE=.res.golden.json rename to enumeration/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzE=.res.golden.json diff --git a/pkg/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzI=.res.golden.json b/enumeration/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzI=.res.golden.json similarity index 100% rename from pkg/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzI=.res.golden.json rename to enumeration/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzI=.res.golden.json diff --git a/pkg/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzQ=.res.golden.json b/enumeration/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzQ=.res.golden.json similarity index 100% rename from pkg/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzQ=.res.golden.json rename to enumeration/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzQ=.res.golden.json diff --git a/pkg/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0Nzc=.res.golden.json b/enumeration/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0Nzc=.res.golden.json similarity index 100% rename from pkg/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0Nzc=.res.golden.json rename to enumeration/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0Nzc=.res.golden.json diff --git a/pkg/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0Nzg=.res.golden.json b/enumeration/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0Nzg=.res.golden.json similarity index 100% rename from pkg/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0Nzg=.res.golden.json rename to enumeration/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0Nzg=.res.golden.json diff --git a/pkg/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0ODA=.res.golden.json b/enumeration/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0ODA=.res.golden.json similarity index 100% rename from pkg/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0ODA=.res.golden.json rename to enumeration/remote/test/github_branch_protection_multiples/github_branch_protection-MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0ODA=.res.golden.json diff --git a/pkg/remote/test/github_branch_protection_multiples/results.golden.json b/enumeration/remote/test/github_branch_protection_multiples/results.golden.json similarity index 100% rename from pkg/remote/test/github_branch_protection_multiples/results.golden.json rename to enumeration/remote/test/github_branch_protection_multiples/results.golden.json diff --git a/pkg/remote/test/github_branch_protection_multiples/terraform.tf b/enumeration/remote/test/github_branch_protection_multiples/terraform.tf similarity index 100% rename from pkg/remote/test/github_branch_protection_multiples/terraform.tf rename to enumeration/remote/test/github_branch_protection_multiples/terraform.tf diff --git a/pkg/remote/test/github_membership_multiple/github_membership-driftctl-test_driftctl-acceptance-tester.res.golden.json b/enumeration/remote/test/github_membership_multiple/github_membership-driftctl-test_driftctl-acceptance-tester.res.golden.json similarity index 100% rename from pkg/remote/test/github_membership_multiple/github_membership-driftctl-test_driftctl-acceptance-tester.res.golden.json rename to enumeration/remote/test/github_membership_multiple/github_membership-driftctl-test_driftctl-acceptance-tester.res.golden.json diff --git a/pkg/remote/test/github_membership_multiple/github_membership-driftctl-test_eliecharra.res.golden.json b/enumeration/remote/test/github_membership_multiple/github_membership-driftctl-test_eliecharra.res.golden.json similarity index 100% rename from pkg/remote/test/github_membership_multiple/github_membership-driftctl-test_eliecharra.res.golden.json rename to enumeration/remote/test/github_membership_multiple/github_membership-driftctl-test_eliecharra.res.golden.json diff --git a/pkg/remote/test/github_membership_multiple/results.golden.json b/enumeration/remote/test/github_membership_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/github_membership_multiple/results.golden.json rename to enumeration/remote/test/github_membership_multiple/results.golden.json diff --git a/pkg/remote/test/github_repository_multiple/github_repository-driftctl-demos.res.golden.json b/enumeration/remote/test/github_repository_multiple/github_repository-driftctl-demos.res.golden.json similarity index 100% rename from pkg/remote/test/github_repository_multiple/github_repository-driftctl-demos.res.golden.json rename to enumeration/remote/test/github_repository_multiple/github_repository-driftctl-demos.res.golden.json diff --git a/pkg/remote/test/github_repository_multiple/github_repository-driftctl.res.golden.json b/enumeration/remote/test/github_repository_multiple/github_repository-driftctl.res.golden.json similarity index 100% rename from pkg/remote/test/github_repository_multiple/github_repository-driftctl.res.golden.json rename to enumeration/remote/test/github_repository_multiple/github_repository-driftctl.res.golden.json diff --git a/pkg/remote/test/github_repository_multiple/results.golden.json b/enumeration/remote/test/github_repository_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/github_repository_multiple/results.golden.json rename to enumeration/remote/test/github_repository_multiple/results.golden.json diff --git a/pkg/remote/test/github_team_membership_multiple/github_team_membership-4570529_driftctl-acceptance-tester.res.golden.json b/enumeration/remote/test/github_team_membership_multiple/github_team_membership-4570529_driftctl-acceptance-tester.res.golden.json similarity index 100% rename from pkg/remote/test/github_team_membership_multiple/github_team_membership-4570529_driftctl-acceptance-tester.res.golden.json rename to enumeration/remote/test/github_team_membership_multiple/github_team_membership-4570529_driftctl-acceptance-tester.res.golden.json diff --git a/pkg/remote/test/github_team_membership_multiple/github_team_membership-4570529_wbeuil.res.golden.json b/enumeration/remote/test/github_team_membership_multiple/github_team_membership-4570529_wbeuil.res.golden.json similarity index 100% rename from pkg/remote/test/github_team_membership_multiple/github_team_membership-4570529_wbeuil.res.golden.json rename to enumeration/remote/test/github_team_membership_multiple/github_team_membership-4570529_wbeuil.res.golden.json diff --git a/pkg/remote/test/github_team_membership_multiple/results.golden.json b/enumeration/remote/test/github_team_membership_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/github_team_membership_multiple/results.golden.json rename to enumeration/remote/test/github_team_membership_multiple/results.golden.json diff --git a/pkg/remote/test/github_team_membership_multiple/terraform.tf b/enumeration/remote/test/github_team_membership_multiple/terraform.tf similarity index 100% rename from pkg/remote/test/github_team_membership_multiple/terraform.tf rename to enumeration/remote/test/github_team_membership_multiple/terraform.tf diff --git a/pkg/remote/test/github_teams_multiple/github_team-4556811.res.golden.json b/enumeration/remote/test/github_teams_multiple/github_team-4556811.res.golden.json similarity index 100% rename from pkg/remote/test/github_teams_multiple/github_team-4556811.res.golden.json rename to enumeration/remote/test/github_teams_multiple/github_team-4556811.res.golden.json diff --git a/pkg/remote/test/github_teams_multiple/github_team-4556812.res.golden.json b/enumeration/remote/test/github_teams_multiple/github_team-4556812.res.golden.json similarity index 100% rename from pkg/remote/test/github_teams_multiple/github_team-4556812.res.golden.json rename to enumeration/remote/test/github_teams_multiple/github_team-4556812.res.golden.json diff --git a/pkg/remote/test/github_teams_multiple/github_team-4556814.res.golden.json b/enumeration/remote/test/github_teams_multiple/github_team-4556814.res.golden.json similarity index 100% rename from pkg/remote/test/github_teams_multiple/github_team-4556814.res.golden.json rename to enumeration/remote/test/github_teams_multiple/github_team-4556814.res.golden.json diff --git a/pkg/remote/test/github_teams_multiple/results.golden.json b/enumeration/remote/test/github_teams_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/github_teams_multiple/results.golden.json rename to enumeration/remote/test/github_teams_multiple/results.golden.json diff --git a/pkg/remote/test/google_compute_firewall/google_compute_firewall-projects_cloudskiff-dev-elie_global_firewalls_test-firewall-0-test-firewall-0-cloudskiff-dev-elie.res.golden.json b/enumeration/remote/test/google_compute_firewall/google_compute_firewall-projects_cloudskiff-dev-elie_global_firewalls_test-firewall-0-test-firewall-0-cloudskiff-dev-elie.res.golden.json similarity index 100% rename from pkg/remote/test/google_compute_firewall/google_compute_firewall-projects_cloudskiff-dev-elie_global_firewalls_test-firewall-0-test-firewall-0-cloudskiff-dev-elie.res.golden.json rename to enumeration/remote/test/google_compute_firewall/google_compute_firewall-projects_cloudskiff-dev-elie_global_firewalls_test-firewall-0-test-firewall-0-cloudskiff-dev-elie.res.golden.json diff --git a/pkg/remote/test/google_compute_firewall/google_compute_firewall-projects_cloudskiff-dev-elie_global_firewalls_test-firewall-1-test-firewall-1-cloudskiff-dev-elie.res.golden.json b/enumeration/remote/test/google_compute_firewall/google_compute_firewall-projects_cloudskiff-dev-elie_global_firewalls_test-firewall-1-test-firewall-1-cloudskiff-dev-elie.res.golden.json similarity index 100% rename from pkg/remote/test/google_compute_firewall/google_compute_firewall-projects_cloudskiff-dev-elie_global_firewalls_test-firewall-1-test-firewall-1-cloudskiff-dev-elie.res.golden.json rename to enumeration/remote/test/google_compute_firewall/google_compute_firewall-projects_cloudskiff-dev-elie_global_firewalls_test-firewall-1-test-firewall-1-cloudskiff-dev-elie.res.golden.json diff --git a/pkg/remote/test/google_compute_firewall/google_compute_firewall-projects_cloudskiff-dev-elie_global_firewalls_test-firewall-2-test-firewall-2-cloudskiff-dev-elie.res.golden.json b/enumeration/remote/test/google_compute_firewall/google_compute_firewall-projects_cloudskiff-dev-elie_global_firewalls_test-firewall-2-test-firewall-2-cloudskiff-dev-elie.res.golden.json similarity index 100% rename from pkg/remote/test/google_compute_firewall/google_compute_firewall-projects_cloudskiff-dev-elie_global_firewalls_test-firewall-2-test-firewall-2-cloudskiff-dev-elie.res.golden.json rename to enumeration/remote/test/google_compute_firewall/google_compute_firewall-projects_cloudskiff-dev-elie_global_firewalls_test-firewall-2-test-firewall-2-cloudskiff-dev-elie.res.golden.json diff --git a/enumeration/remote/test/google_compute_firewall/results.golden.json b/enumeration/remote/test/google_compute_firewall/results.golden.json new file mode 100755 index 00000000..b6948683 --- /dev/null +++ b/enumeration/remote/test/google_compute_firewall/results.golden.json @@ -0,0 +1,116 @@ +[ + { + "allow": [ + { + "ports": [ + "80", + "8080", + "1000-2000" + ], + "protocol": "tcp" + }, + { + "ports": null, + "protocol": "icmp" + } + ], + "creation_timestamp": "2021-09-14T05:21:08.730-07:00", + "deny": null, + "description": "", + "destination_ranges": null, + "direction": "INGRESS", + "disabled": false, + "enable_logging": null, + "id": "projects/cloudskiff-dev-elie/global/firewalls/test-firewall-1", + "log_config": null, + "name": "test-firewall-1", + "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-elie/global/networks/test-network", + "priority": 1000, + "project": "cloudskiff-dev-elie", + "self_link": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-elie/global/firewalls/test-firewall-1", + "source_ranges": null, + "source_service_accounts": null, + "source_tags": [ + "web" + ], + "target_service_accounts": null, + "target_tags": null, + "timeouts": {} + }, + { + "allow": [ + { + "ports": [ + "80", + "8080", + "1000-2000" + ], + "protocol": "tcp" + }, + { + "ports": null, + "protocol": "icmp" + } + ], + "creation_timestamp": "2021-09-14T05:21:08.744-07:00", + "deny": null, + "description": "", + "destination_ranges": null, + "direction": "INGRESS", + "disabled": false, + "enable_logging": null, + "id": "projects/cloudskiff-dev-elie/global/firewalls/test-firewall-0", + "log_config": null, + "name": "test-firewall-0", + "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-elie/global/networks/test-network", + "priority": 1000, + "project": "cloudskiff-dev-elie", + "self_link": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-elie/global/firewalls/test-firewall-0", + "source_ranges": null, + "source_service_accounts": null, + "source_tags": [ + "web" + ], + "target_service_accounts": null, + "target_tags": null, + "timeouts": {} + }, + { + "allow": [ + { + "ports": [ + "80", + "8080", + "1000-2000" + ], + "protocol": "tcp" + }, + { + "ports": null, + "protocol": "icmp" + } + ], + "creation_timestamp": "2021-09-14T05:21:08.624-07:00", + "deny": null, + "description": "", + "destination_ranges": null, + "direction": "INGRESS", + "disabled": false, + "enable_logging": null, + "id": "projects/cloudskiff-dev-elie/global/firewalls/test-firewall-2", + "log_config": null, + "name": "test-firewall-2", + "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-elie/global/networks/test-network", + "priority": 1000, + "project": "cloudskiff-dev-elie", + "self_link": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-elie/global/firewalls/test-firewall-2", + "source_ranges": null, + "source_service_accounts": null, + "source_tags": [ + "web" + ], + "target_service_accounts": null, + "target_tags": null, + "timeouts": {} + } +] \ No newline at end of file diff --git a/pkg/remote/test/google_compute_firewall/terraform.tf b/enumeration/remote/test/google_compute_firewall/terraform.tf similarity index 100% rename from pkg/remote/test/google_compute_firewall/terraform.tf rename to enumeration/remote/test/google_compute_firewall/terraform.tf diff --git a/pkg/remote/test/google_compute_instance_group/google_compute_instance_group-projects_cloudskiff-dev-raphael_zones_us-central1-a_instanceGroups_driftctl-test-1-driftctl-test-1-cloudskiff-dev-raphael-us-central1-a.res.golden.json b/enumeration/remote/test/google_compute_instance_group/google_compute_instance_group-projects_cloudskiff-dev-raphael_zones_us-central1-a_instanceGroups_driftctl-test-1-driftctl-test-1-cloudskiff-dev-raphael-us-central1-a.res.golden.json similarity index 100% rename from pkg/remote/test/google_compute_instance_group/google_compute_instance_group-projects_cloudskiff-dev-raphael_zones_us-central1-a_instanceGroups_driftctl-test-1-driftctl-test-1-cloudskiff-dev-raphael-us-central1-a.res.golden.json rename to enumeration/remote/test/google_compute_instance_group/google_compute_instance_group-projects_cloudskiff-dev-raphael_zones_us-central1-a_instanceGroups_driftctl-test-1-driftctl-test-1-cloudskiff-dev-raphael-us-central1-a.res.golden.json diff --git a/pkg/remote/test/google_compute_instance_group/google_compute_instance_group-projects_cloudskiff-dev-raphael_zones_us-central1-a_instanceGroups_driftctl-test-2-driftctl-test-2-cloudskiff-dev-raphael-us-central1-a.res.golden.json b/enumeration/remote/test/google_compute_instance_group/google_compute_instance_group-projects_cloudskiff-dev-raphael_zones_us-central1-a_instanceGroups_driftctl-test-2-driftctl-test-2-cloudskiff-dev-raphael-us-central1-a.res.golden.json similarity index 100% rename from pkg/remote/test/google_compute_instance_group/google_compute_instance_group-projects_cloudskiff-dev-raphael_zones_us-central1-a_instanceGroups_driftctl-test-2-driftctl-test-2-cloudskiff-dev-raphael-us-central1-a.res.golden.json rename to enumeration/remote/test/google_compute_instance_group/google_compute_instance_group-projects_cloudskiff-dev-raphael_zones_us-central1-a_instanceGroups_driftctl-test-2-driftctl-test-2-cloudskiff-dev-raphael-us-central1-a.res.golden.json diff --git a/enumeration/remote/test/google_compute_instance_group/results.golden.json b/enumeration/remote/test/google_compute_instance_group/results.golden.json new file mode 100755 index 00000000..b1bfeceb --- /dev/null +++ b/enumeration/remote/test/google_compute_instance_group/results.golden.json @@ -0,0 +1,28 @@ +[ + { + "description": "Terraform test instance group", + "id": "projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroups/driftctl-test-2", + "instances": null, + "name": "driftctl-test-2", + "named_port": null, + "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/global/networks/test-network", + "project": "cloudskiff-dev-raphael", + "self_link": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroups/driftctl-test-2", + "size": 0, + "timeouts": {}, + "zone": "us-central1-a" + }, + { + "description": "Terraform test instance group", + "id": "projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroups/driftctl-test-1", + "instances": null, + "name": "driftctl-test-1", + "named_port": null, + "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/global/networks/test-network", + "project": "cloudskiff-dev-raphael", + "self_link": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroups/driftctl-test-1", + "size": 0, + "timeouts": {}, + "zone": "us-central1-a" + } +] \ No newline at end of file diff --git a/pkg/remote/test/google_compute_instance_group/terraform.tf b/enumeration/remote/test/google_compute_instance_group/terraform.tf similarity index 100% rename from pkg/remote/test/google_compute_instance_group/terraform.tf rename to enumeration/remote/test/google_compute_instance_group/terraform.tf diff --git a/pkg/remote/test/google_compute_network/google_compute_network-projects_driftctl-qa-1_global_networks_driftctl-unittest-1-driftctl-unittest-1.res.golden.json b/enumeration/remote/test/google_compute_network/google_compute_network-projects_driftctl-qa-1_global_networks_driftctl-unittest-1-driftctl-unittest-1.res.golden.json similarity index 100% rename from pkg/remote/test/google_compute_network/google_compute_network-projects_driftctl-qa-1_global_networks_driftctl-unittest-1-driftctl-unittest-1.res.golden.json rename to enumeration/remote/test/google_compute_network/google_compute_network-projects_driftctl-qa-1_global_networks_driftctl-unittest-1-driftctl-unittest-1.res.golden.json diff --git a/pkg/remote/test/google_compute_network/google_compute_network-projects_driftctl-qa-1_global_networks_driftctl-unittest-2-driftctl-unittest-2.res.golden.json b/enumeration/remote/test/google_compute_network/google_compute_network-projects_driftctl-qa-1_global_networks_driftctl-unittest-2-driftctl-unittest-2.res.golden.json similarity index 100% rename from pkg/remote/test/google_compute_network/google_compute_network-projects_driftctl-qa-1_global_networks_driftctl-unittest-2-driftctl-unittest-2.res.golden.json rename to enumeration/remote/test/google_compute_network/google_compute_network-projects_driftctl-qa-1_global_networks_driftctl-unittest-2-driftctl-unittest-2.res.golden.json diff --git a/pkg/remote/test/google_compute_network/google_compute_network-projects_driftctl-qa-1_global_networks_driftctl-unittest-3-driftctl-unittest-3.res.golden.json b/enumeration/remote/test/google_compute_network/google_compute_network-projects_driftctl-qa-1_global_networks_driftctl-unittest-3-driftctl-unittest-3.res.golden.json similarity index 100% rename from pkg/remote/test/google_compute_network/google_compute_network-projects_driftctl-qa-1_global_networks_driftctl-unittest-3-driftctl-unittest-3.res.golden.json rename to enumeration/remote/test/google_compute_network/google_compute_network-projects_driftctl-qa-1_global_networks_driftctl-unittest-3-driftctl-unittest-3.res.golden.json diff --git a/enumeration/remote/test/google_compute_network/results.golden.json b/enumeration/remote/test/google_compute_network/results.golden.json new file mode 100755 index 00000000..c160ca26 --- /dev/null +++ b/enumeration/remote/test/google_compute_network/results.golden.json @@ -0,0 +1,41 @@ +[ + { + "auto_create_subnetworks": false, + "delete_default_routes_on_create": false, + "description": "", + "gateway_ipv4": "", + "id": "projects/driftctl-qa-1/global/networks/driftctl-unittest-3", + "mtu": 1460, + "name": "driftctl-unittest-3", + "project": "driftctl-qa-1", + "routing_mode": "REGIONAL", + "self_link": "https://www.googleapis.com/compute/v1/projects/driftctl-qa-1/global/networks/driftctl-unittest-3", + "timeouts": {} + }, + { + "auto_create_subnetworks": true, + "delete_default_routes_on_create": false, + "description": "", + "gateway_ipv4": "", + "id": "projects/driftctl-qa-1/global/networks/driftctl-unittest-2", + "mtu": 1460, + "name": "driftctl-unittest-2", + "project": "driftctl-qa-1", + "routing_mode": "REGIONAL", + "self_link": "https://www.googleapis.com/compute/v1/projects/driftctl-qa-1/global/networks/driftctl-unittest-2", + "timeouts": {} + }, + { + "auto_create_subnetworks": false, + "delete_default_routes_on_create": false, + "description": "", + "gateway_ipv4": "", + "id": "projects/driftctl-qa-1/global/networks/driftctl-unittest-1", + "mtu": 1460, + "name": "driftctl-unittest-1", + "project": "driftctl-qa-1", + "routing_mode": "REGIONAL", + "self_link": "https://www.googleapis.com/compute/v1/projects/driftctl-qa-1/global/networks/driftctl-unittest-1", + "timeouts": {} + } +] \ No newline at end of file diff --git a/pkg/remote/test/google_compute_subnetwork_multiple/google_compute_subnetwork-projects_cloudskiff-dev-raphael_regions_us-central1_subnetworks_driftctl-unittest-1-driftctl-unittest-1-.res.golden.json b/enumeration/remote/test/google_compute_subnetwork_multiple/google_compute_subnetwork-projects_cloudskiff-dev-raphael_regions_us-central1_subnetworks_driftctl-unittest-1-driftctl-unittest-1-.res.golden.json similarity index 100% rename from pkg/remote/test/google_compute_subnetwork_multiple/google_compute_subnetwork-projects_cloudskiff-dev-raphael_regions_us-central1_subnetworks_driftctl-unittest-1-driftctl-unittest-1-.res.golden.json rename to enumeration/remote/test/google_compute_subnetwork_multiple/google_compute_subnetwork-projects_cloudskiff-dev-raphael_regions_us-central1_subnetworks_driftctl-unittest-1-driftctl-unittest-1-.res.golden.json diff --git a/pkg/remote/test/google_compute_subnetwork_multiple/google_compute_subnetwork-projects_cloudskiff-dev-raphael_regions_us-central1_subnetworks_driftctl-unittest-2-driftctl-unittest-2-.res.golden.json b/enumeration/remote/test/google_compute_subnetwork_multiple/google_compute_subnetwork-projects_cloudskiff-dev-raphael_regions_us-central1_subnetworks_driftctl-unittest-2-driftctl-unittest-2-.res.golden.json similarity index 100% rename from pkg/remote/test/google_compute_subnetwork_multiple/google_compute_subnetwork-projects_cloudskiff-dev-raphael_regions_us-central1_subnetworks_driftctl-unittest-2-driftctl-unittest-2-.res.golden.json rename to enumeration/remote/test/google_compute_subnetwork_multiple/google_compute_subnetwork-projects_cloudskiff-dev-raphael_regions_us-central1_subnetworks_driftctl-unittest-2-driftctl-unittest-2-.res.golden.json diff --git a/pkg/remote/test/google_compute_subnetwork_multiple/google_compute_subnetwork-projects_cloudskiff-dev-raphael_regions_us-central1_subnetworks_driftctl-unittest-3-driftctl-unittest-3-.res.golden.json b/enumeration/remote/test/google_compute_subnetwork_multiple/google_compute_subnetwork-projects_cloudskiff-dev-raphael_regions_us-central1_subnetworks_driftctl-unittest-3-driftctl-unittest-3-.res.golden.json similarity index 100% rename from pkg/remote/test/google_compute_subnetwork_multiple/google_compute_subnetwork-projects_cloudskiff-dev-raphael_regions_us-central1_subnetworks_driftctl-unittest-3-driftctl-unittest-3-.res.golden.json rename to enumeration/remote/test/google_compute_subnetwork_multiple/google_compute_subnetwork-projects_cloudskiff-dev-raphael_regions_us-central1_subnetworks_driftctl-unittest-3-driftctl-unittest-3-.res.golden.json diff --git a/enumeration/remote/test/google_compute_subnetwork_multiple/results.golden.json b/enumeration/remote/test/google_compute_subnetwork_multiple/results.golden.json new file mode 100755 index 00000000..074a306e --- /dev/null +++ b/enumeration/remote/test/google_compute_subnetwork_multiple/results.golden.json @@ -0,0 +1,71 @@ +[ + { + "creation_timestamp": "2021-10-20T07:39:34.673-07:00", + "description": "", + "fingerprint": null, + "gateway_address": "10.2.0.1", + "id": "projects/cloudskiff-dev-raphael/regions/us-central1/subnetworks/driftctl-unittest-1", + "ip_cidr_range": "10.2.0.0/16", + "log_config": null, + "name": "driftctl-unittest-1", + "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/global/networks/test-network-1871346572", + "private_ip_google_access": false, + "private_ipv6_google_access": "DISABLE_GOOGLE_ACCESS", + "project": "cloudskiff-dev-raphael", + "region": "us-central1", + "secondary_ip_range": [ + { + "ip_cidr_range": "192.168.10.0/24", + "range_name": "tf-test-secondary-range-update1" + } + ], + "self_link": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/regions/us-central1/subnetworks/driftctl-unittest-1", + "timeouts": {} + }, + { + "creation_timestamp": "2021-10-20T07:39:45.114-07:00", + "description": "", + "fingerprint": null, + "gateway_address": "10.2.0.1", + "id": "projects/cloudskiff-dev-raphael/regions/us-central1/subnetworks/driftctl-unittest-2", + "ip_cidr_range": "10.2.0.0/16", + "log_config": null, + "name": "driftctl-unittest-2", + "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/global/networks/test-network-2871346572", + "private_ip_google_access": false, + "private_ipv6_google_access": "DISABLE_GOOGLE_ACCESS", + "project": "cloudskiff-dev-raphael", + "region": "us-central1", + "secondary_ip_range": [ + { + "ip_cidr_range": "192.168.10.0/24", + "range_name": "tf-test-secondary-range-update1" + } + ], + "self_link": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/regions/us-central1/subnetworks/driftctl-unittest-2", + "timeouts": {} + }, + { + "creation_timestamp": "2021-10-20T07:39:34.650-07:00", + "description": "", + "fingerprint": null, + "gateway_address": "10.2.0.1", + "id": "projects/cloudskiff-dev-raphael/regions/us-central1/subnetworks/driftctl-unittest-3", + "ip_cidr_range": "10.2.0.0/16", + "log_config": null, + "name": "driftctl-unittest-3", + "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/global/networks/test-network-3871346572", + "private_ip_google_access": false, + "private_ipv6_google_access": "DISABLE_GOOGLE_ACCESS", + "project": "cloudskiff-dev-raphael", + "region": "us-central1", + "secondary_ip_range": [ + { + "ip_cidr_range": "192.168.10.0/24", + "range_name": "tf-test-secondary-range-update1" + } + ], + "self_link": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/regions/us-central1/subnetworks/driftctl-unittest-3", + "timeouts": {} + } +] \ No newline at end of file diff --git a/pkg/remote/test/google_project_member_listing_multiple/results.golden.json b/enumeration/remote/test/google_project_member_listing_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/google_project_member_listing_multiple/results.golden.json rename to enumeration/remote/test/google_project_member_listing_multiple/results.golden.json diff --git a/pkg/remote/test/google_storage_bucket/google_storage_bucket-driftctl-unittest-1-driftctl-unittest-1.res.golden.json b/enumeration/remote/test/google_storage_bucket/google_storage_bucket-driftctl-unittest-1-driftctl-unittest-1.res.golden.json similarity index 100% rename from pkg/remote/test/google_storage_bucket/google_storage_bucket-driftctl-unittest-1-driftctl-unittest-1.res.golden.json rename to enumeration/remote/test/google_storage_bucket/google_storage_bucket-driftctl-unittest-1-driftctl-unittest-1.res.golden.json diff --git a/pkg/remote/test/google_storage_bucket/google_storage_bucket-driftctl-unittest-2-driftctl-unittest-2.res.golden.json b/enumeration/remote/test/google_storage_bucket/google_storage_bucket-driftctl-unittest-2-driftctl-unittest-2.res.golden.json similarity index 100% rename from pkg/remote/test/google_storage_bucket/google_storage_bucket-driftctl-unittest-2-driftctl-unittest-2.res.golden.json rename to enumeration/remote/test/google_storage_bucket/google_storage_bucket-driftctl-unittest-2-driftctl-unittest-2.res.golden.json diff --git a/pkg/remote/test/google_storage_bucket/google_storage_bucket-driftctl-unittest-3-driftctl-unittest-3.res.golden.json b/enumeration/remote/test/google_storage_bucket/google_storage_bucket-driftctl-unittest-3-driftctl-unittest-3.res.golden.json similarity index 100% rename from pkg/remote/test/google_storage_bucket/google_storage_bucket-driftctl-unittest-3-driftctl-unittest-3.res.golden.json rename to enumeration/remote/test/google_storage_bucket/google_storage_bucket-driftctl-unittest-3-driftctl-unittest-3.res.golden.json diff --git a/pkg/remote/test/google_storage_bucket/results.golden.json b/enumeration/remote/test/google_storage_bucket/results.golden.json similarity index 100% rename from pkg/remote/test/google_storage_bucket/results.golden.json rename to enumeration/remote/test/google_storage_bucket/results.golden.json diff --git a/pkg/remote/test/google_storage_bucket_member_listing_multiple/results.golden.json b/enumeration/remote/test/google_storage_bucket_member_listing_multiple/results.golden.json similarity index 100% rename from pkg/remote/test/google_storage_bucket_member_listing_multiple/results.golden.json rename to enumeration/remote/test/google_storage_bucket_member_listing_multiple/results.golden.json diff --git a/pkg/resource/aws/aws_alb.go b/enumeration/resource/aws/aws_alb.go similarity index 100% rename from pkg/resource/aws/aws_alb.go rename to enumeration/resource/aws/aws_alb.go diff --git a/pkg/resource/aws/aws_alb_listener.go b/enumeration/resource/aws/aws_alb_listener.go similarity index 100% rename from pkg/resource/aws/aws_alb_listener.go rename to enumeration/resource/aws/aws_alb_listener.go diff --git a/enumeration/resource/aws/aws_ami.go b/enumeration/resource/aws/aws_ami.go new file mode 100644 index 00000000..ac8fec7d --- /dev/null +++ b/enumeration/resource/aws/aws_ami.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsAmiResourceType = "aws_ami" + +func initAwsAmiMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsAmiResourceType, resource.FlagDeepMode) +} diff --git a/pkg/resource/aws/aws_api_gateway_account.go b/enumeration/resource/aws/aws_api_gateway_account.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_account.go rename to enumeration/resource/aws/aws_api_gateway_account.go diff --git a/pkg/resource/aws/aws_api_gateway_api_key.go b/enumeration/resource/aws/aws_api_gateway_api_key.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_api_key.go rename to enumeration/resource/aws/aws_api_gateway_api_key.go diff --git a/pkg/resource/aws/aws_api_gateway_authorizer.go b/enumeration/resource/aws/aws_api_gateway_authorizer.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_authorizer.go rename to enumeration/resource/aws/aws_api_gateway_authorizer.go diff --git a/pkg/resource/aws/aws_api_gateway_base_path_mapping.go b/enumeration/resource/aws/aws_api_gateway_base_path_mapping.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_base_path_mapping.go rename to enumeration/resource/aws/aws_api_gateway_base_path_mapping.go diff --git a/pkg/resource/aws/aws_api_gateway_deployment.go b/enumeration/resource/aws/aws_api_gateway_deployment.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_deployment.go rename to enumeration/resource/aws/aws_api_gateway_deployment.go diff --git a/pkg/resource/aws/aws_api_gateway_domain_name.go b/enumeration/resource/aws/aws_api_gateway_domain_name.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_domain_name.go rename to enumeration/resource/aws/aws_api_gateway_domain_name.go diff --git a/pkg/resource/aws/aws_api_gateway_gateway_response.go b/enumeration/resource/aws/aws_api_gateway_gateway_response.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_gateway_response.go rename to enumeration/resource/aws/aws_api_gateway_gateway_response.go diff --git a/pkg/resource/aws/aws_api_gateway_integration.go b/enumeration/resource/aws/aws_api_gateway_integration.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_integration.go rename to enumeration/resource/aws/aws_api_gateway_integration.go diff --git a/pkg/resource/aws/aws_api_gateway_integration_response.go b/enumeration/resource/aws/aws_api_gateway_integration_response.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_integration_response.go rename to enumeration/resource/aws/aws_api_gateway_integration_response.go diff --git a/pkg/resource/aws/aws_api_gateway_method.go b/enumeration/resource/aws/aws_api_gateway_method.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_method.go rename to enumeration/resource/aws/aws_api_gateway_method.go diff --git a/pkg/resource/aws/aws_api_gateway_method_response.go b/enumeration/resource/aws/aws_api_gateway_method_response.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_method_response.go rename to enumeration/resource/aws/aws_api_gateway_method_response.go diff --git a/pkg/resource/aws/aws_api_gateway_method_settings.go b/enumeration/resource/aws/aws_api_gateway_method_settings.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_method_settings.go rename to enumeration/resource/aws/aws_api_gateway_method_settings.go diff --git a/pkg/resource/aws/aws_api_gateway_model.go b/enumeration/resource/aws/aws_api_gateway_model.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_model.go rename to enumeration/resource/aws/aws_api_gateway_model.go diff --git a/pkg/resource/aws/aws_api_gateway_request_validator.go b/enumeration/resource/aws/aws_api_gateway_request_validator.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_request_validator.go rename to enumeration/resource/aws/aws_api_gateway_request_validator.go diff --git a/pkg/resource/aws/aws_api_gateway_resource.go b/enumeration/resource/aws/aws_api_gateway_resource.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_resource.go rename to enumeration/resource/aws/aws_api_gateway_resource.go diff --git a/pkg/resource/aws/aws_api_gateway_rest_api.go b/enumeration/resource/aws/aws_api_gateway_rest_api.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_rest_api.go rename to enumeration/resource/aws/aws_api_gateway_rest_api.go diff --git a/pkg/resource/aws/aws_api_gateway_rest_api_policy.go b/enumeration/resource/aws/aws_api_gateway_rest_api_policy.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_rest_api_policy.go rename to enumeration/resource/aws/aws_api_gateway_rest_api_policy.go diff --git a/pkg/resource/aws/aws_api_gateway_stage.go b/enumeration/resource/aws/aws_api_gateway_stage.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_stage.go rename to enumeration/resource/aws/aws_api_gateway_stage.go diff --git a/pkg/resource/aws/aws_api_gateway_vpc_link.go b/enumeration/resource/aws/aws_api_gateway_vpc_link.go similarity index 100% rename from pkg/resource/aws/aws_api_gateway_vpc_link.go rename to enumeration/resource/aws/aws_api_gateway_vpc_link.go diff --git a/pkg/resource/aws/aws_apigatewayv2_api.go b/enumeration/resource/aws/aws_apigatewayv2_api.go similarity index 100% rename from pkg/resource/aws/aws_apigatewayv2_api.go rename to enumeration/resource/aws/aws_apigatewayv2_api.go diff --git a/pkg/resource/aws/aws_apigatewayv2_authorizer.go b/enumeration/resource/aws/aws_apigatewayv2_authorizer.go similarity index 100% rename from pkg/resource/aws/aws_apigatewayv2_authorizer.go rename to enumeration/resource/aws/aws_apigatewayv2_authorizer.go diff --git a/pkg/resource/aws/aws_apigatewayv2_deployment.go b/enumeration/resource/aws/aws_apigatewayv2_deployment.go similarity index 100% rename from pkg/resource/aws/aws_apigatewayv2_deployment.go rename to enumeration/resource/aws/aws_apigatewayv2_deployment.go diff --git a/pkg/resource/aws/aws_apigatewayv2_domain_name.go b/enumeration/resource/aws/aws_apigatewayv2_domain_name.go similarity index 100% rename from pkg/resource/aws/aws_apigatewayv2_domain_name.go rename to enumeration/resource/aws/aws_apigatewayv2_domain_name.go diff --git a/pkg/resource/aws/aws_apigatewayv2_integration.go b/enumeration/resource/aws/aws_apigatewayv2_integration.go similarity index 100% rename from pkg/resource/aws/aws_apigatewayv2_integration.go rename to enumeration/resource/aws/aws_apigatewayv2_integration.go diff --git a/pkg/resource/aws/aws_apigatewayv2_integration_response.go b/enumeration/resource/aws/aws_apigatewayv2_integration_response.go similarity index 100% rename from pkg/resource/aws/aws_apigatewayv2_integration_response.go rename to enumeration/resource/aws/aws_apigatewayv2_integration_response.go diff --git a/enumeration/resource/aws/aws_apigatewayv2_mapping.go b/enumeration/resource/aws/aws_apigatewayv2_mapping.go new file mode 100644 index 00000000..dd563a9a --- /dev/null +++ b/enumeration/resource/aws/aws_apigatewayv2_mapping.go @@ -0,0 +1,23 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsApiGatewayV2MappingResourceType = "aws_apigatewayv2_api_mapping" + +func initAwsApiGatewayV2MappingMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc( + AwsApiGatewayV2MappingResourceType, + func(res *resource.Resource) map[string]string { + attrs := make(map[string]string) + + if v := res.Attributes().GetString("api_id"); v != nil { + attrs["Api"] = *v + } + if v := res.Attributes().GetString("stage"); v != nil { + attrs["Stage"] = *v + } + + return attrs + }, + ) +} diff --git a/enumeration/resource/aws/aws_apigatewayv2_model.go b/enumeration/resource/aws/aws_apigatewayv2_model.go new file mode 100644 index 00000000..471ae493 --- /dev/null +++ b/enumeration/resource/aws/aws_apigatewayv2_model.go @@ -0,0 +1,16 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsApiGatewayV2ModelResourceType = "aws_apigatewayv2_model" + +func initAwsApiGatewayV2ModelMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc( + AwsApiGatewayV2ModelResourceType, + func(res *resource.Resource) map[string]string { + return map[string]string{ + "name": *res.Attributes().GetString("name"), + } + }, + ) +} diff --git a/pkg/resource/aws/aws_apigatewayv2_route.go b/enumeration/resource/aws/aws_apigatewayv2_route.go similarity index 100% rename from pkg/resource/aws/aws_apigatewayv2_route.go rename to enumeration/resource/aws/aws_apigatewayv2_route.go diff --git a/pkg/resource/aws/aws_apigatewayv2_route_response.go b/enumeration/resource/aws/aws_apigatewayv2_route_response.go similarity index 100% rename from pkg/resource/aws/aws_apigatewayv2_route_response.go rename to enumeration/resource/aws/aws_apigatewayv2_route_response.go diff --git a/pkg/resource/aws/aws_apigatewayv2_stage.go b/enumeration/resource/aws/aws_apigatewayv2_stage.go similarity index 100% rename from pkg/resource/aws/aws_apigatewayv2_stage.go rename to enumeration/resource/aws/aws_apigatewayv2_stage.go diff --git a/pkg/resource/aws/aws_apigatewayv2_vpc_link.go b/enumeration/resource/aws/aws_apigatewayv2_vpc_link.go similarity index 100% rename from pkg/resource/aws/aws_apigatewayv2_vpc_link.go rename to enumeration/resource/aws/aws_apigatewayv2_vpc_link.go diff --git a/enumeration/resource/aws/aws_appautoscaling_policy.go b/enumeration/resource/aws/aws_appautoscaling_policy.go new file mode 100644 index 00000000..b212828a --- /dev/null +++ b/enumeration/resource/aws/aws_appautoscaling_policy.go @@ -0,0 +1,24 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsAppAutoscalingPolicyResourceType = "aws_appautoscaling_policy" + +func initAwsAppAutoscalingPolicyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsAppAutoscalingPolicyResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "name": *res.Attributes().GetString("name"), + "resource_id": *res.Attributes().GetString("resource_id"), + "service_namespace": *res.Attributes().GetString("service_namespace"), + "scalable_dimension": *res.Attributes().GetString("scalable_dimension"), + } + }) + resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsAppAutoscalingPolicyResourceType, func(res *resource.Resource) map[string]string { + attrs := make(map[string]string) + if v := res.Attributes().GetString("scalable_dimension"); v != nil && *v != "" { + attrs["Scalable dimension"] = *v + } + return attrs + }) + resourceSchemaRepository.SetFlags(AwsAppAutoscalingPolicyResourceType, resource.FlagDeepMode) +} diff --git a/pkg/resource/aws/aws_appautoscaling_scheduled_action.go b/enumeration/resource/aws/aws_appautoscaling_scheduled_action.go similarity index 100% rename from pkg/resource/aws/aws_appautoscaling_scheduled_action.go rename to enumeration/resource/aws/aws_appautoscaling_scheduled_action.go diff --git a/enumeration/resource/aws/aws_appautoscaling_target.go b/enumeration/resource/aws/aws_appautoscaling_target.go new file mode 100644 index 00000000..bd01002a --- /dev/null +++ b/enumeration/resource/aws/aws_appautoscaling_target.go @@ -0,0 +1,22 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsAppAutoscalingTargetResourceType = "aws_appautoscaling_target" + +func initAwsAppAutoscalingTargetMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsAppAutoscalingTargetResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "service_namespace": *res.Attributes().GetString("service_namespace"), + "scalable_dimension": *res.Attributes().GetString("scalable_dimension"), + } + }) + resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsAppAutoscalingTargetResourceType, func(res *resource.Resource) map[string]string { + attrs := make(map[string]string) + if v := res.Attributes().GetString("scalable_dimension"); v != nil && *v != "" { + attrs["Scalable dimension"] = *v + } + return attrs + }) + resourceSchemaRepository.SetFlags(AwsAppAutoscalingTargetResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_cloudformation_stack.go b/enumeration/resource/aws/aws_cloudformation_stack.go new file mode 100644 index 00000000..27531968 --- /dev/null +++ b/enumeration/resource/aws/aws_cloudformation_stack.go @@ -0,0 +1,22 @@ +package aws + +import ( + "strconv" + + "github.com/hashicorp/terraform/flatmap" + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsCloudformationStackResourceType = "aws_cloudformation_stack" + +func initAwsCloudformationStackMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsCloudformationStackResourceType, func(res *resource.Resource) map[string]string { + attrs := make(map[string]interface{}) + if v := res.Attributes().GetMap("parameters"); v != nil { + attrs["parameters.%"] = strconv.FormatInt(int64(len(v)), 10) + attrs["parameters"] = v + } + return flatmap.Flatten(attrs) + }) + resourceSchemaRepository.SetFlags(AwsCloudformationStackResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_cloudfront_distribution.go b/enumeration/resource/aws/aws_cloudfront_distribution.go new file mode 100644 index 00000000..8d51bbe6 --- /dev/null +++ b/enumeration/resource/aws/aws_cloudfront_distribution.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsCloudfrontDistributionResourceType = "aws_cloudfront_distribution" + +func initAwsCloudfrontDistributionMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsCloudfrontDistributionResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_db_instance.go b/enumeration/resource/aws/aws_db_instance.go new file mode 100644 index 00000000..2632e554 --- /dev/null +++ b/enumeration/resource/aws/aws_db_instance.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsDbInstanceResourceType = "aws_db_instance" + +func initAwsDbInstanceMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsDbInstanceResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_db_subnet_group.go b/enumeration/resource/aws/aws_db_subnet_group.go new file mode 100644 index 00000000..faa474c7 --- /dev/null +++ b/enumeration/resource/aws/aws_db_subnet_group.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsDbSubnetGroupResourceType = "aws_db_subnet_group" + +func initAwsDbSubnetGroupMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsDbSubnetGroupResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_default_network_acl.go b/enumeration/resource/aws/aws_default_network_acl.go new file mode 100644 index 00000000..f15deb28 --- /dev/null +++ b/enumeration/resource/aws/aws_default_network_acl.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsDefaultNetworkACLResourceType = "aws_default_network_acl" + +func initAwsDefaultNetworkACLMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsDefaultNetworkACLResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_default_route_table.go b/enumeration/resource/aws/aws_default_route_table.go new file mode 100644 index 00000000..2386d346 --- /dev/null +++ b/enumeration/resource/aws/aws_default_route_table.go @@ -0,0 +1,18 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsDefaultRouteTableResourceType = "aws_default_route_table" + +func initAwsDefaultRouteTableMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsDefaultRouteTableResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "vpc_id": *res.Attributes().GetString("vpc_id"), + } + }) + resourceSchemaRepository.SetFlags(AwsDefaultRouteTableResourceType, resource.FlagDeepMode) + resourceSchemaRepository.SetNormalizeFunc(AwsDefaultRouteTableResourceType, func(res *resource.Resource) { + val := res.Attrs + val.SafeDelete([]string{"timeouts"}) + }) +} diff --git a/enumeration/resource/aws/aws_default_security_group.go b/enumeration/resource/aws/aws_default_security_group.go new file mode 100644 index 00000000..000b24e1 --- /dev/null +++ b/enumeration/resource/aws/aws_default_security_group.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsDefaultSecurityGroupResourceType = "aws_default_security_group" + +func initAwsDefaultSecurityGroupMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsDefaultSecurityGroupResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_default_subnet.go b/enumeration/resource/aws/aws_default_subnet.go new file mode 100644 index 00000000..74702a1e --- /dev/null +++ b/enumeration/resource/aws/aws_default_subnet.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsDefaultSubnetResourceType = "aws_default_subnet" + +func initAwsDefaultSubnetMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsDefaultSubnetResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_default_vpc.go b/enumeration/resource/aws/aws_default_vpc.go new file mode 100644 index 00000000..62a02a12 --- /dev/null +++ b/enumeration/resource/aws/aws_default_vpc.go @@ -0,0 +1,9 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsDefaultVpcResourceType = "aws_default_vpc" + +func initAwsDefaultVpcMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsDefaultVpcResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_dynamodb_table.go b/enumeration/resource/aws/aws_dynamodb_table.go new file mode 100644 index 00000000..11bad21e --- /dev/null +++ b/enumeration/resource/aws/aws_dynamodb_table.go @@ -0,0 +1,17 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsDynamodbTableResourceType = "aws_dynamodb_table" + +func initAwsDynamodbTableMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsDynamodbTableResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "table_name": res.ResourceId(), + } + }) + resourceSchemaRepository.SetFlags(AwsDynamodbTableResourceType, resource.FlagDeepMode) + +} diff --git a/enumeration/resource/aws/aws_ebs_encryption_by_default.go b/enumeration/resource/aws/aws_ebs_encryption_by_default.go new file mode 100644 index 00000000..502b99f3 --- /dev/null +++ b/enumeration/resource/aws/aws_ebs_encryption_by_default.go @@ -0,0 +1,9 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsEbsEncryptionByDefaultResourceType = "aws_ebs_encryption_by_default" + +func initAwsEbsEncryptionByDefaultMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsEbsEncryptionByDefaultResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_ebs_snapshot.go b/enumeration/resource/aws/aws_ebs_snapshot.go new file mode 100644 index 00000000..e18e2cb6 --- /dev/null +++ b/enumeration/resource/aws/aws_ebs_snapshot.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsEbsSnapshotResourceType = "aws_ebs_snapshot" + +func initAwsEbsSnapshotMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsEbsSnapshotResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_ebs_volume.go b/enumeration/resource/aws/aws_ebs_volume.go new file mode 100644 index 00000000..d362ba4e --- /dev/null +++ b/enumeration/resource/aws/aws_ebs_volume.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsEbsVolumeResourceType = "aws_ebs_volume" + +func initAwsEbsVolumeMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsEbsVolumeResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_ecr_repository.go b/enumeration/resource/aws/aws_ecr_repository.go new file mode 100644 index 00000000..4f1d1ccd --- /dev/null +++ b/enumeration/resource/aws/aws_ecr_repository.go @@ -0,0 +1,10 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsEcrRepositoryResourceType = "aws_ecr_repository" + +func initAwsEcrRepositoryMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + + resourceSchemaRepository.SetFlags(AwsEcrRepositoryResourceType, resource.FlagDeepMode) +} diff --git a/pkg/resource/aws/aws_ecr_repository_policy.go b/enumeration/resource/aws/aws_ecr_repository_policy.go similarity index 100% rename from pkg/resource/aws/aws_ecr_repository_policy.go rename to enumeration/resource/aws/aws_ecr_repository_policy.go diff --git a/enumeration/resource/aws/aws_eip.go b/enumeration/resource/aws/aws_eip.go new file mode 100644 index 00000000..e1d914c4 --- /dev/null +++ b/enumeration/resource/aws/aws_eip.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsEipResourceType = "aws_eip" + +func initAwsEipMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsEipResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_eip_association.go b/enumeration/resource/aws/aws_eip_association.go new file mode 100644 index 00000000..60248114 --- /dev/null +++ b/enumeration/resource/aws/aws_eip_association.go @@ -0,0 +1,9 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsEipAssociationResourceType = "aws_eip_association" + +func initAwsEipAssociationMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsEipAssociationResourceType, resource.FlagDeepMode) +} diff --git a/pkg/resource/aws/aws_elasticache_cluster.go b/enumeration/resource/aws/aws_elasticache_cluster.go similarity index 100% rename from pkg/resource/aws/aws_elasticache_cluster.go rename to enumeration/resource/aws/aws_elasticache_cluster.go diff --git a/pkg/resource/aws/aws_elb.go b/enumeration/resource/aws/aws_elb.go similarity index 100% rename from pkg/resource/aws/aws_elb.go rename to enumeration/resource/aws/aws_elb.go diff --git a/enumeration/resource/aws/aws_iam_access_key.go b/enumeration/resource/aws/aws_iam_access_key.go new file mode 100644 index 00000000..3706284e --- /dev/null +++ b/enumeration/resource/aws/aws_iam_access_key.go @@ -0,0 +1,26 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsIamAccessKeyResourceType = "aws_iam_access_key" + +func initAwsIAMAccessKeyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsIamAccessKeyResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "user": *res.Attributes().GetString("user"), + } + }) + + resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsIamAccessKeyResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if user := val.GetString("user"); user != nil && *user != "" { + attrs["User"] = *user + } + return attrs + }) + resourceSchemaRepository.SetFlags(AwsIamAccessKeyResourceType, resource.FlagDeepMode) +} diff --git a/pkg/resource/aws/aws_iam_group.go b/enumeration/resource/aws/aws_iam_group.go similarity index 100% rename from pkg/resource/aws/aws_iam_group.go rename to enumeration/resource/aws/aws_iam_group.go diff --git a/pkg/resource/aws/aws_iam_group_policy.go b/enumeration/resource/aws/aws_iam_group_policy.go similarity index 100% rename from pkg/resource/aws/aws_iam_group_policy.go rename to enumeration/resource/aws/aws_iam_group_policy.go diff --git a/pkg/resource/aws/aws_iam_group_policy_attachment.go b/enumeration/resource/aws/aws_iam_group_policy_attachment.go similarity index 100% rename from pkg/resource/aws/aws_iam_group_policy_attachment.go rename to enumeration/resource/aws/aws_iam_group_policy_attachment.go diff --git a/enumeration/resource/aws/aws_iam_policy.go b/enumeration/resource/aws/aws_iam_policy.go new file mode 100644 index 00000000..538411ac --- /dev/null +++ b/enumeration/resource/aws/aws_iam_policy.go @@ -0,0 +1,16 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsIamPolicyResourceType = "aws_iam_policy" + +func initAwsIAMPolicyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.UpdateSchema(AwsIamPolicyResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ + "policy": func(attributeSchema *resource.AttributeSchema) { + attributeSchema.JsonString = true + }, + }) + resourceSchemaRepository.SetFlags(AwsIamPolicyResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_iam_policy_attachment.go b/enumeration/resource/aws/aws_iam_policy_attachment.go new file mode 100644 index 00000000..088f0b14 --- /dev/null +++ b/enumeration/resource/aws/aws_iam_policy_attachment.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsIamPolicyAttachmentResourceType = "aws_iam_policy_attachment" + +func initAwsIAMPolicyAttachmentMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsIamPolicyAttachmentResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_iam_role.go b/enumeration/resource/aws/aws_iam_role.go new file mode 100644 index 00000000..e736f29c --- /dev/null +++ b/enumeration/resource/aws/aws_iam_role.go @@ -0,0 +1,16 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsIamRoleResourceType = "aws_iam_role" + +func initAwsIAMRoleMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.UpdateSchema(AwsIamRoleResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ + "assume_role_policy": func(attributeSchema *resource.AttributeSchema) { + attributeSchema.JsonString = true + }, + }) + resourceSchemaRepository.SetFlags(AwsIamRoleResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_iam_role_policy.go b/enumeration/resource/aws/aws_iam_role_policy.go new file mode 100644 index 00000000..d5ed4580 --- /dev/null +++ b/enumeration/resource/aws/aws_iam_role_policy.go @@ -0,0 +1,16 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsIamRolePolicyResourceType = "aws_iam_role_policy" + +func initAwsIAMRolePolicyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.UpdateSchema(AwsIamRolePolicyResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ + "policy": func(attributeSchema *resource.AttributeSchema) { + attributeSchema.JsonString = true + }, + }) + resourceSchemaRepository.SetFlags(AwsIamRolePolicyResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_iam_role_policy_attachment.go b/enumeration/resource/aws/aws_iam_role_policy_attachment.go new file mode 100644 index 00000000..763e1664 --- /dev/null +++ b/enumeration/resource/aws/aws_iam_role_policy_attachment.go @@ -0,0 +1,15 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsIamRolePolicyAttachmentResourceType = "aws_iam_role_policy_attachment" + +func initAwsIamRolePolicyAttachmentMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsIamRolePolicyAttachmentResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "role": *res.Attributes().GetString("role"), + "policy_arn": *res.Attributes().GetString("policy_arn"), + } + }) + resourceSchemaRepository.SetFlags(AwsIamRolePolicyAttachmentResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_iam_user.go b/enumeration/resource/aws/aws_iam_user.go new file mode 100644 index 00000000..e79413be --- /dev/null +++ b/enumeration/resource/aws/aws_iam_user.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsIamUserResourceType = "aws_iam_user" + +func initAwsIAMUserMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsIamUserResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_iam_user_policy.go b/enumeration/resource/aws/aws_iam_user_policy.go new file mode 100644 index 00000000..3de42a38 --- /dev/null +++ b/enumeration/resource/aws/aws_iam_user_policy.go @@ -0,0 +1,16 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsIamUserPolicyResourceType = "aws_iam_user_policy" + +func initAwsIAMUserPolicyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.UpdateSchema(AwsIamUserPolicyResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ + "policy": func(attributeSchema *resource.AttributeSchema) { + attributeSchema.JsonString = true + }, + }) + resourceSchemaRepository.SetFlags(AwsIamUserPolicyResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_iam_user_policy_attachment.go b/enumeration/resource/aws/aws_iam_user_policy_attachment.go new file mode 100644 index 00000000..06f1bef3 --- /dev/null +++ b/enumeration/resource/aws/aws_iam_user_policy_attachment.go @@ -0,0 +1,15 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsIamUserPolicyAttachmentResourceType = "aws_iam_user_policy_attachment" + +func initAwsIamUserPolicyAttachmentMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsIamUserPolicyAttachmentResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "user": *res.Attributes().GetString("user"), + "policy_arn": *res.Attributes().GetString("policy_arn"), + } + }) + resourceSchemaRepository.SetFlags(AwsIamUserPolicyAttachmentResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_instance.go b/enumeration/resource/aws/aws_instance.go new file mode 100644 index 00000000..98f12250 --- /dev/null +++ b/enumeration/resource/aws/aws_instance.go @@ -0,0 +1,21 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsInstanceResourceType = "aws_instance" + +func initAwsInstanceMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsInstanceResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if tags := val.GetMap("tags"); tags != nil { + if name, ok := tags["Name"]; ok { + attrs["Name"] = name.(string) + } + } + return attrs + }) + resourceSchemaRepository.SetFlags(AwsInstanceResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_internet_gateway.go b/enumeration/resource/aws/aws_internet_gateway.go new file mode 100644 index 00000000..4c6c6966 --- /dev/null +++ b/enumeration/resource/aws/aws_internet_gateway.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsInternetGatewayResourceType = "aws_internet_gateway" + +func initAwsInternetGatewayMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsInternetGatewayResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_key_pair.go b/enumeration/resource/aws/aws_key_pair.go new file mode 100644 index 00000000..35514570 --- /dev/null +++ b/enumeration/resource/aws/aws_key_pair.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsKeyPairResourceType = "aws_key_pair" + +func initAwsKeyPairMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsKeyPairResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_kms_alias.go b/enumeration/resource/aws/aws_kms_alias.go new file mode 100644 index 00000000..ff6750f1 --- /dev/null +++ b/enumeration/resource/aws/aws_kms_alias.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsKmsAliasResourceType = "aws_kms_alias" + +func initAwsKmsAliasMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsKmsAliasResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_kms_key.go b/enumeration/resource/aws/aws_kms_key.go new file mode 100644 index 00000000..9e2658ae --- /dev/null +++ b/enumeration/resource/aws/aws_kms_key.go @@ -0,0 +1,16 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsKmsKeyResourceType = "aws_kms_key" + +func initAwsKmsKeyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.UpdateSchema(AwsKmsKeyResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ + "policy": func(attributeSchema *resource.AttributeSchema) { + attributeSchema.JsonString = true + }, + }) + resourceSchemaRepository.SetFlags(AwsKmsKeyResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_lambda_event_source_mapping.go b/enumeration/resource/aws/aws_lambda_event_source_mapping.go new file mode 100644 index 00000000..25e66a7b --- /dev/null +++ b/enumeration/resource/aws/aws_lambda_event_source_mapping.go @@ -0,0 +1,22 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsLambdaEventSourceMappingResourceType = "aws_lambda_event_source_mapping" + +func initAwsLambdaEventSourceMappingMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsLambdaEventSourceMappingResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + source := val.GetString("event_source_arn") + dest := val.GetString("function_name") + if source != nil && *source != "" && dest != nil && *dest != "" { + attrs["Source"] = *source + attrs["Dest"] = *dest + } + return attrs + }) + resourceSchemaRepository.SetFlags(AwsLambdaEventSourceMappingResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_lambda_function.go b/enumeration/resource/aws/aws_lambda_function.go new file mode 100644 index 00000000..1a16b4f4 --- /dev/null +++ b/enumeration/resource/aws/aws_lambda_function.go @@ -0,0 +1,17 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsLambdaFunctionResourceType = "aws_lambda_function" + +func initAwsLambdaFunctionMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsLambdaFunctionResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "function_name": res.ResourceId(), + } + }) + resourceSchemaRepository.SetFlags(AwsLambdaFunctionResourceType, resource.FlagDeepMode) +} diff --git a/pkg/resource/aws/aws_launch_configuration.go b/enumeration/resource/aws/aws_launch_configuration.go similarity index 100% rename from pkg/resource/aws/aws_launch_configuration.go rename to enumeration/resource/aws/aws_launch_configuration.go diff --git a/enumeration/resource/aws/aws_launch_template.go b/enumeration/resource/aws/aws_launch_template.go new file mode 100644 index 00000000..ce931e20 --- /dev/null +++ b/enumeration/resource/aws/aws_launch_template.go @@ -0,0 +1,9 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsLaunchTemplateResourceType = "aws_launch_template" + +func initAwsLaunchTemplateMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsLaunchTemplateResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_lb.go b/enumeration/resource/aws/aws_lb.go new file mode 100644 index 00000000..791f8068 --- /dev/null +++ b/enumeration/resource/aws/aws_lb.go @@ -0,0 +1,13 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsLoadBalancerResourceType = "aws_lb" + +func initAwsLoadBalancerMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsLoadBalancerResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "Name": *res.Attributes().GetString("name"), + } + }) +} diff --git a/pkg/resource/aws/aws_lb_listener.go b/enumeration/resource/aws/aws_lb_listener.go similarity index 100% rename from pkg/resource/aws/aws_lb_listener.go rename to enumeration/resource/aws/aws_lb_listener.go diff --git a/enumeration/resource/aws/aws_nat_gateway.go b/enumeration/resource/aws/aws_nat_gateway.go new file mode 100644 index 00000000..c51f3a98 --- /dev/null +++ b/enumeration/resource/aws/aws_nat_gateway.go @@ -0,0 +1,9 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsNatGatewayResourceType = "aws_nat_gateway" + +func initNatGatewayMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsNatGatewayResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_network_acl.go b/enumeration/resource/aws/aws_network_acl.go new file mode 100644 index 00000000..46973ef1 --- /dev/null +++ b/enumeration/resource/aws/aws_network_acl.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsNetworkACLResourceType = "aws_network_acl" + +func initAwsNetworkACLMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsNetworkACLResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_network_acl_rule.go b/enumeration/resource/aws/aws_network_acl_rule.go new file mode 100644 index 00000000..d646937b --- /dev/null +++ b/enumeration/resource/aws/aws_network_acl_rule.go @@ -0,0 +1,69 @@ +package aws + +import ( + "bytes" + "fmt" + "strconv" + + "github.com/hashicorp/terraform/helper/hashcode" + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsNetworkACLRuleResourceType = "aws_network_acl_rule" + +func initAwsNetworkACLRuleMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsNetworkACLRuleResourceType, resource.FlagDeepMode) + resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsNetworkACLRuleResourceType, func(res *resource.Resource) map[string]string { + + ruleNumber := strconv.FormatInt(int64(*res.Attrs.GetFloat64("rule_number")), 10) + if ruleNumber == "32767" { + ruleNumber = "*" + } + + attrs := map[string]string{ + "Network": *res.Attrs.GetString("network_acl_id"), + "Egress": strconv.FormatBool(*res.Attrs.GetBool("egress")), + "Rule number": ruleNumber, + } + + if proto := res.Attrs.GetString("protocol"); proto != nil { + if *proto == "-1" { + *proto = "All" + } + attrs["Protocol"] = *proto + } + + if res.Attrs.GetFloat64("from_port") != nil && res.Attrs.GetFloat64("to_port") != nil { + attrs["Port range"] = fmt.Sprintf("%d - %d", + int64(*res.Attrs.GetFloat64("from_port")), + int64(*res.Attrs.GetFloat64("to_port")), + ) + } + + if cidr := res.Attrs.GetString("cidr_block"); cidr != nil && *cidr != "" { + attrs["CIDR"] = *cidr + } + + if cidr := res.Attrs.GetString("ipv6_cidr_block"); cidr != nil && *cidr != "" { + attrs["CIDR"] = *cidr + } + + return attrs + }) + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsNetworkACLRuleResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "network_acl_id": *res.Attrs.GetString("network_acl_id"), + "rule_number": strconv.FormatInt(int64(*res.Attrs.GetFloat64("rule_number")), 10), + "egress": strconv.FormatBool(*res.Attrs.GetBool("egress")), + } + }) +} + +func CreateNetworkACLRuleID(networkAclId string, ruleNumber int, egress bool, protocol string) string { + var buf bytes.Buffer + buf.WriteString(fmt.Sprintf("%s-", networkAclId)) + buf.WriteString(fmt.Sprintf("%d-", ruleNumber)) + buf.WriteString(fmt.Sprintf("%t-", egress)) + buf.WriteString(fmt.Sprintf("%s-", protocol)) + return fmt.Sprintf("nacl-%d", hashcode.String(buf.String())) +} diff --git a/enumeration/resource/aws/aws_rds_cluster.go b/enumeration/resource/aws/aws_rds_cluster.go new file mode 100644 index 00000000..56378351 --- /dev/null +++ b/enumeration/resource/aws/aws_rds_cluster.go @@ -0,0 +1,17 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsRDSClusterResourceType = "aws_rds_cluster" + +func initAwsRDSClusterMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsRDSClusterResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "cluster_identifier": *res.Attributes().GetString("cluster_identifier"), + "database_name": *res.Attributes().GetString("database_name"), + } + }) + resourceSchemaRepository.SetFlags(AwsRDSClusterResourceType, resource.FlagDeepMode) +} diff --git a/pkg/resource/aws/aws_rds_cluster_instance.go b/enumeration/resource/aws/aws_rds_cluster_instance.go similarity index 100% rename from pkg/resource/aws/aws_rds_cluster_instance.go rename to enumeration/resource/aws/aws_rds_cluster_instance.go diff --git a/enumeration/resource/aws/aws_route.go b/enumeration/resource/aws/aws_route.go new file mode 100644 index 00000000..65b04566 --- /dev/null +++ b/enumeration/resource/aws/aws_route.go @@ -0,0 +1,62 @@ +package aws + +import ( + "fmt" + + "github.com/hashicorp/terraform/helper/hashcode" + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsRouteResourceType = "aws_route" + +func initAwsRouteMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsRouteResourceType, func(res *resource.Resource) map[string]string { + attributes := map[string]string{ + "route_table_id": *res.Attributes().GetString("route_table_id"), + } + if ipv4 := res.Attributes().GetString("destination_cidr_block"); ipv4 != nil && *ipv4 != "" { + attributes["destination_cidr_block"] = *ipv4 + } + if ipv6 := res.Attributes().GetString("destination_ipv6_cidr_block"); ipv6 != nil && *ipv6 != "" { + attributes["destination_ipv6_cidr_block"] = *ipv6 + } + if prefixes := res.Attributes().GetString("destination_prefix_list_id"); prefixes != nil && *prefixes != "" { + attributes["destination_prefix_list_id"] = *prefixes + } + return attributes + }) + resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsRouteResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if rtID := val.GetString("route_table_id"); rtID != nil && *rtID != "" { + attrs["Table"] = *rtID + } + if ipv4 := val.GetString("destination_cidr_block"); ipv4 != nil && *ipv4 != "" { + attrs["Destination"] = *ipv4 + } + if ipv6 := val.GetString("destination_ipv6_cidr_block"); ipv6 != nil && *ipv6 != "" { + attrs["Destination"] = *ipv6 + } + if prefix := val.GetString("destination_prefix_list_id"); prefix != nil && *prefix != "" { + attrs["Destination"] = *prefix + } + return attrs + }) + resourceSchemaRepository.SetFlags(AwsRouteResourceType, resource.FlagDeepMode) +} + +func CalculateRouteID(tableId, CidrBlock, Ipv6CidrBlock, PrefixListId *string) string { + if CidrBlock != nil && *CidrBlock != "" { + return fmt.Sprintf("r-%s%d", *tableId, hashcode.String(*CidrBlock)) + } + + if Ipv6CidrBlock != nil && *Ipv6CidrBlock != "" { + return fmt.Sprintf("r-%s%d", *tableId, hashcode.String(*Ipv6CidrBlock)) + } + + if PrefixListId != nil && *PrefixListId != "" { + return fmt.Sprintf("r-%s%d", *tableId, hashcode.String(*PrefixListId)) + } + + return "" +} diff --git a/enumeration/resource/aws/aws_route53_health_check.go b/enumeration/resource/aws/aws_route53_health_check.go new file mode 100644 index 00000000..e87e5b37 --- /dev/null +++ b/enumeration/resource/aws/aws_route53_health_check.go @@ -0,0 +1,43 @@ +package aws + +import ( + "fmt" + + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsRoute53HealthCheckResourceType = "aws_route53_health_check" + +func initAwsRoute53HealthCheckMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsRoute53HealthCheckResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if tags := val.GetMap("tags"); tags != nil { + if name, ok := tags["Name"]; ok { + attrs["Name"] = name.(string) + } + } + port := val.GetInt("port") + path := val.GetString("resource_path") + if fqdn := val.GetString("fqdn"); fqdn != nil && *fqdn != "" { + attrs["Fqdn"] = *fqdn + if port != nil { + attrs["Port"] = fmt.Sprintf("%d", *port) + } + if path != nil && *path != "" { + attrs["Path"] = *path + } + } + if address := val.GetString("ip_address"); address != nil && *address != "" { + attrs["IpAddress"] = *address + if port != nil { + attrs["Port"] = fmt.Sprintf("%d", *port) + } + if path != nil && *path != "" { + attrs["Path"] = *path + } + } + return attrs + }) + resourceSchemaRepository.SetFlags(AwsRoute53HealthCheckResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_route53_record.go b/enumeration/resource/aws/aws_route53_record.go new file mode 100644 index 00000000..de67f8c4 --- /dev/null +++ b/enumeration/resource/aws/aws_route53_record.go @@ -0,0 +1,25 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsRoute53RecordResourceType = "aws_route53_record" + +func initAwsRoute53RecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsRoute53RecordResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if fqdn := val.GetString("fqdn"); fqdn != nil && *fqdn != "" { + attrs["Fqdn"] = *fqdn + } + if ty := val.GetString("type"); ty != nil && *ty != "" { + attrs["Type"] = *ty + } + if zoneID := val.GetString("zone_id"); zoneID != nil && *zoneID != "" { + attrs["ZoneId"] = *zoneID + } + return attrs + }) + resourceSchemaRepository.SetFlags(AwsRoute53RecordResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_route53_zone.go b/enumeration/resource/aws/aws_route53_zone.go new file mode 100644 index 00000000..4d188b05 --- /dev/null +++ b/enumeration/resource/aws/aws_route53_zone.go @@ -0,0 +1,19 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsRoute53ZoneResourceType = "aws_route53_zone" + +func initAwsRoute53ZoneMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsRoute53ZoneResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + return attrs + }) + resourceSchemaRepository.SetFlags(AwsRoute53ZoneResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_route_table.go b/enumeration/resource/aws/aws_route_table.go new file mode 100644 index 00000000..edef3856 --- /dev/null +++ b/enumeration/resource/aws/aws_route_table.go @@ -0,0 +1,13 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsRouteTableResourceType = "aws_route_table" + +func initAwsRouteTableMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsRouteTableResourceType, resource.FlagDeepMode) + resourceSchemaRepository.SetNormalizeFunc(AwsRouteTableResourceType, func(res *resource.Resource) { + val := res.Attrs + val.SafeDelete([]string{"timeouts"}) + }) +} diff --git a/enumeration/resource/aws/aws_route_table_association.go b/enumeration/resource/aws/aws_route_table_association.go new file mode 100644 index 00000000..75ae4c5c --- /dev/null +++ b/enumeration/resource/aws/aws_route_table_association.go @@ -0,0 +1,30 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsRouteTableAssociationResourceType = "aws_route_table_association" + +func initAwsRouteTableAssociationMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsRouteTableAssociationResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "route_table_id": *res.Attributes().GetString("route_table_id"), + } + }) + + resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsRouteTableAssociationResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if rtID := val.GetString("route_table_id"); rtID != nil && *rtID != "" { + attrs["Table"] = *rtID + } + if gtwID := val.GetString("gateway_id"); gtwID != nil && *gtwID != "" { + attrs["Gateway"] = *gtwID + } + if subnetID := val.GetString("subnet_id"); subnetID != nil && *subnetID != "" { + attrs["Subnet"] = *subnetID + } + return attrs + }) + resourceSchemaRepository.SetFlags(AwsRouteTableAssociationResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_s3_bucket.go b/enumeration/resource/aws/aws_s3_bucket.go new file mode 100644 index 00000000..81c12b04 --- /dev/null +++ b/enumeration/resource/aws/aws_s3_bucket.go @@ -0,0 +1,21 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsS3BucketResourceType = "aws_s3_bucket" + +func initAwsS3BucketMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsS3BucketResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "alias": *res.Attributes().GetString("region"), + } + }) + resourceSchemaRepository.UpdateSchema(AwsS3BucketResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ + "policy": func(attributeSchema *resource.AttributeSchema) { + attributeSchema.JsonString = true + }, + }) + resourceSchemaRepository.SetFlags(AwsS3BucketResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_s3_bucket_analytics_configuration.go b/enumeration/resource/aws/aws_s3_bucket_analytics_configuration.go new file mode 100644 index 00000000..5f894cf4 --- /dev/null +++ b/enumeration/resource/aws/aws_s3_bucket_analytics_configuration.go @@ -0,0 +1,14 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsS3BucketAnalyticsConfigurationResourceType = "aws_s3_bucket_analytics_configuration" + +func initAwsS3BucketAnalyticsConfigurationMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsS3BucketAnalyticsConfigurationResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "alias": *res.Attributes().GetString("region"), + } + }) + resourceSchemaRepository.SetFlags(AwsS3BucketAnalyticsConfigurationResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_s3_bucket_inventory.go b/enumeration/resource/aws/aws_s3_bucket_inventory.go new file mode 100644 index 00000000..6e07eaad --- /dev/null +++ b/enumeration/resource/aws/aws_s3_bucket_inventory.go @@ -0,0 +1,14 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsS3BucketInventoryResourceType = "aws_s3_bucket_inventory" + +func initAwsS3BucketInventoryMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsS3BucketInventoryResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "alias": *res.Attributes().GetString("region"), + } + }) + resourceSchemaRepository.SetFlags(AwsS3BucketInventoryResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_s3_bucket_metric.go b/enumeration/resource/aws/aws_s3_bucket_metric.go new file mode 100644 index 00000000..29b238d6 --- /dev/null +++ b/enumeration/resource/aws/aws_s3_bucket_metric.go @@ -0,0 +1,14 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsS3BucketMetricResourceType = "aws_s3_bucket_metric" + +func initAwsS3BucketMetricMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsS3BucketMetricResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "alias": *res.Attributes().GetString("region"), + } + }) + resourceSchemaRepository.SetFlags(AwsS3BucketMetricResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_s3_bucket_notification.go b/enumeration/resource/aws/aws_s3_bucket_notification.go new file mode 100644 index 00000000..1f425e1f --- /dev/null +++ b/enumeration/resource/aws/aws_s3_bucket_notification.go @@ -0,0 +1,14 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsS3BucketNotificationResourceType = "aws_s3_bucket_notification" + +func initAwsS3BucketNotificationMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsS3BucketNotificationResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "alias": *res.Attributes().GetString("region"), + } + }) + resourceSchemaRepository.SetFlags(AwsS3BucketNotificationResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_s3_bucket_policy.go b/enumeration/resource/aws/aws_s3_bucket_policy.go new file mode 100644 index 00000000..35ec7754 --- /dev/null +++ b/enumeration/resource/aws/aws_s3_bucket_policy.go @@ -0,0 +1,21 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsS3BucketPolicyResourceType = "aws_s3_bucket_policy" + +func initAwsS3BucketPolicyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsS3BucketPolicyResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "alias": *res.Attributes().GetString("region"), + } + }) + resourceSchemaRepository.UpdateSchema(AwsS3BucketPolicyResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ + "policy": func(attributeSchema *resource.AttributeSchema) { + attributeSchema.JsonString = true + }, + }) + resourceSchemaRepository.SetFlags(AwsS3BucketPolicyResourceType, resource.FlagDeepMode) +} diff --git a/pkg/resource/aws/aws_s3_bucket_public_access_block.go b/enumeration/resource/aws/aws_s3_bucket_public_access_block.go similarity index 100% rename from pkg/resource/aws/aws_s3_bucket_public_access_block.go rename to enumeration/resource/aws/aws_s3_bucket_public_access_block.go diff --git a/enumeration/resource/aws/aws_security_group.go b/enumeration/resource/aws/aws_security_group.go new file mode 100644 index 00000000..55510c70 --- /dev/null +++ b/enumeration/resource/aws/aws_security_group.go @@ -0,0 +1,9 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsSecurityGroupResourceType = "aws_security_group" + +func initAwsSecurityGroupMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsSecurityGroupResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_security_group_rule.go b/enumeration/resource/aws/aws_security_group_rule.go new file mode 100644 index 00000000..48436ac6 --- /dev/null +++ b/enumeration/resource/aws/aws_security_group_rule.go @@ -0,0 +1,174 @@ +package aws + +import ( + "bytes" + "fmt" + "strings" + + "github.com/hashicorp/terraform/flatmap" + "github.com/hashicorp/terraform/helper/hashcode" + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsSecurityGroupRuleResourceType = "aws_security_group_rule" + +func CreateSecurityGroupRuleIdHash(attrs *resource.Attributes) string { + var buf bytes.Buffer + buf.WriteString(fmt.Sprintf("%s-", *attrs.GetString("security_group_id"))) + if attrs.GetInt("from_port") != nil && *attrs.GetInt("from_port") > 0 { + buf.WriteString(fmt.Sprintf("%d-", *attrs.GetInt("from_port"))) + } + if attrs.GetInt("to_port") != nil && *attrs.GetInt("to_port") > 0 { + buf.WriteString(fmt.Sprintf("%d-", *attrs.GetInt("to_port"))) + } + buf.WriteString(fmt.Sprintf("%s-", *attrs.GetString("protocol"))) + buf.WriteString(fmt.Sprintf("%s-", *attrs.GetString("type"))) + + if attrs.GetSlice("cidr_blocks") != nil { + for _, v := range attrs.GetSlice("cidr_blocks") { + buf.WriteString(fmt.Sprintf("%s-", v)) + } + } + + if attrs.GetSlice("ipv6_cidr_blocks") != nil { + for _, v := range attrs.GetSlice("ipv6_cidr_blocks") { + buf.WriteString(fmt.Sprintf("%s-", v)) + } + } + + if attrs.GetSlice("prefix_list_ids") != nil { + for _, v := range attrs.GetSlice("prefix_list_ids") { + buf.WriteString(fmt.Sprintf("%s-", v)) + } + } + + if (attrs.GetBool("self") != nil && *attrs.GetBool("self")) || + (attrs.GetString("source_security_group_id") != nil && *attrs.GetString("source_security_group_id") != "") { + if attrs.GetBool("self") != nil && *attrs.GetBool("self") { + buf.WriteString(fmt.Sprintf("%s-", *attrs.GetString("security_group_id"))) + } else { + buf.WriteString(fmt.Sprintf("%s-", *attrs.GetString("source_security_group_id"))) + } + buf.WriteString("-") + } + + return fmt.Sprintf("sgrule-%d", hashcode.String(buf.String())) +} + +func initAwsSecurityGroupRuleMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsSecurityGroupRuleResourceType, func(res *resource.Resource) map[string]string { + attrs := make(map[string]interface{}) + if v, ok := res.Attributes().Get("type"); ok { + attrs["type"] = v + } + if v, ok := res.Attributes().Get("protocol"); ok { + attrs["protocol"] = v + } + if v := res.Attributes().GetInt("from_port"); v != nil { + attrs["from_port"] = *v + } + if v := res.Attributes().GetInt("to_port"); v != nil { + attrs["to_port"] = *v + } + if v, ok := res.Attributes().Get("security_group_id"); ok { + attrs["security_group_id"] = v + } + if v, ok := res.Attributes().Get("self"); ok { + attrs["self"] = v + } + if v, ok := res.Attributes().Get("cidr_blocks"); ok { + attrs["cidr_blocks"] = v + } + if v, ok := res.Attributes().Get("ipv6_cidr_blocks"); ok { + attrs["ipv6_cidr_blocks"] = v + } + if v, ok := res.Attributes().Get("prefix_list_ids"); ok { + attrs["prefix_list_ids"] = v + } + if v, ok := res.Attributes().Get("source_security_group_id"); ok { + attrs["source_security_group_id"] = v + } + return flatmap.Flatten(attrs) + }) + resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsSecurityGroupRuleResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if sgID := val.GetString("security_group_id"); sgID != nil && *sgID != "" { + attrs["SecurityGroup"] = *sgID + } + if protocol := val.GetString("protocol"); protocol != nil && *protocol != "" { + if *protocol == "-1" { + *protocol = "All" + } + attrs["Protocol"] = *protocol + } + fromPort := val.GetInt("from_port") + toPort := val.GetInt("to_port") + if fromPort != nil && toPort != nil { + portRange := "All" + if *fromPort != 0 && *fromPort == *toPort { + portRange = fmt.Sprintf("%d", *fromPort) + } + if *fromPort != 0 && *toPort != 0 && *fromPort != *toPort { + portRange = fmt.Sprintf("%d-%d", *fromPort, *toPort) + } + attrs["Ports"] = portRange + } + ty := val.GetString("type") + if ty != nil && *ty != "" { + attrs["Type"] = *ty + var sourceOrDestination string + switch *ty { + case "egress": + sourceOrDestination = "Destination" + case "ingress": + sourceOrDestination = "Source" + } + if ipv4 := val.GetSlice("cidr_blocks"); len(ipv4) > 0 { + attrs[sourceOrDestination] = join(ipv4, ", ") + } + if ipv6 := val.GetSlice("ipv6_cidr_blocks"); len(ipv6) > 0 { + attrs[sourceOrDestination] = join(ipv6, ", ") + } + if prefixList := val.GetSlice("prefix_list_ids"); len(prefixList) > 0 { + attrs[sourceOrDestination] = join(prefixList, ", ") + } + if sourceSgID := val.GetString("source_security_group_id"); sourceSgID != nil && *sourceSgID != "" { + attrs[sourceOrDestination] = *sourceSgID + } + } + return attrs + }) + resourceSchemaRepository.SetFlags(AwsSecurityGroupRuleResourceType, resource.FlagDeepMode) +} + +func join(elems []interface{}, sep string) string { + firstElemt, ok := elems[0].(string) + if !ok { + panic("cannot join a slice that contains something else than strings") + } + switch len(elems) { + case 0: + return "" + case 1: + + return firstElemt + } + n := len(sep) * (len(elems) - 1) + for i := 0; i < len(elems); i++ { + n += len(elems[i].(string)) + } + + var b strings.Builder + b.Grow(n) + b.WriteString(firstElemt) + for _, s := range elems[1:] { + b.WriteString(sep) + elem, ok := s.(string) + if !ok { + panic("cannot join a slice that contains something else than strings") + } + b.WriteString(elem) + } + return b.String() +} diff --git a/enumeration/resource/aws/aws_sns_topic.go b/enumeration/resource/aws/aws_sns_topic.go new file mode 100644 index 00000000..c60da86e --- /dev/null +++ b/enumeration/resource/aws/aws_sns_topic.go @@ -0,0 +1,33 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsSnsTopicResourceType = "aws_sns_topic" + +func initSnsTopicMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsSnsTopicResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "topic_arn": res.ResourceId(), + } + }) + resourceSchemaRepository.UpdateSchema(AwsSnsTopicResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ + "delivery_policy": func(attributeSchema *resource.AttributeSchema) { + attributeSchema.JsonString = true + }, + "policy": func(attributeSchema *resource.AttributeSchema) { + attributeSchema.JsonString = true + }, + }) + resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsSnsTopicResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + if displayName := val.GetString("display_name"); displayName != nil && *displayName != "" { + attrs["DisplayName"] = *displayName + } + } + return attrs + }) + resourceSchemaRepository.SetFlags(AwsSnsTopicResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_sns_topic_policy.go b/enumeration/resource/aws/aws_sns_topic_policy.go new file mode 100644 index 00000000..25869100 --- /dev/null +++ b/enumeration/resource/aws/aws_sns_topic_policy.go @@ -0,0 +1,22 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsSnsTopicPolicyResourceType = "aws_sns_topic_policy" + +func initSnsTopicPolicyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsSnsTopicPolicyResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "topic_arn": res.ResourceId(), + } + }) + + resourceSchemaRepository.UpdateSchema(AwsSnsTopicPolicyResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ + "policy": func(attributeSchema *resource.AttributeSchema) { + attributeSchema.JsonString = true + }, + }) + resourceSchemaRepository.SetFlags(AwsSnsTopicPolicyResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_sns_topic_subscription.go b/enumeration/resource/aws/aws_sns_topic_subscription.go new file mode 100644 index 00000000..0cde9ec5 --- /dev/null +++ b/enumeration/resource/aws/aws_sns_topic_subscription.go @@ -0,0 +1,26 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsSnsTopicSubscriptionResourceType = "aws_sns_topic_subscription" + +func initSnsTopicSubscriptionMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AwsSnsTopicSubscriptionResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "SubscriptionId": res.ResourceId(), + } + }) + + resourceSchemaRepository.UpdateSchema(AwsSnsTopicSubscriptionResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ + "delivery_policy": func(attributeSchema *resource.AttributeSchema) { + attributeSchema.JsonString = true + }, + "filter_policy": func(attributeSchema *resource.AttributeSchema) { + attributeSchema.JsonString = true + }, + }) + + resourceSchemaRepository.SetFlags(AwsSnsTopicSubscriptionResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_sqs_queue.go b/enumeration/resource/aws/aws_sqs_queue.go new file mode 100644 index 00000000..b6ee1dcf --- /dev/null +++ b/enumeration/resource/aws/aws_sqs_queue.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsSqsQueueResourceType = "aws_sqs_queue" + +func initSqsQueueMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsSqsQueueResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_sqs_queue_policy.go b/enumeration/resource/aws/aws_sqs_queue_policy.go new file mode 100644 index 00000000..422e2d03 --- /dev/null +++ b/enumeration/resource/aws/aws_sqs_queue_policy.go @@ -0,0 +1,16 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsSqsQueuePolicyResourceType = "aws_sqs_queue_policy" + +func initAwsSQSQueuePolicyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.UpdateSchema(AwsSqsQueuePolicyResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ + "policy": func(attributeSchema *resource.AttributeSchema) { + attributeSchema.JsonString = true + }, + }) + resourceSchemaRepository.SetFlags(AwsSqsQueuePolicyResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_subnet.go b/enumeration/resource/aws/aws_subnet.go new file mode 100644 index 00000000..e25361a1 --- /dev/null +++ b/enumeration/resource/aws/aws_subnet.go @@ -0,0 +1,11 @@ +package aws + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AwsSubnetResourceType = "aws_subnet" + +func initAwsSubnetMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsSubnetResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/aws_vpc.go b/enumeration/resource/aws/aws_vpc.go new file mode 100644 index 00000000..8d7ae9ac --- /dev/null +++ b/enumeration/resource/aws/aws_vpc.go @@ -0,0 +1,9 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +const AwsVpcResourceType = "aws_vpc" + +func initAwsVpcMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AwsVpcResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/aws/metadata_test.go b/enumeration/resource/aws/metadata_test.go new file mode 100644 index 00000000..3dafb326 --- /dev/null +++ b/enumeration/resource/aws/metadata_test.go @@ -0,0 +1,141 @@ +package aws + +import ( + tf "github.com/snyk/driftctl/enumeration/terraform" + + "testing" + + "github.com/snyk/driftctl/enumeration/resource" + testresource "github.com/snyk/driftctl/test/resource" + "github.com/stretchr/testify/assert" +) + +func TestAWS_Metadata_Flags(t *testing.T) { + testcases := map[string][]resource.Flags{ + AwsAmiResourceType: {resource.FlagDeepMode}, + AwsApiGatewayAccountResourceType: {}, + AwsApiGatewayApiKeyResourceType: {}, + AwsApiGatewayAuthorizerResourceType: {}, + AwsApiGatewayBasePathMappingResourceType: {}, + AwsApiGatewayDeploymentResourceType: {}, + AwsApiGatewayDomainNameResourceType: {}, + AwsApiGatewayGatewayResponseResourceType: {}, + AwsApiGatewayIntegrationResourceType: {}, + AwsApiGatewayIntegrationResponseResourceType: {}, + AwsApiGatewayMethodResourceType: {}, + AwsApiGatewayMethodResponseResourceType: {}, + AwsApiGatewayMethodSettingsResourceType: {}, + AwsApiGatewayModelResourceType: {}, + AwsApiGatewayRequestValidatorResourceType: {}, + AwsApiGatewayResourceResourceType: {}, + AwsApiGatewayRestApiResourceType: {}, + AwsApiGatewayRestApiPolicyResourceType: {}, + AwsApiGatewayStageResourceType: {}, + AwsApiGatewayVpcLinkResourceType: {}, + AwsApiGatewayV2ApiResourceType: {}, + AwsApiGatewayV2RouteResourceType: {}, + AwsApiGatewayV2DeploymentResourceType: {}, + AwsApiGatewayV2VpcLinkResourceType: {}, + AwsApiGatewayV2AuthorizerResourceType: {}, + AwsApiGatewayV2RouteResponseResourceType: {}, + AwsApiGatewayV2DomainNameResourceType: {}, + AwsApiGatewayV2ModelResourceType: {}, + AwsApiGatewayV2StageResourceType: {}, + AwsApiGatewayV2MappingResourceType: {}, + AwsApiGatewayV2IntegrationResourceType: {}, + AwsApiGatewayV2IntegrationResponseResourceType: {}, + AwsAppAutoscalingPolicyResourceType: {resource.FlagDeepMode}, + AwsAppAutoscalingScheduledActionResourceType: {}, + AwsAppAutoscalingTargetResourceType: {resource.FlagDeepMode}, + AwsCloudformationStackResourceType: {resource.FlagDeepMode}, + AwsCloudfrontDistributionResourceType: {resource.FlagDeepMode}, + AwsDbInstanceResourceType: {resource.FlagDeepMode}, + AwsDbSubnetGroupResourceType: {resource.FlagDeepMode}, + AwsDefaultNetworkACLResourceType: {resource.FlagDeepMode}, + AwsDefaultRouteTableResourceType: {resource.FlagDeepMode}, + AwsDefaultSecurityGroupResourceType: {resource.FlagDeepMode}, + AwsDefaultSubnetResourceType: {resource.FlagDeepMode}, + AwsDefaultVpcResourceType: {resource.FlagDeepMode}, + AwsDynamodbTableResourceType: {resource.FlagDeepMode}, + AwsEbsEncryptionByDefaultResourceType: {resource.FlagDeepMode}, + AwsEbsSnapshotResourceType: {resource.FlagDeepMode}, + AwsEbsVolumeResourceType: {resource.FlagDeepMode}, + AwsEcrRepositoryResourceType: {resource.FlagDeepMode}, + AwsEipResourceType: {resource.FlagDeepMode}, + AwsEipAssociationResourceType: {resource.FlagDeepMode}, + AwsElastiCacheClusterResourceType: {}, + AwsIamAccessKeyResourceType: {resource.FlagDeepMode}, + AwsIamPolicyResourceType: {resource.FlagDeepMode}, + AwsIamPolicyAttachmentResourceType: {resource.FlagDeepMode}, + AwsIamRoleResourceType: {resource.FlagDeepMode}, + AwsIamRolePolicyResourceType: {resource.FlagDeepMode}, + AwsIamRolePolicyAttachmentResourceType: {resource.FlagDeepMode}, + AwsIamUserResourceType: {resource.FlagDeepMode}, + AwsIamUserPolicyResourceType: {resource.FlagDeepMode}, + AwsIamUserPolicyAttachmentResourceType: {resource.FlagDeepMode}, + AwsIamGroupPolicyResourceType: {}, + AwsIamGroupPolicyAttachmentResourceType: {}, + AwsInstanceResourceType: {resource.FlagDeepMode}, + AwsInternetGatewayResourceType: {resource.FlagDeepMode}, + AwsKeyPairResourceType: {resource.FlagDeepMode}, + AwsKmsAliasResourceType: {resource.FlagDeepMode}, + AwsKmsKeyResourceType: {resource.FlagDeepMode}, + AwsLambdaEventSourceMappingResourceType: {resource.FlagDeepMode}, + AwsLambdaFunctionResourceType: {resource.FlagDeepMode}, + AwsNatGatewayResourceType: {resource.FlagDeepMode}, + AwsNetworkACLResourceType: {resource.FlagDeepMode}, + AwsRDSClusterResourceType: {resource.FlagDeepMode}, + AwsRDSClusterInstanceResourceType: {}, + AwsRouteResourceType: {resource.FlagDeepMode}, + AwsRoute53HealthCheckResourceType: {resource.FlagDeepMode}, + AwsRoute53RecordResourceType: {resource.FlagDeepMode}, + AwsRoute53ZoneResourceType: {resource.FlagDeepMode}, + AwsRouteTableResourceType: {resource.FlagDeepMode}, + AwsRouteTableAssociationResourceType: {resource.FlagDeepMode}, + AwsS3BucketResourceType: {resource.FlagDeepMode}, + AwsS3BucketAnalyticsConfigurationResourceType: {resource.FlagDeepMode}, + AwsS3BucketInventoryResourceType: {resource.FlagDeepMode}, + AwsS3BucketMetricResourceType: {resource.FlagDeepMode}, + AwsS3BucketNotificationResourceType: {resource.FlagDeepMode}, + AwsS3BucketPolicyResourceType: {resource.FlagDeepMode}, + AwsS3BucketPublicAccessBlockResourceType: {}, + AwsSecurityGroupResourceType: {resource.FlagDeepMode}, + AwsSnsTopicResourceType: {resource.FlagDeepMode}, + AwsSnsTopicPolicyResourceType: {resource.FlagDeepMode}, + AwsSnsTopicSubscriptionResourceType: {resource.FlagDeepMode}, + AwsSqsQueueResourceType: {resource.FlagDeepMode}, + AwsSqsQueuePolicyResourceType: {resource.FlagDeepMode}, + AwsSubnetResourceType: {resource.FlagDeepMode}, + AwsVpcResourceType: {resource.FlagDeepMode}, + AwsSecurityGroupRuleResourceType: {resource.FlagDeepMode}, + AwsNetworkACLRuleResourceType: {resource.FlagDeepMode}, + AwsLaunchTemplateResourceType: {resource.FlagDeepMode}, + AwsLaunchConfigurationResourceType: {}, + AwsLoadBalancerResourceType: {}, + AwsApplicationLoadBalancerResourceType: {}, + AwsClassicLoadBalancerResourceType: {}, + AwsLoadBalancerListenerResourceType: {}, + AwsApplicationLoadBalancerListenerResourceType: {}, + AwsIamGroupResourceType: {}, + AwsEcrRepositoryPolicyResourceType: {}, + } + + schemaRepository := testresource.InitFakeSchemaRepository(tf.AWS, "3.19.0") + InitResourcesMetadata(schemaRepository) + + for ty, flags := range testcases { + t.Run(ty, func(tt *testing.T) { + sch, exist := schemaRepository.GetSchema(ty) + assert.True(tt, exist) + + if len(flags) == 0 { + assert.Equal(tt, resource.Flags(0x0), sch.Flags, "should not have any flag") + return + } + + for _, flag := range flags { + assert.Truef(tt, sch.Flags.HasFlag(flag), "should have given flag %d", flag) + } + }) + } +} diff --git a/enumeration/resource/aws/metadatas.go b/enumeration/resource/aws/metadatas.go new file mode 100644 index 00000000..c6aac2b9 --- /dev/null +++ b/enumeration/resource/aws/metadatas.go @@ -0,0 +1,70 @@ +package aws + +import "github.com/snyk/driftctl/enumeration/resource" + +func InitResourcesMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + initAwsAmiMetaData(resourceSchemaRepository) + initAwsCloudfrontDistributionMetaData(resourceSchemaRepository) + initAwsDbInstanceMetaData(resourceSchemaRepository) + initAwsDbSubnetGroupMetaData(resourceSchemaRepository) + initAwsDefaultSecurityGroupMetaData(resourceSchemaRepository) + initAwsDefaultSubnetMetaData(resourceSchemaRepository) + initAwsDefaultVpcMetaData(resourceSchemaRepository) + initAwsDefaultRouteTableMetadata(resourceSchemaRepository) + initAwsDynamodbTableMetaData(resourceSchemaRepository) + initAwsEbsSnapshotMetaData(resourceSchemaRepository) + initAwsInstanceMetaData(resourceSchemaRepository) + initAwsInternetGatewayMetaData(resourceSchemaRepository) + initAwsEbsVolumeMetaData(resourceSchemaRepository) + initAwsEipMetaData(resourceSchemaRepository) + initAwsEipAssociationMetaData(resourceSchemaRepository) + initAwsS3BucketMetaData(resourceSchemaRepository) + initAwsS3BucketPolicyMetaData(resourceSchemaRepository) + initAwsS3BucketInventoryMetadata(resourceSchemaRepository) + initAwsS3BucketMetricMetadata(resourceSchemaRepository) + initAwsS3BucketNotificationMetadata(resourceSchemaRepository) + initAwsS3BucketAnalyticsConfigurationMetaData(resourceSchemaRepository) + initAwsEcrRepositoryMetaData(resourceSchemaRepository) + initAwsRouteMetaData(resourceSchemaRepository) + initAwsRouteTableAssociationMetaData(resourceSchemaRepository) + initAwsRoute53RecordMetaData(resourceSchemaRepository) + initAwsRoute53ZoneMetaData(resourceSchemaRepository) + initAwsRoute53HealthCheckMetaData(resourceSchemaRepository) + initAwsRouteTableMetaData(resourceSchemaRepository) + initSnsTopicSubscriptionMetaData(resourceSchemaRepository) + initSnsTopicPolicyMetaData(resourceSchemaRepository) + initSnsTopicMetaData(resourceSchemaRepository) + initSqsQueueMetaData(resourceSchemaRepository) + initAwsIAMAccessKeyMetaData(resourceSchemaRepository) + initAwsIAMPolicyMetaData(resourceSchemaRepository) + initAwsIAMPolicyAttachmentMetaData(resourceSchemaRepository) + initAwsIAMRoleMetaData(resourceSchemaRepository) + initAwsIAMRolePolicyMetaData(resourceSchemaRepository) + initAwsIamRolePolicyAttachmentMetaData(resourceSchemaRepository) + initAwsIamUserPolicyAttachmentMetaData(resourceSchemaRepository) + initAwsIAMUserMetaData(resourceSchemaRepository) + initAwsIAMUserPolicyMetaData(resourceSchemaRepository) + initAwsKeyPairMetaData(resourceSchemaRepository) + initAwsKmsKeyMetaData(resourceSchemaRepository) + initAwsKmsAliasMetaData(resourceSchemaRepository) + initAwsLambdaFunctionMetaData(resourceSchemaRepository) + initAwsLambdaEventSourceMappingMetaData(resourceSchemaRepository) + initNatGatewayMetaData(resourceSchemaRepository) + initAwsNetworkACLMetaData(resourceSchemaRepository) + initAwsNetworkACLRuleMetaData(resourceSchemaRepository) + initAwsDefaultNetworkACLMetaData(resourceSchemaRepository) + initAwsSubnetMetaData(resourceSchemaRepository) + initAwsSQSQueuePolicyMetaData(resourceSchemaRepository) + initAwsSecurityGroupRuleMetaData(resourceSchemaRepository) + initAwsSecurityGroupMetaData(resourceSchemaRepository) + initAwsRDSClusterMetaData(resourceSchemaRepository) + initAwsCloudformationStackMetaData(resourceSchemaRepository) + initAwsVpcMetaData(resourceSchemaRepository) + initAwsAppAutoscalingTargetMetaData(resourceSchemaRepository) + initAwsAppAutoscalingPolicyMetaData(resourceSchemaRepository) + initAwsLaunchTemplateMetaData(resourceSchemaRepository) + initAwsApiGatewayV2ModelMetaData(resourceSchemaRepository) + initAwsApiGatewayV2MappingMetaData(resourceSchemaRepository) + initAwsEbsEncryptionByDefaultMetaData(resourceSchemaRepository) + initAwsLoadBalancerMetaData(resourceSchemaRepository) +} diff --git a/enumeration/resource/azurerm/azurerm_container_registry.go b/enumeration/resource/azurerm/azurerm_container_registry.go new file mode 100644 index 00000000..4a6d52b6 --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_container_registry.go @@ -0,0 +1,16 @@ +package azurerm + +import "github.com/snyk/driftctl/enumeration/resource" + +const AzureContainerRegistryResourceType = "azurerm_container_registry" + +func initAzureContainerRegistryMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureContainerRegistryResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + return attrs + }) +} diff --git a/enumeration/resource/azurerm/azurerm_firewall.go b/enumeration/resource/azurerm/azurerm_firewall.go new file mode 100644 index 00000000..ded79fe2 --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_firewall.go @@ -0,0 +1,16 @@ +package azurerm + +import "github.com/snyk/driftctl/enumeration/resource" + +const AzureFirewallResourceType = "azurerm_firewall" + +func initAzureFirewallMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureFirewallResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + return attrs + }) +} diff --git a/enumeration/resource/azurerm/azurerm_image.go b/enumeration/resource/azurerm/azurerm_image.go new file mode 100644 index 00000000..a86e8d2e --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_image.go @@ -0,0 +1,19 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AzureImageResourceType = "azurerm_image" + +func initAzureImageMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureImageResourceType, func(res *resource.Resource) map[string]string { + attrs := make(map[string]string) + + if v := res.Attributes().GetString("name"); v != nil && *v != "" { + attrs["Name"] = *v + } + + return attrs + }) +} diff --git a/enumeration/resource/azurerm/azurerm_lb.go b/enumeration/resource/azurerm/azurerm_lb.go new file mode 100644 index 00000000..04a7cc0d --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_lb.go @@ -0,0 +1,16 @@ +package azurerm + +import "github.com/snyk/driftctl/enumeration/resource" + +const AzureLoadBalancerResourceType = "azurerm_lb" + +func initAzureLoadBalancerMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureLoadBalancerResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + return attrs + }) +} diff --git a/enumeration/resource/azurerm/azurerm_lb_rule.go b/enumeration/resource/azurerm/azurerm_lb_rule.go new file mode 100644 index 00000000..ea5a5fcd --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_lb_rule.go @@ -0,0 +1,21 @@ +package azurerm + +import "github.com/snyk/driftctl/enumeration/resource" + +const AzureLoadBalancerRuleResourceType = "azurerm_lb_rule" + +func initAzureLoadBalancerRuleMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(AzureLoadBalancerRuleResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "loadbalancer_id": *res.Attributes().GetString("loadbalancer_id"), + } + }) + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureLoadBalancerRuleResourceType, func(res *resource.Resource) map[string]string { + attrs := make(map[string]string) + if name := res.Attributes().GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + return attrs + }) + resourceSchemaRepository.SetFlags(AzureLoadBalancerRuleResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/azurerm/azurerm_network_security_group.go b/enumeration/resource/azurerm/azurerm_network_security_group.go new file mode 100644 index 00000000..29303ba1 --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_network_security_group.go @@ -0,0 +1,17 @@ +package azurerm + +import "github.com/snyk/driftctl/enumeration/resource" + +const AzureNetworkSecurityGroupResourceType = "azurerm_network_security_group" + +func initAzureNetworkSecurityGroupMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureNetworkSecurityGroupResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + return attrs + }) + resourceSchemaRepository.SetFlags(AzureNetworkSecurityGroupResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/azurerm/azurerm_postgresql_database.go b/enumeration/resource/azurerm/azurerm_postgresql_database.go new file mode 100644 index 00000000..7a5784f9 --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_postgresql_database.go @@ -0,0 +1,16 @@ +package azurerm + +import "github.com/snyk/driftctl/enumeration/resource" + +const AzurePostgresqlDatabaseResourceType = "azurerm_postgresql_database" + +func initAzurePostgresqlDatabaseMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePostgresqlDatabaseResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + return attrs + }) +} diff --git a/enumeration/resource/azurerm/azurerm_postgresql_server.go b/enumeration/resource/azurerm/azurerm_postgresql_server.go new file mode 100644 index 00000000..167e68c3 --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_postgresql_server.go @@ -0,0 +1,16 @@ +package azurerm + +import "github.com/snyk/driftctl/enumeration/resource" + +const AzurePostgresqlServerResourceType = "azurerm_postgresql_server" + +func initAzurePostgresqlServerMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePostgresqlServerResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + return attrs + }) +} diff --git a/enumeration/resource/azurerm/azurerm_private_dns_a_record.go b/enumeration/resource/azurerm/azurerm_private_dns_a_record.go new file mode 100644 index 00000000..014ebb7f --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_private_dns_a_record.go @@ -0,0 +1,22 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AzurePrivateDNSARecordResourceType = "azurerm_private_dns_a_record" + +func initAzurePrivateDNSARecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePrivateDNSARecordResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + if zone := val.GetString("zone_name"); zone != nil && *zone != "" { + attrs["Zone"] = *zone + } + return attrs + }) + resourceSchemaRepository.SetFlags(AzurePrivateDNSARecordResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/azurerm/azurerm_private_dns_aaaa_record.go b/enumeration/resource/azurerm/azurerm_private_dns_aaaa_record.go new file mode 100644 index 00000000..ad358292 --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_private_dns_aaaa_record.go @@ -0,0 +1,22 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AzurePrivateDNSAAAARecordResourceType = "azurerm_private_dns_aaaa_record" + +func initAzurePrivateDNSAAAARecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePrivateDNSAAAARecordResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + if zone := val.GetString("zone_name"); zone != nil && *zone != "" { + attrs["Zone"] = *zone + } + return attrs + }) + resourceSchemaRepository.SetFlags(AzurePrivateDNSAAAARecordResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/azurerm/azurerm_private_dns_cname_record.go b/enumeration/resource/azurerm/azurerm_private_dns_cname_record.go new file mode 100644 index 00000000..5673ddff --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_private_dns_cname_record.go @@ -0,0 +1,21 @@ +package azurerm + +import "github.com/snyk/driftctl/enumeration/resource" + +const AzurePrivateDNSCNameRecordResourceType = "azurerm_private_dns_cname_record" + +func initAzurePrivateDNSCNameRecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AzurePrivateDNSCNameRecordResourceType, resource.FlagDeepMode) + + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePrivateDNSCNameRecordResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + if zone := val.GetString("zone_name"); zone != nil && *zone != "" { + attrs["Zone"] = *zone + } + return attrs + }) +} diff --git a/enumeration/resource/azurerm/azurerm_private_dns_mx_record.go b/enumeration/resource/azurerm/azurerm_private_dns_mx_record.go new file mode 100644 index 00000000..478a721c --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_private_dns_mx_record.go @@ -0,0 +1,25 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AzurePrivateDNSMXRecordResourceType = "azurerm_private_dns_mx_record" + +func initAzurePrivateDNSMXRecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetNormalizeFunc(AzurePrivateDNSMXRecordResourceType, func(res *resource.Resource) { + res.Attributes().SafeDelete([]string{"timeouts"}) + }) + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePrivateDNSMXRecordResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + if zone := val.GetString("zone_name"); zone != nil && *zone != "" { + attrs["Zone"] = *zone + } + return attrs + }) + resourceSchemaRepository.SetFlags(AzurePrivateDNSMXRecordResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/azurerm/azurerm_private_dns_ptr_record.go b/enumeration/resource/azurerm/azurerm_private_dns_ptr_record.go new file mode 100644 index 00000000..59ea6135 --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_private_dns_ptr_record.go @@ -0,0 +1,22 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AzurePrivateDNSPTRRecordResourceType = "azurerm_private_dns_ptr_record" + +func initAzurePrivateDNSPTRRecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePrivateDNSPTRRecordResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + if zone := val.GetString("zone_name"); zone != nil && *zone != "" { + attrs["Zone"] = *zone + } + return attrs + }) + resourceSchemaRepository.SetFlags(AzurePrivateDNSPTRRecordResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/azurerm/azurerm_private_dns_srv_record.go b/enumeration/resource/azurerm/azurerm_private_dns_srv_record.go new file mode 100644 index 00000000..91942d1c --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_private_dns_srv_record.go @@ -0,0 +1,22 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AzurePrivateDNSSRVRecordResourceType = "azurerm_private_dns_srv_record" + +func initAzurePrivateDNSSRVRecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePrivateDNSSRVRecordResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + if zone := val.GetString("zone_name"); zone != nil && *zone != "" { + attrs["Zone"] = *zone + } + return attrs + }) + resourceSchemaRepository.SetFlags(AzurePrivateDNSSRVRecordResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/azurerm/azurerm_private_dns_txt_record.go b/enumeration/resource/azurerm/azurerm_private_dns_txt_record.go new file mode 100644 index 00000000..17367542 --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_private_dns_txt_record.go @@ -0,0 +1,22 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AzurePrivateDNSTXTRecordResourceType = "azurerm_private_dns_txt_record" + +func initAzurePrivateDNSTXTRecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePrivateDNSTXTRecordResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + if zone := val.GetString("zone_name"); zone != nil && *zone != "" { + attrs["Zone"] = *zone + } + return attrs + }) + resourceSchemaRepository.SetFlags(AzurePrivateDNSTXTRecordResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/azurerm/azurerm_private_dns_zone.go b/enumeration/resource/azurerm/azurerm_private_dns_zone.go new file mode 100644 index 00000000..19075ea3 --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_private_dns_zone.go @@ -0,0 +1,11 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AzurePrivateDNSZoneResourceType = "azurerm_private_dns_zone" + +func initAzurePrivateDNSZoneMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(AzurePrivateDNSZoneResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/azurerm/azurerm_public_ip.go b/enumeration/resource/azurerm/azurerm_public_ip.go new file mode 100644 index 00000000..ac3d0710 --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_public_ip.go @@ -0,0 +1,16 @@ +package azurerm + +import "github.com/snyk/driftctl/enumeration/resource" + +const AzurePublicIPResourceType = "azurerm_public_ip" + +func initAzurePublicIPMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePublicIPResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + return attrs + }) +} diff --git a/enumeration/resource/azurerm/azurerm_resource_group.go b/enumeration/resource/azurerm/azurerm_resource_group.go new file mode 100644 index 00000000..f30885cb --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_resource_group.go @@ -0,0 +1,16 @@ +package azurerm + +import "github.com/snyk/driftctl/enumeration/resource" + +const AzureResourceGroupResourceType = "azurerm_resource_group" + +func initAzureResourceGroupMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureResourceGroupResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + return attrs + }) +} diff --git a/enumeration/resource/azurerm/azurerm_route.go b/enumeration/resource/azurerm/azurerm_route.go new file mode 100644 index 00000000..79789fc1 --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_route.go @@ -0,0 +1,23 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AzureRouteResourceType = "azurerm_route" + +func initAzureRouteMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureRouteResourceType, func(res *resource.Resource) map[string]string { + attrs := make(map[string]string) + + if v := res.Attributes().GetString("name"); v != nil && *v != "" { + attrs["Name"] = *v + } + + if v := res.Attributes().GetString("route_table_name"); v != nil && *v != "" { + attrs["Table"] = *v + } + + return attrs + }) +} diff --git a/enumeration/resource/azurerm/azurerm_route_table.go b/enumeration/resource/azurerm/azurerm_route_table.go new file mode 100644 index 00000000..d0056f6a --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_route_table.go @@ -0,0 +1,18 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AzureRouteTableResourceType = "azurerm_route_table" + +func initAzureRouteTableMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureRouteTableResourceType, func(res *resource.Resource) map[string]string { + attrs := make(map[string]string) + + if v := res.Attributes().GetString("name"); v != nil && *v != "" { + attrs["Name"] = *v + } + return attrs + }) +} diff --git a/enumeration/resource/azurerm/azurerm_ssh_public_key.go b/enumeration/resource/azurerm/azurerm_ssh_public_key.go new file mode 100644 index 00000000..1d648ba2 --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_ssh_public_key.go @@ -0,0 +1,20 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AzureSSHPublicKeyResourceType = "azurerm_ssh_public_key" + +func initAzureSSHPublicKeyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureSSHPublicKeyResourceType, func(res *resource.Resource) map[string]string { + attrs := make(map[string]string) + + if v := res.Attributes().GetString("name"); v != nil && *v != "" { + attrs["Name"] = *v + } + + return attrs + }) + resourceSchemaRepository.SetFlags(AzureSSHPublicKeyResourceType, resource.FlagDeepMode) +} diff --git a/pkg/resource/azurerm/azurerm_storage_account.go b/enumeration/resource/azurerm/azurerm_storage_account.go similarity index 100% rename from pkg/resource/azurerm/azurerm_storage_account.go rename to enumeration/resource/azurerm/azurerm_storage_account.go diff --git a/pkg/resource/azurerm/azurerm_storage_container.go b/enumeration/resource/azurerm/azurerm_storage_container.go similarity index 100% rename from pkg/resource/azurerm/azurerm_storage_container.go rename to enumeration/resource/azurerm/azurerm_storage_container.go diff --git a/pkg/resource/azurerm/azurerm_subnet.go b/enumeration/resource/azurerm/azurerm_subnet.go similarity index 100% rename from pkg/resource/azurerm/azurerm_subnet.go rename to enumeration/resource/azurerm/azurerm_subnet.go diff --git a/enumeration/resource/azurerm/azurerm_virtual_network.go b/enumeration/resource/azurerm/azurerm_virtual_network.go new file mode 100644 index 00000000..790413f1 --- /dev/null +++ b/enumeration/resource/azurerm/azurerm_virtual_network.go @@ -0,0 +1,18 @@ +package azurerm + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +const AzureVirtualNetworkResourceType = "azurerm_virtual_network" + +func initAzureVirtualNetworkMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureVirtualNetworkResourceType, func(res *resource.Resource) map[string]string { + attrs := make(map[string]string) + + if v := res.Attributes().GetString("name"); v != nil && *v != "" { + attrs["Name"] = *v + } + return attrs + }) +} diff --git a/enumeration/resource/azurerm/metadata.go b/enumeration/resource/azurerm/metadata.go new file mode 100644 index 00000000..5051e226 --- /dev/null +++ b/enumeration/resource/azurerm/metadata.go @@ -0,0 +1,28 @@ +package azurerm + +import "github.com/snyk/driftctl/enumeration/resource" + +func InitResourcesMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + initAzureVirtualNetworkMetaData(resourceSchemaRepository) + initAzureRouteTableMetaData(resourceSchemaRepository) + initAzureRouteMetaData(resourceSchemaRepository) + initAzureResourceGroupMetadata(resourceSchemaRepository) + initAzureContainerRegistryMetadata(resourceSchemaRepository) + initAzureFirewallMetadata(resourceSchemaRepository) + initAzurePostgresqlServerMetadata(resourceSchemaRepository) + initAzurePublicIPMetadata(resourceSchemaRepository) + initAzurePostgresqlDatabaseMetadata(resourceSchemaRepository) + initAzureNetworkSecurityGroupMetadata(resourceSchemaRepository) + initAzureLoadBalancerMetadata(resourceSchemaRepository) + initAzurePrivateDNSZoneMetaData(resourceSchemaRepository) + initAzurePrivateDNSARecordMetaData(resourceSchemaRepository) + initAzurePrivateDNSAAAARecordMetaData(resourceSchemaRepository) + initAzurePrivateDNSPTRRecordMetaData(resourceSchemaRepository) + initAzurePrivateDNSSRVRecordMetaData(resourceSchemaRepository) + initAzurePrivateDNSMXRecordMetaData(resourceSchemaRepository) + initAzurePrivateDNSTXTRecordMetaData(resourceSchemaRepository) + initAzureImageMetaData(resourceSchemaRepository) + initAzureSSHPublicKeyMetaData(resourceSchemaRepository) + initAzurePrivateDNSCNameRecordMetaData(resourceSchemaRepository) + initAzureLoadBalancerRuleMetadata(resourceSchemaRepository) +} diff --git a/enumeration/resource/azurerm/metadata_test.go b/enumeration/resource/azurerm/metadata_test.go new file mode 100644 index 00000000..6dff89a1 --- /dev/null +++ b/enumeration/resource/azurerm/metadata_test.go @@ -0,0 +1,60 @@ +package azurerm + +import ( + tf "github.com/snyk/driftctl/enumeration/terraform" + + "testing" + + "github.com/snyk/driftctl/enumeration/resource" + testresource "github.com/snyk/driftctl/test/resource" + "github.com/stretchr/testify/assert" +) + +func TestAzureMetadata_Flags(t *testing.T) { + testcases := map[string][]resource.Flags{ + AzureContainerRegistryResourceType: {}, + AzureFirewallResourceType: {}, + AzurePostgresqlServerResourceType: {}, + AzurePostgresqlDatabaseResourceType: {}, + AzurePublicIPResourceType: {}, + AzureResourceGroupResourceType: {}, + AzureRouteResourceType: {}, + AzureRouteTableResourceType: {}, + AzureStorageAccountResourceType: {}, + AzureStorageContainerResourceType: {}, + AzureSubnetResourceType: {}, + AzureVirtualNetworkResourceType: {}, + AzureNetworkSecurityGroupResourceType: {resource.FlagDeepMode}, + AzureLoadBalancerResourceType: {}, + AzurePrivateDNSZoneResourceType: {resource.FlagDeepMode}, + AzurePrivateDNSARecordResourceType: {resource.FlagDeepMode}, + AzurePrivateDNSAAAARecordResourceType: {resource.FlagDeepMode}, + AzurePrivateDNSCNameRecordResourceType: {resource.FlagDeepMode}, + AzurePrivateDNSPTRRecordResourceType: {resource.FlagDeepMode}, + AzurePrivateDNSMXRecordResourceType: {resource.FlagDeepMode}, + AzurePrivateDNSSRVRecordResourceType: {resource.FlagDeepMode}, + AzurePrivateDNSTXTRecordResourceType: {resource.FlagDeepMode}, + AzureImageResourceType: {}, + AzureSSHPublicKeyResourceType: {resource.FlagDeepMode}, + AzureLoadBalancerRuleResourceType: {resource.FlagDeepMode}, + } + + schemaRepository := testresource.InitFakeSchemaRepository(tf.AZURE, "2.71.0") + InitResourcesMetadata(schemaRepository) + + for ty, flags := range testcases { + t.Run(ty, func(tt *testing.T) { + sch, exist := schemaRepository.GetSchema(ty) + assert.True(tt, exist) + + if len(flags) == 0 { + assert.Equal(tt, resource.Flags(0x0), sch.Flags, "should not have any flag") + return + } + + for _, flag := range flags { + assert.Truef(tt, sch.Flags.HasFlag(flag), "should have given flag %d", flag) + } + }) + } +} diff --git a/enumeration/resource/deserializer.go b/enumeration/resource/deserializer.go new file mode 100644 index 00000000..aa71080a --- /dev/null +++ b/enumeration/resource/deserializer.go @@ -0,0 +1,48 @@ +package resource + +import ( + "encoding/json" + + "github.com/zclconf/go-cty/cty" + ctyjson "github.com/zclconf/go-cty/cty/json" +) + +type Deserializer struct { + factory ResourceFactory +} + +func NewDeserializer(factory ResourceFactory) *Deserializer { + return &Deserializer{factory} +} + +func (s *Deserializer) Deserialize(ty string, rawList []cty.Value) ([]*Resource, error) { + resources := make([]*Resource, 0) + for _, rawResource := range rawList { + rawResource := rawResource + res, err := s.DeserializeOne(ty, rawResource) + if err != nil { + return nil, err + } + resources = append(resources, res) + } + return resources, nil +} + +func (s *Deserializer) DeserializeOne(ty string, value cty.Value) (*Resource, error) { + if value.IsNull() { + return nil, nil + } + + // Marked values cannot be deserialized to JSON. + // For example, this ensures we can deserialize sensitive values too. + unmarkedVal, _ := value.UnmarkDeep() + + var attrs Attributes + bytes, _ := ctyjson.Marshal(unmarkedVal, unmarkedVal.Type()) + err := json.Unmarshal(bytes, &attrs) + if err != nil { + return nil, err + } + + return s.factory.CreateAbstractResource(ty, value.GetAttr("id").AsString(), attrs), nil +} diff --git a/enumeration/resource/github/github_branch_protection.go b/enumeration/resource/github/github_branch_protection.go new file mode 100644 index 00000000..03c5b22d --- /dev/null +++ b/enumeration/resource/github/github_branch_protection.go @@ -0,0 +1,39 @@ +// GENERATED, DO NOT EDIT THIS FILE +package github + +import ( + "encoding/base64" + + "github.com/snyk/driftctl/enumeration/resource" +) + +const GithubBranchProtectionResourceType = "github_branch_protection" + +func initGithubBranchProtectionMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(GithubBranchProtectionResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + pattern := val.GetString("pattern") + repoID := val.GetString("repository_id") + if pattern != nil && *pattern != "" { + id := "" + if repoID != nil && *repoID != "" { + decodedID, err := base64.StdEncoding.DecodeString(*repoID) + if err == nil { + id = string(decodedID) + } + } + if id == "" { + attrs["Branch"] = *pattern + attrs["Id"] = res.ResourceId() + return attrs + } + attrs["Branch"] = *pattern + attrs["RepoId"] = id + return attrs + } + attrs["Id"] = res.ResourceId() + return attrs + }) + resourceSchemaRepository.SetFlags(GithubBranchProtectionResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/github/github_membership.go b/enumeration/resource/github/github_membership.go new file mode 100644 index 00000000..2a5dbd67 --- /dev/null +++ b/enumeration/resource/github/github_membership.go @@ -0,0 +1,10 @@ +// GENERATED, DO NOT EDIT THIS FILE +package github + +import "github.com/snyk/driftctl/enumeration/resource" + +const GithubMembershipResourceType = "github_membership" + +func initGithubMembershipMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(GithubMembershipResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/github/github_repository.go b/enumeration/resource/github/github_repository.go new file mode 100644 index 00000000..80f0e31c --- /dev/null +++ b/enumeration/resource/github/github_repository.go @@ -0,0 +1,9 @@ +package github + +import "github.com/snyk/driftctl/enumeration/resource" + +const GithubRepositoryResourceType = "github_repository" + +func initGithubRepositoryMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(GithubRepositoryResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/github/github_team.go b/enumeration/resource/github/github_team.go new file mode 100644 index 00000000..6d679b39 --- /dev/null +++ b/enumeration/resource/github/github_team.go @@ -0,0 +1,19 @@ +// GENERATED, DO NOT EDIT THIS FILE +package github + +import "github.com/snyk/driftctl/enumeration/resource" + +const GithubTeamResourceType = "github_team" + +func initGithubTeamMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(GithubTeamResourceType, func(res *resource.Resource) map[string]string { + val := res.Attrs + attrs := make(map[string]string) + attrs["Id"] = res.ResourceId() + if name := val.GetString("name"); name != nil && *name != "" { + attrs["Name"] = *name + } + return attrs + }) + resourceSchemaRepository.SetFlags(GithubTeamResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/github/github_team_membership.go b/enumeration/resource/github/github_team_membership.go new file mode 100644 index 00000000..b2056879 --- /dev/null +++ b/enumeration/resource/github/github_team_membership.go @@ -0,0 +1,10 @@ +// GENERATED, DO NOT EDIT THIS FILE +package github + +import "github.com/snyk/driftctl/enumeration/resource" + +const GithubTeamMembershipResourceType = "github_team_membership" + +func initGithubTeamMembershipMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetFlags(GithubTeamMembershipResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/github/metadata_test.go b/enumeration/resource/github/metadata_test.go new file mode 100644 index 00000000..7f090258 --- /dev/null +++ b/enumeration/resource/github/metadata_test.go @@ -0,0 +1,40 @@ +package github + +import ( + tf "github.com/snyk/driftctl/enumeration/terraform" + + "testing" + + "github.com/snyk/driftctl/enumeration/resource" + testresource "github.com/snyk/driftctl/test/resource" + "github.com/stretchr/testify/assert" +) + +func TestGitHub_Metadata_Flags(t *testing.T) { + testcases := map[string][]resource.Flags{ + GithubBranchProtectionResourceType: {resource.FlagDeepMode}, + GithubMembershipResourceType: {resource.FlagDeepMode}, + GithubTeamMembershipResourceType: {resource.FlagDeepMode}, + GithubRepositoryResourceType: {resource.FlagDeepMode}, + GithubTeamResourceType: {resource.FlagDeepMode}, + } + + schemaRepository := testresource.InitFakeSchemaRepository(tf.GITHUB, "4.4.0") + InitResourcesMetadata(schemaRepository) + + for ty, flags := range testcases { + t.Run(ty, func(tt *testing.T) { + sch, exist := schemaRepository.GetSchema(ty) + assert.True(tt, exist) + + if len(flags) == 0 { + assert.Equal(tt, resource.Flags(0x0), sch.Flags, "should not have any flag") + return + } + + for _, flag := range flags { + assert.Truef(tt, sch.Flags.HasFlag(flag), "should have given flag %d", flag) + } + }) + } +} diff --git a/enumeration/resource/github/metadatas.go b/enumeration/resource/github/metadatas.go new file mode 100644 index 00000000..24622785 --- /dev/null +++ b/enumeration/resource/github/metadatas.go @@ -0,0 +1,11 @@ +package github + +import "github.com/snyk/driftctl/enumeration/resource" + +func InitResourcesMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + initGithubBranchProtectionMetaData(resourceSchemaRepository) + initGithubMembershipMetaData(resourceSchemaRepository) + initGithubRepositoryMetaData(resourceSchemaRepository) + initGithubTeamMetaData(resourceSchemaRepository) + initGithubTeamMembershipMetaData(resourceSchemaRepository) +} diff --git a/enumeration/resource/google/google_bigquery_dataset.go b/enumeration/resource/google/google_bigquery_dataset.go new file mode 100644 index 00000000..db663d49 --- /dev/null +++ b/enumeration/resource/google/google_bigquery_dataset.go @@ -0,0 +1,13 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleBigqueryDatasetResourceType = "google_bigquery_dataset" + +func initGoogleBigqueryDatasetMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleBigqueryDatasetResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "name": *res.Attrs.GetString("friendly_name"), + } + }) +} diff --git a/enumeration/resource/google/google_bigquery_table.go b/enumeration/resource/google/google_bigquery_table.go new file mode 100644 index 00000000..9c39658d --- /dev/null +++ b/enumeration/resource/google/google_bigquery_table.go @@ -0,0 +1,13 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleBigqueryTableResourceType = "google_bigquery_table" + +func initGoogleBigqueryTableMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleBigqueryTableResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "name": *res.Attrs.GetString("friendly_name"), + } + }) +} diff --git a/pkg/resource/google/google_bigtable_instance.go b/enumeration/resource/google/google_bigtable_instance.go similarity index 100% rename from pkg/resource/google/google_bigtable_instance.go rename to enumeration/resource/google/google_bigtable_instance.go diff --git a/pkg/resource/google/google_bigtable_table.go b/enumeration/resource/google/google_bigtable_table.go similarity index 100% rename from pkg/resource/google/google_bigtable_table.go rename to enumeration/resource/google/google_bigtable_table.go diff --git a/pkg/resource/google/google_cloudfunctions_function.go b/enumeration/resource/google/google_cloudfunctions_function.go similarity index 100% rename from pkg/resource/google/google_cloudfunctions_function.go rename to enumeration/resource/google/google_cloudfunctions_function.go diff --git a/pkg/resource/google/google_cloudrun_service.go b/enumeration/resource/google/google_cloudrun_service.go similarity index 100% rename from pkg/resource/google/google_cloudrun_service.go rename to enumeration/resource/google/google_cloudrun_service.go diff --git a/enumeration/resource/google/google_compute_address.go b/enumeration/resource/google/google_compute_address.go new file mode 100644 index 00000000..a3773285 --- /dev/null +++ b/enumeration/resource/google/google_compute_address.go @@ -0,0 +1,14 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleComputeAddressResourceType = "google_compute_address" + +func initGoogleComputeAddressMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeAddressResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "Name": *res.Attributes().GetString("name"), + "Address": *res.Attributes().GetString("address"), + } + }) +} diff --git a/enumeration/resource/google/google_compute_disk.go b/enumeration/resource/google/google_compute_disk.go new file mode 100644 index 00000000..ed40a222 --- /dev/null +++ b/enumeration/resource/google/google_compute_disk.go @@ -0,0 +1,13 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleComputeDiskResourceType = "google_compute_disk" + +func initGoogleComputeDiskMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeDiskResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "Name": *res.Attributes().GetString("name"), + } + }) +} diff --git a/enumeration/resource/google/google_compute_firewall.go b/enumeration/resource/google/google_compute_firewall.go new file mode 100644 index 00000000..4b836b5f --- /dev/null +++ b/enumeration/resource/google/google_compute_firewall.go @@ -0,0 +1,15 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleComputeFirewallResourceType = "google_compute_firewall" + +func initGoogleComputeFirewallMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleComputeFirewallResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "name": *res.Attrs.GetString("name"), + "project": *res.Attrs.GetString("project"), + } + }) + resourceSchemaRepository.SetFlags(GoogleComputeFirewallResourceType, resource.FlagDeepMode) +} diff --git a/pkg/resource/google/google_compute_forwarding_rule.go b/enumeration/resource/google/google_compute_forwarding_rule.go similarity index 100% rename from pkg/resource/google/google_compute_forwarding_rule.go rename to enumeration/resource/google/google_compute_forwarding_rule.go diff --git a/enumeration/resource/google/google_compute_global_address.go b/enumeration/resource/google/google_compute_global_address.go new file mode 100644 index 00000000..1c3af037 --- /dev/null +++ b/enumeration/resource/google/google_compute_global_address.go @@ -0,0 +1,14 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleComputeGlobalAddressResourceType = "google_compute_global_address" + +func initGoogleComputeGlobalAddressMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeGlobalAddressResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "Name": *res.Attributes().GetString("name"), + "Address": *res.Attributes().GetString("address"), + } + }) +} diff --git a/pkg/resource/google/google_compute_global_forwarding_rule.go b/enumeration/resource/google/google_compute_global_forwarding_rule.go similarity index 100% rename from pkg/resource/google/google_compute_global_forwarding_rule.go rename to enumeration/resource/google/google_compute_global_forwarding_rule.go diff --git a/enumeration/resource/google/google_compute_health_check.go b/enumeration/resource/google/google_compute_health_check.go new file mode 100644 index 00000000..c33b8173 --- /dev/null +++ b/enumeration/resource/google/google_compute_health_check.go @@ -0,0 +1,13 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleComputeHealthCheckResourceType = "google_compute_health_check" + +func initGoogleComputeHealthCheckMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeHealthCheckResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "Name": *res.Attributes().GetString("name"), + } + }) +} diff --git a/enumeration/resource/google/google_compute_image.go b/enumeration/resource/google/google_compute_image.go new file mode 100644 index 00000000..f52788da --- /dev/null +++ b/enumeration/resource/google/google_compute_image.go @@ -0,0 +1,13 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleComputeImageResourceType = "google_compute_image" + +func initGoogleComputeImageMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeImageResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "Name": *res.Attributes().GetString("name"), + } + }) +} diff --git a/pkg/resource/google/google_compute_instance.go b/enumeration/resource/google/google_compute_instance.go similarity index 100% rename from pkg/resource/google/google_compute_instance.go rename to enumeration/resource/google/google_compute_instance.go diff --git a/enumeration/resource/google/google_compute_instance_group.go b/enumeration/resource/google/google_compute_instance_group.go new file mode 100644 index 00000000..e2cd6c4f --- /dev/null +++ b/enumeration/resource/google/google_compute_instance_group.go @@ -0,0 +1,23 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleComputeInstanceGroupResourceType = "google_compute_instance_group" + +func initGoogleComputeInstanceGroupMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleComputeInstanceGroupResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "name": *res.Attributes().GetString("name"), + "project": *res.Attributes().GetString("project"), + "zone": *res.Attributes().GetString("location"), + } + }) + resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeInstanceGroupResourceType, func(res *resource.Resource) map[string]string { + attrs := make(map[string]string) + if v := res.Attributes().GetString("name"); v != nil && *v != "" { + attrs["Name"] = *v + } + return attrs + }) + resourceSchemaRepository.SetFlags(GoogleComputeInstanceGroupResourceType, resource.FlagDeepMode) +} diff --git a/enumeration/resource/google/google_compute_instance_group_manager.go b/enumeration/resource/google/google_compute_instance_group_manager.go new file mode 100644 index 00000000..35ab1fbb --- /dev/null +++ b/enumeration/resource/google/google_compute_instance_group_manager.go @@ -0,0 +1,15 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleComputeInstanceGroupManagerResourceType = "google_compute_instance_group_manager" + +func initComputeInstanceGroupManagerMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeInstanceGroupManagerResourceType, func(res *resource.Resource) map[string]string { + attrs := make(map[string]string) + if v := res.Attributes().GetString("name"); v != nil && *v != "" { + attrs["Name"] = *v + } + return attrs + }) +} diff --git a/enumeration/resource/google/google_compute_network.go b/enumeration/resource/google/google_compute_network.go new file mode 100644 index 00000000..850fc61a --- /dev/null +++ b/enumeration/resource/google/google_compute_network.go @@ -0,0 +1,14 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleComputeNetworkResourceType = "google_compute_network" + +func initGoogleComputeNetworkMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleComputeNetworkResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "name": *res.Attributes().GetString("name"), + } + }) + resourceSchemaRepository.SetFlags(GoogleComputeNetworkResourceType, resource.FlagDeepMode) +} diff --git a/pkg/resource/google/google_compute_node_group.go b/enumeration/resource/google/google_compute_node_group.go similarity index 100% rename from pkg/resource/google/google_compute_node_group.go rename to enumeration/resource/google/google_compute_node_group.go diff --git a/enumeration/resource/google/google_compute_router.go b/enumeration/resource/google/google_compute_router.go new file mode 100644 index 00000000..43ae2b12 --- /dev/null +++ b/enumeration/resource/google/google_compute_router.go @@ -0,0 +1,15 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleComputeRouterResourceType = "google_compute_router" + +func initGoogleComputeRouterMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleComputeRouterResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "name": *res.Attrs.GetString("name"), + "region": *res.Attrs.GetString("region"), + "project": *res.Attrs.GetString("project"), + } + }) +} diff --git a/enumeration/resource/google/google_compute_subnetwork.go b/enumeration/resource/google/google_compute_subnetwork.go new file mode 100644 index 00000000..c88670d9 --- /dev/null +++ b/enumeration/resource/google/google_compute_subnetwork.go @@ -0,0 +1,23 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleComputeSubnetworkResourceType = "google_compute_subnetwork" + +func initGoogleComputeSubnetworkMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleComputeSubnetworkResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "name": *res.Attributes().GetString("name"), + "region": *res.Attributes().GetString("region"), + } + }) + resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeSubnetworkResourceType, func(res *resource.Resource) map[string]string { + attrs := make(map[string]string) + + if v := res.Attributes().GetString("name"); v != nil && *v != "" { + attrs["Name"] = *v + } + return attrs + }) + resourceSchemaRepository.SetFlags(GoogleComputeSubnetworkResourceType, resource.FlagDeepMode) +} diff --git a/pkg/resource/google/google_dns_managed_zone.go b/enumeration/resource/google/google_dns_managed_zone.go similarity index 100% rename from pkg/resource/google/google_dns_managed_zone.go rename to enumeration/resource/google/google_dns_managed_zone.go diff --git a/pkg/resource/google/google_project_iam_binding.go b/enumeration/resource/google/google_project_iam_binding.go similarity index 100% rename from pkg/resource/google/google_project_iam_binding.go rename to enumeration/resource/google/google_project_iam_binding.go diff --git a/enumeration/resource/google/google_project_iam_member.go b/enumeration/resource/google/google_project_iam_member.go new file mode 100644 index 00000000..1d29be10 --- /dev/null +++ b/enumeration/resource/google/google_project_iam_member.go @@ -0,0 +1,17 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleProjectIamMemberResourceType = "google_project_iam_member" + +func initGoogleProjectIAMMemberMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleProjectIamMemberResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "project": *res.Attrs.GetString("project"), + "role": *res.Attrs.GetString("role"), + "member": *res.Attrs.GetString("member"), + } + }) + resourceSchemaRepository.SetFlags(GoogleProjectIamMemberResourceType, resource.FlagDeepMode) + +} diff --git a/pkg/resource/google/google_project_iam_policy.go b/enumeration/resource/google/google_project_iam_policy.go similarity index 100% rename from pkg/resource/google/google_project_iam_policy.go rename to enumeration/resource/google/google_project_iam_policy.go diff --git a/pkg/resource/google/google_sql_database_instance.go b/enumeration/resource/google/google_sql_database_instance.go similarity index 100% rename from pkg/resource/google/google_sql_database_instance.go rename to enumeration/resource/google/google_sql_database_instance.go diff --git a/enumeration/resource/google/google_storage_bucket.go b/enumeration/resource/google/google_storage_bucket.go new file mode 100644 index 00000000..42ecc5f5 --- /dev/null +++ b/enumeration/resource/google/google_storage_bucket.go @@ -0,0 +1,14 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleStorageBucketResourceType = "google_storage_bucket" + +func initGoogleStorageBucketMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleStorageBucketResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "name": res.ResourceId(), + } + }) + resourceSchemaRepository.SetFlags(GoogleStorageBucketResourceType, resource.FlagDeepMode) +} diff --git a/pkg/resource/google/google_storage_bucket_iam_binding.go b/enumeration/resource/google/google_storage_bucket_iam_binding.go similarity index 100% rename from pkg/resource/google/google_storage_bucket_iam_binding.go rename to enumeration/resource/google/google_storage_bucket_iam_binding.go diff --git a/enumeration/resource/google/google_storage_bucket_iam_member.go b/enumeration/resource/google/google_storage_bucket_iam_member.go new file mode 100644 index 00000000..3be21401 --- /dev/null +++ b/enumeration/resource/google/google_storage_bucket_iam_member.go @@ -0,0 +1,25 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +const GoogleStorageBucketIamMemberResourceType = "google_storage_bucket_iam_member" + +func initGoogleStorageBucketIamBMemberMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleStorageBucketIamMemberResourceType, func(res *resource.Resource) map[string]string { + return map[string]string{ + "bucket": *res.Attrs.GetString("bucket"), + "role": *res.Attrs.GetString("role"), + "member": *res.Attrs.GetString("member"), + } + }) + resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleStorageBucketIamMemberResourceType, func(res *resource.Resource) map[string]string { + attrs := map[string]string{ + "bucket": *res.Attrs.GetString("bucket"), + "role": *res.Attrs.GetString("role"), + "member": *res.Attrs.GetString("member"), + } + return attrs + }) + resourceSchemaRepository.SetFlags(GoogleStorageBucketIamMemberResourceType, resource.FlagDeepMode) + +} diff --git a/pkg/resource/google/google_storage_bucket_iam_policy.go b/enumeration/resource/google/google_storage_bucket_iam_policy.go similarity index 100% rename from pkg/resource/google/google_storage_bucket_iam_policy.go rename to enumeration/resource/google/google_storage_bucket_iam_policy.go diff --git a/enumeration/resource/google/metadata_test.go b/enumeration/resource/google/metadata_test.go new file mode 100644 index 00000000..7ff82e99 --- /dev/null +++ b/enumeration/resource/google/metadata_test.go @@ -0,0 +1,58 @@ +package google + +import ( + "testing" + + tf "github.com/snyk/driftctl/enumeration/terraform" + + "github.com/snyk/driftctl/enumeration/resource" + testresource "github.com/snyk/driftctl/test/resource" + "github.com/stretchr/testify/assert" +) + +func TestGoogle_Metadata_Flags(t *testing.T) { + testcases := map[string][]resource.Flags{ + GoogleBigqueryDatasetResourceType: {}, + GoogleComputeFirewallResourceType: {resource.FlagDeepMode}, + GoogleComputeInstanceResourceType: {}, + GoogleComputeInstanceGroupResourceType: {resource.FlagDeepMode}, + GoogleComputeNetworkResourceType: {resource.FlagDeepMode}, + GoogleComputeRouterResourceType: {}, + GoogleDNSManagedZoneResourceType: {}, + GoogleProjectIamBindingResourceType: {}, + GoogleProjectIamMemberResourceType: {resource.FlagDeepMode}, + GoogleProjectIamPolicyResourceType: {}, + GoogleStorageBucketResourceType: {resource.FlagDeepMode}, + GoogleStorageBucketIamBindingResourceType: {}, + GoogleStorageBucketIamMemberResourceType: {resource.FlagDeepMode}, + GoogleStorageBucketIamPolicyResourceType: {}, + GoogleBigqueryTableResourceType: {}, + GoogleComputeDiskResourceType: {}, + GoogleBigTableInstanceResourceType: {}, + GoogleComputeGlobalAddressResourceType: {}, + GoogleCloudRunServiceResourceType: {}, + GoogleComputeNodeGroupResourceType: {}, + GoogleComputeForwardingRuleResourceType: {}, + GoogleComputeInstanceGroupManagerResourceType: {}, + GoogleComputeGlobalForwardingRuleResourceType: {}, + } + + schemaRepository := testresource.InitFakeSchemaRepository(tf.GOOGLE, "3.78.0") + InitResourcesMetadata(schemaRepository) + + for ty, flags := range testcases { + t.Run(ty, func(tt *testing.T) { + sch, exist := schemaRepository.GetSchema(ty) + assert.True(tt, exist) + + if len(flags) == 0 { + assert.Equal(tt, resource.Flags(0x0), sch.Flags, "should not have any flag") + return + } + + for _, flag := range flags { + assert.Truef(tt, sch.Flags.HasFlag(flag), "should have given flag %d", flag) + } + }) + } +} diff --git a/enumeration/resource/google/metadatas.go b/enumeration/resource/google/metadatas.go new file mode 100644 index 00000000..9cb8c876 --- /dev/null +++ b/enumeration/resource/google/metadatas.go @@ -0,0 +1,22 @@ +package google + +import "github.com/snyk/driftctl/enumeration/resource" + +func InitResourcesMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { + initGoogleStorageBucketMetadata(resourceSchemaRepository) + initGoogleComputeFirewallMetadata(resourceSchemaRepository) + initGoogleComputeRouterMetadata(resourceSchemaRepository) + initGoogleComputeNetworkMetadata(resourceSchemaRepository) + initGoogleStorageBucketIamBMemberMetadata(resourceSchemaRepository) + initGoogleComputeInstanceGroupMetadata(resourceSchemaRepository) + initGoogleBigqueryDatasetMetadata(resourceSchemaRepository) + initGoogleBigqueryTableMetadata(resourceSchemaRepository) + initGoogleProjectIAMMemberMetadata(resourceSchemaRepository) + initGoogleComputeAddressMetadata(resourceSchemaRepository) + initGoogleComputeGlobalAddressMetadata(resourceSchemaRepository) + initGoogleComputeSubnetworkMetadata(resourceSchemaRepository) + initGoogleComputeDiskMetadata(resourceSchemaRepository) + initGoogleComputeImageMetadata(resourceSchemaRepository) + initGoogleComputeHealthCheckMetadata(resourceSchemaRepository) + initComputeInstanceGroupManagerMetadata(resourceSchemaRepository) +} diff --git a/enumeration/resource/mock_Supplier.go b/enumeration/resource/mock_Supplier.go new file mode 100644 index 00000000..981d4f48 --- /dev/null +++ b/enumeration/resource/mock_Supplier.go @@ -0,0 +1,33 @@ +// Code generated by mockery v0.0.0-dev. DO NOT EDIT. + +package resource + +import mock "github.com/stretchr/testify/mock" + +// MockSupplier is an autogenerated mock type for the Supplier type +type MockSupplier struct { + mock.Mock +} + +// Resources provides a mock function with given fields: +func (_m *MockSupplier) Resources() ([]*Resource, error) { + ret := _m.Called() + + var r0 []*Resource + if rf, ok := ret.Get(0).(func() []*Resource); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*Resource) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/resource/resource.go b/enumeration/resource/resource.go similarity index 100% rename from pkg/resource/resource.go rename to enumeration/resource/resource.go diff --git a/pkg/resource/resource_test.go b/enumeration/resource/resource_test.go similarity index 100% rename from pkg/resource/resource_test.go rename to enumeration/resource/resource_test.go diff --git a/pkg/resource/resource_types.go b/enumeration/resource/resource_types.go similarity index 100% rename from pkg/resource/resource_types.go rename to enumeration/resource/resource_types.go diff --git a/pkg/resource/schemas.go b/enumeration/resource/schemas.go similarity index 100% rename from pkg/resource/schemas.go rename to enumeration/resource/schemas.go diff --git a/enumeration/resource/supplier.go b/enumeration/resource/supplier.go new file mode 100644 index 00000000..054b64d0 --- /dev/null +++ b/enumeration/resource/supplier.go @@ -0,0 +1,11 @@ +package resource + +// Supplier supply the list of resource.Resource, it's the main interface to retrieve remote resources +type Supplier interface { + Resources() ([]*Resource, error) +} + +type StoppableSupplier interface { + Supplier + Stop() +} diff --git a/pkg/terraform/error/provider_download_error.go b/enumeration/terraform/error/provider_download_error.go similarity index 100% rename from pkg/terraform/error/provider_download_error.go rename to enumeration/terraform/error/provider_download_error.go diff --git a/pkg/terraform/lock/lockfile.go b/enumeration/terraform/lock/lockfile.go similarity index 100% rename from pkg/terraform/lock/lockfile.go rename to enumeration/terraform/lock/lockfile.go diff --git a/pkg/terraform/lock/lockfile_test.go b/enumeration/terraform/lock/lockfile_test.go similarity index 100% rename from pkg/terraform/lock/lockfile_test.go rename to enumeration/terraform/lock/lockfile_test.go diff --git a/pkg/terraform/lock/testdata/lockfile_empty.hcl b/enumeration/terraform/lock/testdata/lockfile_empty.hcl similarity index 100% rename from pkg/terraform/lock/testdata/lockfile_empty.hcl rename to enumeration/terraform/lock/testdata/lockfile_empty.hcl diff --git a/pkg/terraform/lock/testdata/lockfile_invalid.hcl b/enumeration/terraform/lock/testdata/lockfile_invalid.hcl similarity index 100% rename from pkg/terraform/lock/testdata/lockfile_invalid.hcl rename to enumeration/terraform/lock/testdata/lockfile_invalid.hcl diff --git a/pkg/terraform/lock/testdata/lockfile_invalid_type-1.hcl b/enumeration/terraform/lock/testdata/lockfile_invalid_type-1.hcl similarity index 100% rename from pkg/terraform/lock/testdata/lockfile_invalid_type-1.hcl rename to enumeration/terraform/lock/testdata/lockfile_invalid_type-1.hcl diff --git a/pkg/terraform/lock/testdata/lockfile_invalid_type-3.hcl b/enumeration/terraform/lock/testdata/lockfile_invalid_type-3.hcl similarity index 100% rename from pkg/terraform/lock/testdata/lockfile_invalid_type-3.hcl rename to enumeration/terraform/lock/testdata/lockfile_invalid_type-3.hcl diff --git a/pkg/terraform/lock/testdata/lockfile_valid.hcl b/enumeration/terraform/lock/testdata/lockfile_valid.hcl similarity index 100% rename from pkg/terraform/lock/testdata/lockfile_valid.hcl rename to enumeration/terraform/lock/testdata/lockfile_valid.hcl diff --git a/enumeration/terraform/mock_ResourceFactory.go b/enumeration/terraform/mock_ResourceFactory.go new file mode 100644 index 00000000..fb58c47a --- /dev/null +++ b/enumeration/terraform/mock_ResourceFactory.go @@ -0,0 +1,53 @@ +// Code generated by mockery v2.3.0. DO NOT EDIT. + +package terraform + +import ( + "github.com/snyk/driftctl/enumeration/resource" + mock "github.com/stretchr/testify/mock" + cty "github.com/zclconf/go-cty/cty" +) + +// MockResourceFactory is an autogenerated mock type for the ResourceFactory type +type MockResourceFactory struct { + mock.Mock +} + +// CreateAbstractResource provides a mock function with given fields: ty, id, data +func (_m *MockResourceFactory) CreateAbstractResource(ty string, id string, data map[string]interface{}) *resource.Resource { + ret := _m.Called(ty, id, data) + + var r0 *resource.Resource + if rf, ok := ret.Get(0).(func(string, string, map[string]interface{}) *resource.Resource); ok { + r0 = rf(ty, id, data) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*resource.Resource) + } + } + + return r0 +} + +// CreateResource provides a mock function with given fields: data, ty +func (_m *MockResourceFactory) CreateResource(data interface{}, ty string) (*cty.Value, error) { + ret := _m.Called(data, ty) + + var r0 *cty.Value + if rf, ok := ret.Get(0).(func(interface{}, string) *cty.Value); ok { + r0 = rf(data, ty) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*cty.Value) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(interface{}, string) error); ok { + r1 = rf(data, ty) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/enumeration/terraform/parallel_resource_reader.go b/enumeration/terraform/parallel_resource_reader.go new file mode 100644 index 00000000..ab0f2286 --- /dev/null +++ b/enumeration/terraform/parallel_resource_reader.go @@ -0,0 +1,42 @@ +package terraform + +import ( + "github.com/snyk/driftctl/enumeration/parallel" + "github.com/zclconf/go-cty/cty" +) + +type ParallelResourceReader struct { + runner *parallel.ParallelRunner +} + +func NewParallelResourceReader(runner *parallel.ParallelRunner) *ParallelResourceReader { + return &ParallelResourceReader{ + runner: runner, + } +} + +func (p *ParallelResourceReader) Wait() ([]cty.Value, error) { + results := make([]cty.Value, 0) +Loop: + for { + select { + case res, ok := <-p.runner.Read(): + if !ok { + break Loop + } + ctyVal := res.(cty.Value) + if !ctyVal.IsNull() { + results = append(results, ctyVal) + } + case <-p.runner.DoneChan(): + break Loop + } + } + return results, p.runner.Err() +} + +func (p *ParallelResourceReader) Run(runnable func() (cty.Value, error)) { + p.runner.Run(func() (interface{}, error) { + return runnable() + }) +} diff --git a/enumeration/terraform/parallel_resource_reader_test.go b/enumeration/terraform/parallel_resource_reader_test.go new file mode 100644 index 00000000..4dd89175 --- /dev/null +++ b/enumeration/terraform/parallel_resource_reader_test.go @@ -0,0 +1,79 @@ +package terraform + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/snyk/driftctl/enumeration/parallel" + + "github.com/r3labs/diff/v2" + "github.com/stretchr/testify/assert" + + "github.com/zclconf/go-cty/cty" +) + +func TestParallelResourceReader_Wait(t *testing.T) { + assert := assert.New(t) + tests := []struct { + name string + execs []func() (cty.Value, error) + want []cty.Value + wantErr bool + }{ + { + name: "Working // read resource", + execs: []func() (cty.Value, error){ + func() (cty.Value, error) { + return cty.BoolVal(true), nil + }, + func() (cty.Value, error) { + return cty.StringVal("test"), nil + }, + }, + want: []cty.Value{cty.BoolVal(true), cty.StringVal("test")}, + wantErr: false, + }, + + { + name: "failing // read resource", + execs: []func() (cty.Value, error){ + func() (cty.Value, error) { + return cty.BoolVal(true), nil + }, + func() (cty.Value, error) { + return cty.NilVal, errors.New("error") + }, + func() (cty.Value, error) { + return cty.StringVal("test"), nil + }, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewParallelResourceReader(parallel.NewParallelRunner(context.TODO(), 10)) + + for _, fun := range tt.execs { + p.Run(fun) + } + + got, err := p.Wait() + assert.Equal(tt.wantErr, err != nil) + if tt.want != nil { + changelog, err := diff.Diff(got, tt.want) + if err != nil { + panic(err) + } + if len(changelog) > 0 { + for _, change := range changelog { + t.Errorf("%s got = %v, want %v", strings.Join(change.Path, "."), change.From, change.To) + } + } + } + }) + } +} diff --git a/pkg/terraform/plugin_client.go b/enumeration/terraform/plugin_client.go similarity index 100% rename from pkg/terraform/plugin_client.go rename to enumeration/terraform/plugin_client.go diff --git a/pkg/terraform/provider_config.go b/enumeration/terraform/provider_config.go similarity index 100% rename from pkg/terraform/provider_config.go rename to enumeration/terraform/provider_config.go diff --git a/pkg/terraform/provider_config_test.go b/enumeration/terraform/provider_config_test.go similarity index 100% rename from pkg/terraform/provider_config_test.go rename to enumeration/terraform/provider_config_test.go diff --git a/enumeration/terraform/provider_downloader.go b/enumeration/terraform/provider_downloader.go new file mode 100644 index 00000000..dafdec64 --- /dev/null +++ b/enumeration/terraform/provider_downloader.go @@ -0,0 +1,82 @@ +package terraform + +import ( + "context" + "io/ioutil" + "net/http" + "os" + + tferror "github.com/snyk/driftctl/enumeration/terraform/error" + + "github.com/hashicorp/go-getter" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +type ProviderDownloaderInterface interface { + Download(url, path string) error +} + +type ProviderDownloader struct { + httpclient *http.Client + unzip getter.ZipDecompressor + context context.Context +} + +func NewProviderDownloader() *ProviderDownloader { + return &ProviderDownloader{ + httpclient: http.DefaultClient, + unzip: getter.ZipDecompressor{}, + context: context.Background(), + } +} + +func (p *ProviderDownloader) Download(url, path string) error { + logrus.WithFields(logrus.Fields{ + "url": url, + "path": path, + }).Debug("Downloading provider") + + req, err := http.NewRequestWithContext(p.context, "GET", url, nil) + if err != nil { + return err + } + resp, err := p.httpclient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusForbidden { + return tferror.ProviderNotFoundError{} + } + if resp.StatusCode != http.StatusOK { + return errors.Errorf("unsuccessful request to %s: %s", url, resp.Status) + } + f, err := ioutil.TempFile("", "terraform-provider") + if err != nil { + return errors.Errorf("failed to open temporary file to download from %s", url) + } + defer f.Close() + defer os.Remove(f.Name()) + n, err := getter.Copy(p.context, f, resp.Body) + if err == nil && n < resp.ContentLength { + err = errors.Errorf( + "incorrect response size: expected %d bytes, but got %d bytes", + resp.ContentLength, + n, + ) + } + if err != nil { + return err + } + logrus.WithFields(logrus.Fields{ + "src": f.Name(), + "dst": path, + }).Debug("Decompressing archive") + err = p.unzip.Decompress(path, f.Name(), true, 0) + if err != nil { + return err + } + return nil +} diff --git a/enumeration/terraform/provider_downloader_test.go b/enumeration/terraform/provider_downloader_test.go new file mode 100644 index 00000000..6c6503d2 --- /dev/null +++ b/enumeration/terraform/provider_downloader_test.go @@ -0,0 +1,109 @@ +package terraform + +import ( + "fmt" + terraformError "github.com/snyk/driftctl/enumeration/terraform/error" + "io/ioutil" + "net/http" + "path" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/stretchr/testify/assert" + + "github.com/jarcoal/httpmock" +) + +func TestProviderDownloader_Download(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + downloader := NewProviderDownloader() + url := "https://example.com/terraform-provider-aws_3.19.0_linux_amd64.zip" + + cases := []struct { + name string + httpStatus *int + testFile *string + responder httpmock.Responder + assert func(assert *assert.Assertions, tmpDir string, err error) + }{ + { + name: "TestBadResponse(404)", + responder: httpmock.NewBytesResponder(http.StatusNotFound, []byte{}), + assert: func(assert *assert.Assertions, tmpDir string, err error) { + assert.Equal( + fmt.Sprintf("unsuccessful request to %s: 404", url), + err.Error(), + ) + }, + }, + { + name: "TestProviderNotFound(403)", + responder: httpmock.NewBytesResponder(http.StatusForbidden, []byte{}), + assert: func(assert *assert.Assertions, tmpDir string, err error) { + assert.IsType( + terraformError.ProviderNotFoundError{}, + err, + ) + }, + }, + { + name: "TestHttpError", + responder: httpmock.NewErrorResponder(fmt.Errorf("test error")), + assert: func(assert *assert.Assertions, tmpDir string, err error) { + assert.Contains(err.Error(), "test error") + }, + }, + { + name: "TestInvalidZip", + testFile: aws.String("invalid.zip"), + assert: func(assert *assert.Assertions, tmpDir string, err error) { + assert.NotNil(err) + infos, err := ioutil.ReadDir(tmpDir) + assert.Nil(err) + assert.Len(infos, 0) + }, + }, + { + name: "TestValidZip", + testFile: aws.String("terraform-provider-aws_3.5.0_linux_amd64.zip"), + assert: func(assert *assert.Assertions, tmpDir string, err error) { + assert.Nil(err) + file, err := ioutil.ReadFile(path.Join(tmpDir, "terraform-provider-aws_v3.5.0_x5")) + assert.Nil(err) + assert.Equal([]byte{0x74, 0x65, 0x73, 0x74, 0xa}, file) + }, + }, + } + + for _, c := range cases { + + t.Run(c.name, func(tt *testing.T) { + tmpDir := tt.TempDir() + + httpmock.Reset() + assert := assert.New(tt) + + if c.httpStatus == nil { + c.httpStatus = aws.Int(http.StatusOK) + } + + if c.responder != nil { + httpmock.RegisterResponder("GET", url, c.responder) + } else { + if c.testFile != nil { + body, err := ioutil.ReadFile("./testdata/" + *c.testFile) + if err != nil { + tt.Error(err) + } + httpmock.RegisterResponder("GET", url, httpmock.NewBytesResponder(*c.httpStatus, body)) + } + } + + err := downloader.Download(url, tmpDir) + + c.assert(assert, tmpDir, err) + }) + + } +} diff --git a/pkg/terraform/provider_factory.go b/enumeration/terraform/provider_factory.go similarity index 100% rename from pkg/terraform/provider_factory.go rename to enumeration/terraform/provider_factory.go diff --git a/enumeration/terraform/provider_installer.go b/enumeration/terraform/provider_installer.go new file mode 100644 index 00000000..f40b8c16 --- /dev/null +++ b/enumeration/terraform/provider_installer.go @@ -0,0 +1,95 @@ +package terraform + +import ( + "fmt" + "io/fs" + "os" + "path" + "path/filepath" + "runtime" + "strings" + + error2 "github.com/snyk/driftctl/enumeration/terraform/error" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +type HomeDirInterface interface { + Dir() (string, error) +} + +type ProviderInstaller struct { + downloader ProviderDownloaderInterface + config ProviderConfig + homeDir string +} + +func NewProviderInstaller(config ProviderConfig) (*ProviderInstaller, error) { + return &ProviderInstaller{ + NewProviderDownloader(), + config, + config.ConfigDir, + }, nil +} + +func (p *ProviderInstaller) Install() (string, error) { + providerDir := p.getProviderDirectory() + providerPath := p.getBinaryPath() + + info, err := os.Stat(providerPath) + + if err != nil && os.IsNotExist(err) { + logrus.WithFields(logrus.Fields{ + "path": providerPath, + }).Debug("provider not found, downloading ...") + err := p.downloader.Download( + p.config.GetDownloadUrl(), + providerDir, + ) + if err != nil { + if notFoundErr, ok := err.(error2.ProviderNotFoundError); ok { + notFoundErr.Version = p.config.Version + return "", notFoundErr + } + return "", err + } + logrus.Debug("Download successful") + } + + if info != nil && info.IsDir() { + return "", errors.Errorf( + "found directory instead of provider binary in %s", + providerPath, + ) + } + + if info != nil { + logrus.WithFields(logrus.Fields{ + "path": providerPath, + }).Debug("Found existing provider") + } + + return p.getBinaryPath(), nil +} + +func (p ProviderInstaller) getProviderDirectory() string { + return path.Join(p.homeDir, fmt.Sprintf(".driftctl/plugins/%s_%s/", runtime.GOOS, runtime.GOARCH)) +} + +// Handle postfixes in binary names +func (p *ProviderInstaller) getBinaryPath() string { + providerDir := p.getProviderDirectory() + binaryName := p.config.GetBinaryName() + _, err := os.Stat(path.Join(providerDir, binaryName)) + if err != nil && os.IsNotExist(err) { + _ = filepath.WalkDir(providerDir, func(filePath string, d fs.DirEntry, err error) error { + if d != nil && strings.HasPrefix(d.Name(), p.config.GetBinaryName()) { + binaryName = d.Name() + } + return nil + }) + } + + return path.Join(providerDir, binaryName) +} diff --git a/enumeration/terraform/provider_installer_test.go b/enumeration/terraform/provider_installer_test.go new file mode 100644 index 00000000..43814e84 --- /dev/null +++ b/enumeration/terraform/provider_installer_test.go @@ -0,0 +1,206 @@ +package terraform + +import ( + "fmt" + terraformError "github.com/snyk/driftctl/enumeration/terraform/error" + "os" + "path" + "runtime" + "testing" + + "github.com/snyk/driftctl/mocks" + "github.com/stretchr/testify/mock" + + "github.com/stretchr/testify/assert" +) + +func TestProviderInstallerInstallDoesNotExist(t *testing.T) { + + assert := assert.New(t) + fakeTmpHome := t.TempDir() + + expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH) + + config := ProviderConfig{ + Key: "aws", + Version: "3.19.0", + } + + mockDownloader := mocks.ProviderDownloaderInterface{} + mockDownloader.On("Download", config.GetDownloadUrl(), path.Join(fakeTmpHome, expectedSubFolder)).Return(nil) + + installer := ProviderInstaller{ + downloader: &mockDownloader, + config: config, + homeDir: fakeTmpHome, + } + + providerPath, err := installer.Install() + mockDownloader.AssertExpectations(t) + + assert.Nil(err) + assert.Equal(path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName()), providerPath) + +} + +func TestProviderInstallerInstallAlreadyExist(t *testing.T) { + + assert := assert.New(t) + fakeTmpHome := t.TempDir() + expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH) + err := os.MkdirAll(path.Join(fakeTmpHome, expectedSubFolder), 0755) + if err != nil { + t.Error(err) + } + + config := ProviderConfig{ + Key: "aws", + Version: "3.19.0", + } + + _, err = os.Create(path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName())) + if err != nil { + t.Error(err) + } + + mockDownloader := mocks.ProviderDownloaderInterface{} + + installer := ProviderInstaller{ + downloader: &mockDownloader, + config: config, + homeDir: fakeTmpHome, + } + + providerPath, err := installer.Install() + mockDownloader.AssertExpectations(t) + + assert.Nil(err) + assert.Equal(path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName()), providerPath) + +} + +func TestProviderInstallerInstallAlreadyExistButIsDirectory(t *testing.T) { + + assert := assert.New(t) + fakeTmpHome := t.TempDir() + expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH) + + config := ProviderConfig{ + Key: "aws", + Version: "3.19.0", + } + + invalidDirPath := path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName()) + err := os.MkdirAll(invalidDirPath, 0755) + if err != nil { + t.Error(err) + } + + mockDownloader := mocks.ProviderDownloaderInterface{} + + installer := ProviderInstaller{ + downloader: &mockDownloader, + config: config, + homeDir: fakeTmpHome, + } + + providerPath, err := installer.Install() + mockDownloader.AssertExpectations(t) + + assert.Empty(providerPath) + assert.NotNil(err) + assert.Equal( + fmt.Sprintf( + "found directory instead of provider binary in %s", + invalidDirPath, + ), + err.Error(), + ) + +} + +// Ensure that if a provider exists with a postfix (_x5) we properly detect it +func TestProviderInstallerInstallPostfixIsHandler(t *testing.T) { + + assert := assert.New(t) + fakeTmpHome := t.TempDir() + expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH) + err := os.MkdirAll(path.Join(fakeTmpHome, expectedSubFolder), 0755) + if err != nil { + t.Error(err) + } + + config := ProviderConfig{ + Key: "aws", + Version: "3.19.0", + } + + _, err = os.Create(path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName()+"_x5")) + if err != nil { + t.Fatal(err) + } + + mockDownloader := mocks.ProviderDownloaderInterface{} + + installer := ProviderInstaller{ + downloader: &mockDownloader, + config: config, + homeDir: fakeTmpHome, + } + + providerPath, err := installer.Install() + mockDownloader.AssertExpectations(t) + + assert.Nil(err) + assert.Equal(path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName()+"_x5"), providerPath) + +} + +func TestProviderInstallerVersionDoesNotExist(t *testing.T) { + + assert := assert.New(t) + + config := ProviderConfig{ + Key: "aws", + Version: "666.666.666", + } + + mockDownloader := mocks.ProviderDownloaderInterface{} + mockDownloader.On("Download", mock.Anything, mock.Anything).Return(terraformError.ProviderNotFoundError{}) + + installer := ProviderInstaller{ + downloader: &mockDownloader, + config: config, + } + + _, err := installer.Install() + + assert.Equal("Provider version 666.666.666 does not exist", err.Error()) +} + +func TestProviderInstallerWithConfigDirectory(t *testing.T) { + + assert := assert.New(t) + fakeTmpHome := t.TempDir() + + expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH) + + config := ProviderConfig{ + Key: "aws", + Version: "3.19.0", + ConfigDir: fakeTmpHome, + } + + mockDownloader := mocks.ProviderDownloaderInterface{} + mockDownloader.On("Download", config.GetDownloadUrl(), path.Join(fakeTmpHome, expectedSubFolder)).Return(nil) + + installer, _ := NewProviderInstaller(config) + installer.downloader = &mockDownloader + + providerPath, err := installer.Install() + mockDownloader.AssertExpectations(t) + + assert.Nil(err) + assert.Equal(path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName()), providerPath) + +} diff --git a/pkg/terraform/providers.go b/enumeration/terraform/providers.go similarity index 100% rename from pkg/terraform/providers.go rename to enumeration/terraform/providers.go diff --git a/enumeration/terraform/resource_factory.go b/enumeration/terraform/resource_factory.go new file mode 100644 index 00000000..e80d3097 --- /dev/null +++ b/enumeration/terraform/resource_factory.go @@ -0,0 +1,35 @@ +package terraform + +import ( + "github.com/snyk/driftctl/enumeration/resource" +) + +type TerraformResourceFactory struct { + resourceSchemaRepository resource.SchemaRepositoryInterface +} + +func NewTerraformResourceFactory(resourceSchemaRepository resource.SchemaRepositoryInterface) *TerraformResourceFactory { + return &TerraformResourceFactory{ + resourceSchemaRepository: resourceSchemaRepository, + } +} + +func (r *TerraformResourceFactory) CreateAbstractResource(ty, id string, data map[string]interface{}) *resource.Resource { + attributes := resource.Attributes(data) + attributes.SanitizeDefaults() + + schema, _ := r.resourceSchemaRepository.GetSchema(ty) + res := resource.Resource{ + Id: id, + Type: ty, + Attrs: &attributes, + Sch: schema, + } + + schema, exist := r.resourceSchemaRepository.(*resource.SchemaRepository).GetSchema(ty) + if exist && schema.NormalizeFunc != nil { + schema.NormalizeFunc(&res) + } + + return &res +} diff --git a/enumeration/terraform/resource_reader.go b/enumeration/terraform/resource_reader.go new file mode 100644 index 00000000..1f7a6e36 --- /dev/null +++ b/enumeration/terraform/resource_reader.go @@ -0,0 +1,17 @@ +package terraform + +import ( + "github.com/snyk/driftctl/enumeration/resource" + + "github.com/zclconf/go-cty/cty" +) + +type ResourceReader interface { + ReadResource(args ReadResourceArgs) (*cty.Value, error) +} + +type ReadResourceArgs struct { + Ty resource.ResourceType + ID string + Attributes map[string]string +} diff --git a/pkg/terraform/schema_supplier.go b/enumeration/terraform/schema_supplier.go similarity index 100% rename from pkg/terraform/schema_supplier.go rename to enumeration/terraform/schema_supplier.go diff --git a/pkg/terraform/terraform_provider.go b/enumeration/terraform/terraform_provider.go similarity index 100% rename from pkg/terraform/terraform_provider.go rename to enumeration/terraform/terraform_provider.go diff --git a/pkg/terraform/testdata/invalid.zip b/enumeration/terraform/testdata/invalid.zip similarity index 100% rename from pkg/terraform/testdata/invalid.zip rename to enumeration/terraform/testdata/invalid.zip diff --git a/pkg/terraform/testdata/terraform-provider-aws_3.5.0_linux_amd64.zip b/enumeration/terraform/testdata/terraform-provider-aws_3.5.0_linux_amd64.zip similarity index 100% rename from pkg/terraform/testdata/terraform-provider-aws_3.5.0_linux_amd64.zip rename to enumeration/terraform/testdata/terraform-provider-aws_3.5.0_linux_amd64.zip diff --git a/mocks/AlerterInterface.go b/mocks/AlerterInterface.go index bfd52a2e..4903a713 100644 --- a/mocks/AlerterInterface.go +++ b/mocks/AlerterInterface.go @@ -3,7 +3,7 @@ package mocks import ( - alerter "github.com/snyk/driftctl/pkg/alerter" + "github.com/snyk/driftctl/enumeration/alerter" mock "github.com/stretchr/testify/mock" ) diff --git a/pkg/alerter/alerter.go b/pkg/alerter/alerter.go deleted file mode 100644 index a6b26880..00000000 --- a/pkg/alerter/alerter.go +++ /dev/null @@ -1,75 +0,0 @@ -package alerter - -import ( - "fmt" - - "github.com/snyk/driftctl/pkg/resource" -) - -type AlerterInterface interface { - SendAlert(key string, alert Alert) -} - -type Alerter struct { - alerts Alerts - alertsCh chan Alerts - doneCh chan bool -} - -func NewAlerter() *Alerter { - var alerter = &Alerter{ - alerts: make(Alerts), - alertsCh: make(chan Alerts), - doneCh: make(chan bool), - } - - go alerter.run() - - return alerter -} - -func (a *Alerter) run() { - defer func() { a.doneCh <- true }() - for alert := range a.alertsCh { - for k, v := range alert { - if val, ok := a.alerts[k]; ok { - a.alerts[k] = append(val, v...) - } else { - a.alerts[k] = v - } - } - } -} - -func (a *Alerter) SetAlerts(alerts Alerts) { - a.alerts = alerts -} - -func (a *Alerter) Retrieve() Alerts { - close(a.alertsCh) - <-a.doneCh - return a.alerts -} - -func (a *Alerter) SendAlert(key string, alert Alert) { - a.alertsCh <- Alerts{ - key: []Alert{alert}, - } -} - -func (a *Alerter) IsResourceIgnored(res *resource.Resource) bool { - alert, alertExists := a.alerts[fmt.Sprintf("%s.%s", res.ResourceType(), res.ResourceId())] - wildcardAlert, wildcardAlertExists := a.alerts[res.ResourceType()] - shouldIgnoreAlert := a.shouldBeIgnored(alert) - shouldIgnoreWildcardAlert := a.shouldBeIgnored(wildcardAlert) - return (alertExists && shouldIgnoreAlert) || (wildcardAlertExists && shouldIgnoreWildcardAlert) -} - -func (a *Alerter) shouldBeIgnored(alert []Alert) bool { - for _, a := range alert { - if a.ShouldIgnoreResource() { - return true - } - } - return false -} diff --git a/pkg/alerter/alerter_test.go b/pkg/alerter/alerter_test.go deleted file mode 100644 index f195760b..00000000 --- a/pkg/alerter/alerter_test.go +++ /dev/null @@ -1,161 +0,0 @@ -package alerter - -import ( - "reflect" - "testing" - - "github.com/snyk/driftctl/pkg/resource" -) - -func TestAlerter_Alert(t *testing.T) { - cases := []struct { - name string - alerts Alerts - expected Alerts - }{ - { - name: "TestNoAlerts", - alerts: nil, - expected: Alerts{}, - }, - { - name: "TestWithSingleAlert", - alerts: Alerts{ - "fakeres.foobar": []Alert{ - &FakeAlert{"This is an alert", false}, - }, - }, - expected: Alerts{ - "fakeres.foobar": []Alert{ - &FakeAlert{"This is an alert", false}, - }, - }, - }, - { - name: "TestWithMultipleAlerts", - alerts: Alerts{ - "fakeres.foobar": []Alert{ - &FakeAlert{"This is an alert", false}, - &FakeAlert{"This is a second alert", true}, - }, - "fakeres.barfoo": []Alert{ - &FakeAlert{"This is a third alert", true}, - }, - }, - expected: Alerts{ - "fakeres.foobar": []Alert{ - &FakeAlert{"This is an alert", false}, - &FakeAlert{"This is a second alert", true}, - }, - "fakeres.barfoo": []Alert{ - &FakeAlert{"This is a third alert", true}, - }, - }, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - alerter := NewAlerter() - - for k, v := range c.alerts { - for _, a := range v { - alerter.SendAlert(k, a) - } - } - - if eq := reflect.DeepEqual(alerter.Retrieve(), c.expected); !eq { - t.Errorf("Got %+v, expected %+v", alerter.Retrieve(), c.expected) - } - }) - } -} - -func TestAlerter_IgnoreResources(t *testing.T) { - cases := []struct { - name string - alerts Alerts - resource *resource.Resource - expected bool - }{ - { - name: "TestNoAlerts", - alerts: Alerts{}, - resource: &resource.Resource{ - Type: "fakeres", - Id: "foobar", - }, - expected: false, - }, - { - name: "TestShouldNotBeIgnoredWithAlerts", - alerts: Alerts{ - "fakeres": { - &FakeAlert{"Should not be ignored", false}, - }, - "fakeres.foobar": { - &FakeAlert{"Should not be ignored", false}, - }, - "fakeres.barfoo": { - &FakeAlert{"Should not be ignored", false}, - }, - "other.resource": { - &FakeAlert{"Should not be ignored", false}, - }, - }, - resource: &resource.Resource{ - Type: "fakeres", - Id: "foobar", - }, - expected: false, - }, - { - name: "TestShouldBeIgnoredWithAlertsOnWildcard", - alerts: Alerts{ - "fakeres": { - &FakeAlert{"Should be ignored", true}, - }, - "other.foobaz": { - &FakeAlert{"Should be ignored", true}, - }, - "other.resource": { - &FakeAlert{"Should not be ignored", false}, - }, - }, - resource: &resource.Resource{ - Type: "fakeres", - Id: "foobar", - }, - expected: true, - }, - { - name: "TestShouldBeIgnoredWithAlertsOnResource", - alerts: Alerts{ - "fakeres": { - &FakeAlert{"Should be ignored", true}, - }, - "other.foobaz": { - &FakeAlert{"Should be ignored", true}, - }, - "other.resource": { - &FakeAlert{"Should not be ignored", false}, - }, - }, - resource: &resource.Resource{ - Type: "other", - Id: "foobaz", - }, - expected: true, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - alerter := NewAlerter() - alerter.SetAlerts(c.alerts) - if got := alerter.IsResourceIgnored(c.resource); got != c.expected { - t.Errorf("Got %+v, expected %+v", got, c.expected) - } - }) - } -} diff --git a/pkg/analyser/analysis.go b/pkg/analyser/analysis.go index 90b80bee..9de62ef3 100644 --- a/pkg/analyser/analysis.go +++ b/pkg/analyser/analysis.go @@ -7,9 +7,10 @@ import ( "strings" "time" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) type Change struct { diff --git a/pkg/analyser/analyzer.go b/pkg/analyser/analyzer.go index 83e297e6..d26f4576 100644 --- a/pkg/analyser/analyzer.go +++ b/pkg/analyser/analyzer.go @@ -2,11 +2,11 @@ package analyser import ( "github.com/r3labs/diff/v2" + "github.com/snyk/driftctl/enumeration/alerter" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/pkg/filter" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) type UnmanagedSecurityGroupRulesAlert struct{} diff --git a/pkg/analyser/analyzer_test.go b/pkg/analyser/analyzer_test.go index 984ef0d9..16bdca89 100644 --- a/pkg/analyser/analyzer_test.go +++ b/pkg/analyser/analyzer_test.go @@ -6,6 +6,9 @@ import ( "testing" "time" + alerter2 "github.com/snyk/driftctl/enumeration/alerter" + aws2 "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/pkg/filter" "github.com/stretchr/testify/mock" @@ -15,9 +18,8 @@ import ( "github.com/snyk/driftctl/test/goldenfile" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/r3labs/diff/v2" ) @@ -32,7 +34,7 @@ func TestAnalyze(t *testing.T) { res *resource.Resource path []string } - alerts alerter.Alerts + alerts alerter2.Alerts expected Analysis hasDrifted bool options *AnalyzerOptions @@ -292,7 +294,7 @@ func TestAnalyze(t *testing.T) { }, }, }, - alerts: alerter.Alerts{ + alerts: alerter2.Alerts{ "": { NewComputedDiffAlert(), }, @@ -379,7 +381,7 @@ func TestAnalyze(t *testing.T) { }, }, }, - alerts: alerter.Alerts{ + alerts: alerter2.Alerts{ "": { NewComputedDiffAlert(), }, @@ -574,15 +576,15 @@ func TestAnalyze(t *testing.T) { }, }, }, - alerts: alerter.Alerts{ + alerts: alerter2.Alerts{ "fakeres": { - &alerter.FakeAlert{Msg: "Should be ignored", IgnoreResource: true}, + &alerter2.FakeAlert{Msg: "Should be ignored", IgnoreResource: true}, }, "other.foobaz": { - &alerter.FakeAlert{Msg: "Should be ignored", IgnoreResource: true}, + &alerter2.FakeAlert{Msg: "Should be ignored", IgnoreResource: true}, }, "other.resource": { - &alerter.FakeAlert{Msg: "Should not be ignored"}, + &alerter2.FakeAlert{Msg: "Should not be ignored"}, }, }, expected: Analysis{ @@ -702,15 +704,15 @@ func TestAnalyze(t *testing.T) { }, }, }, - alerts: alerter.Alerts{ + alerts: alerter2.Alerts{ "fakeres": { - &alerter.FakeAlert{Msg: "Should be ignored", IgnoreResource: true}, + &alerter2.FakeAlert{Msg: "Should be ignored", IgnoreResource: true}, }, "other.foobaz": { - &alerter.FakeAlert{Msg: "Should be ignored", IgnoreResource: true}, + &alerter2.FakeAlert{Msg: "Should be ignored", IgnoreResource: true}, }, "other.resource": { - &alerter.FakeAlert{Msg: "Should not be ignored"}, + &alerter2.FakeAlert{Msg: "Should not be ignored"}, }, }, }, @@ -738,7 +740,7 @@ func TestAnalyze(t *testing.T) { }, }, }, - alerts: alerter.Alerts{}, + alerts: alerter2.Alerts{}, expected: Analysis{ managed: []*resource.Resource{ { @@ -791,7 +793,7 @@ func TestAnalyze(t *testing.T) { }, }, }, - alerts: alerter.Alerts{ + alerts: alerter2.Alerts{ "": { NewComputedDiffAlert(), }, @@ -850,7 +852,7 @@ func TestAnalyze(t *testing.T) { TotalManaged: 1, TotalUnmanaged: 1, }, - alerts: alerter.Alerts{ + alerts: alerter2.Alerts{ "": { newUnmanagedSecurityGroupRulesAlert(), }, @@ -924,7 +926,7 @@ func TestAnalyze(t *testing.T) { TotalUnmanaged: 3, TotalDeleted: 3, }, - alerts: alerter.Alerts{}, + alerts: alerter2.Alerts{}, }, hasDrifted: true, }, @@ -1158,13 +1160,14 @@ func TestAnalyze(t *testing.T) { } testFilter.On("IsFieldIgnored", mock.Anything, mock.Anything).Return(false) - al := alerter.NewAlerter() + al := alerter2.NewAlerter() if c.alerts != nil { al.SetAlerts(c.alerts) } repo := testresource.InitFakeSchemaRepository("aws", "3.19.0") aws.InitResourcesMetadata(repo) + aws2.InitResourcesMetadata(repo) options := AnalyzerOptions{Deep: true} if c.options != nil { @@ -1327,9 +1330,9 @@ func TestAnalysis_MarshalJSON(t *testing.T) { }, }, }) - analysis.SetAlerts(alerter.Alerts{ + analysis.SetAlerts(alerter2.Alerts{ "aws_iam_access_key": { - &alerter.FakeAlert{Msg: "This is an alert"}, + &alerter2.FakeAlert{Msg: "This is an alert"}, }, }) analysis.ProviderName = "AWS" @@ -1413,9 +1416,9 @@ func TestAnalysis_UnmarshalJSON(t *testing.T) { }, }, }, - alerts: alerter.Alerts{ + alerts: alerter2.Alerts{ "aws_iam_access_key": { - &alerter.SerializedAlert{ + &alerter2.SerializedAlert{ Msg: "This is an alert", }, }, diff --git a/pkg/cmd/scan.go b/pkg/cmd/scan.go index 28c51035..b5476e99 100644 --- a/pkg/cmd/scan.go +++ b/pkg/cmd/scan.go @@ -16,27 +16,28 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/snyk/driftctl/build" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote" + "github.com/snyk/driftctl/enumeration/remote/common" + "github.com/snyk/driftctl/enumeration/terraform" + "github.com/snyk/driftctl/enumeration/terraform/lock" "github.com/snyk/driftctl/pkg/analyser" "github.com/snyk/driftctl/pkg/iac/config" "github.com/snyk/driftctl/pkg/iac/terraform/state" "github.com/snyk/driftctl/pkg/memstore" - "github.com/snyk/driftctl/pkg/remote/common" "github.com/snyk/driftctl/pkg/telemetry" "github.com/snyk/driftctl/pkg/terraform/hcl" - "github.com/snyk/driftctl/pkg/terraform/lock" "github.com/spf13/cobra" + "github.com/snyk/driftctl/enumeration/resource" "github.com/snyk/driftctl/pkg" - "github.com/snyk/driftctl/pkg/alerter" cmderrors "github.com/snyk/driftctl/pkg/cmd/errors" "github.com/snyk/driftctl/pkg/cmd/scan/output" "github.com/snyk/driftctl/pkg/filter" "github.com/snyk/driftctl/pkg/iac/supplier" "github.com/snyk/driftctl/pkg/iac/terraform/state/backend" globaloutput "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/remote" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/terraform" + dctlresource "github.com/snyk/driftctl/pkg/resource" ) func NewScanCmd(opts *pkg.ScanOptions) *cobra.Command { @@ -295,6 +296,10 @@ func scanRun(opts *pkg.ScanOptions) error { return err } + err = dctlresource.InitMetadatas(opts.To, resourceSchemaRepository) + if err != nil { + return err + } // Teardown defer func() { logrus.Trace("Exiting scan cmd") diff --git a/pkg/cmd/scan/output/console.go b/pkg/cmd/scan/output/console.go index 23c2e0cb..2f8a70a8 100644 --- a/pkg/cmd/scan/output/console.go +++ b/pkg/cmd/scan/output/console.go @@ -3,6 +3,7 @@ package output import ( "encoding/json" "fmt" + "github.com/snyk/driftctl/enumeration/remote/alerts" "os" "reflect" "sort" @@ -12,12 +13,11 @@ import ( "github.com/fatih/color" "github.com/mattn/go-isatty" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/remote/alerts" "github.com/yudai/gojsondiff" "github.com/yudai/gojsondiff/formatter" + "github.com/snyk/driftctl/enumeration/resource" "github.com/snyk/driftctl/pkg/analyser" - "github.com/snyk/driftctl/pkg/resource" ) const ConsoleOutputType = "console" diff --git a/pkg/cmd/scan/output/console_test.go b/pkg/cmd/scan/output/console_test.go index 87751b18..abdbc2d4 100644 --- a/pkg/cmd/scan/output/console_test.go +++ b/pkg/cmd/scan/output/console_test.go @@ -8,8 +8,8 @@ import ( "path" "testing" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/test/goldenfile" testresource "github.com/snyk/driftctl/test/resource" "github.com/stretchr/testify/assert" diff --git a/pkg/cmd/scan/output/html.go b/pkg/cmd/scan/output/html.go index f7adb2d0..57186963 100644 --- a/pkg/cmd/scan/output/html.go +++ b/pkg/cmd/scan/output/html.go @@ -5,6 +5,7 @@ import ( "embed" "encoding/base64" "fmt" + "github.com/snyk/driftctl/enumeration/alerter" "html/template" "math" "os" @@ -16,9 +17,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/alerter" + "github.com/snyk/driftctl/enumeration/resource" "github.com/snyk/driftctl/pkg/analyser" - "github.com/snyk/driftctl/pkg/resource" ) const HTMLOutputType = "html" diff --git a/pkg/cmd/scan/output/html_test.go b/pkg/cmd/scan/output/html_test.go index 9490a9a7..7542de9a 100644 --- a/pkg/cmd/scan/output/html_test.go +++ b/pkg/cmd/scan/output/html_test.go @@ -7,7 +7,7 @@ import ( "time" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" "github.com/stretchr/testify/assert" "github.com/snyk/driftctl/pkg/analyser" diff --git a/pkg/cmd/scan/output/output_test.go b/pkg/cmd/scan/output/output_test.go index 7ff6c11d..70b2c3fd 100644 --- a/pkg/cmd/scan/output/output_test.go +++ b/pkg/cmd/scan/output/output_test.go @@ -6,16 +6,17 @@ import ( "testing" "time" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/remote/alerts" + "github.com/snyk/driftctl/enumeration/remote/common" + remoteerr "github.com/snyk/driftctl/enumeration/remote/error" + "github.com/pkg/errors" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/alerter" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/pkg/analyser" "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" ) func fakeAnalysis(opts analyser.AnalyzerOptions) *analyser.Analysis { diff --git a/pkg/cmd/scan/output/plan.go b/pkg/cmd/scan/output/plan.go index 48ab75dc..ceed70aa 100644 --- a/pkg/cmd/scan/output/plan.go +++ b/pkg/cmd/scan/output/plan.go @@ -5,8 +5,8 @@ import ( "fmt" "os" + "github.com/snyk/driftctl/enumeration/resource" "github.com/snyk/driftctl/pkg/analyser" - "github.com/snyk/driftctl/pkg/resource" ) const FormatVersion = "0.1" diff --git a/pkg/driftctl.go b/pkg/driftctl.go index f9130d4a..953f4cf4 100644 --- a/pkg/driftctl.go +++ b/pkg/driftctl.go @@ -4,19 +4,21 @@ import ( "fmt" "time" + "github.com/snyk/driftctl/enumeration/alerter" + resource2 "github.com/snyk/driftctl/pkg/resource" + "github.com/jmespath/go-jmespath" "github.com/sirupsen/logrus" "github.com/snyk/driftctl/pkg/memstore" globaloutput "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/alerter" + "github.com/snyk/driftctl/enumeration/resource" "github.com/snyk/driftctl/pkg/analyser" "github.com/snyk/driftctl/pkg/cmd/scan/output" "github.com/snyk/driftctl/pkg/filter" "github.com/snyk/driftctl/pkg/iac/config" "github.com/snyk/driftctl/pkg/iac/terraform/state/backend" "github.com/snyk/driftctl/pkg/middlewares" - "github.com/snyk/driftctl/pkg/resource" ) type FmtOptions struct { @@ -45,7 +47,7 @@ type ScanOptions struct { type DriftCTL struct { remoteSupplier resource.Supplier - iacSupplier resource.IaCSupplier + iacSupplier resource2.IaCSupplier alerter alerter.AlerterInterface analyzer *analyser.Analyzer resourceFactory resource.ResourceFactory @@ -57,7 +59,7 @@ type DriftCTL struct { } func NewDriftCTL(remoteSupplier resource.Supplier, - iacSupplier resource.IaCSupplier, + iacSupplier resource2.IaCSupplier, alerter *alerter.Alerter, analyzer *analyser.Analyzer, resFactory resource.ResourceFactory, diff --git a/pkg/driftctl_test.go b/pkg/driftctl_test.go index 1d8a99f1..6a857651 100644 --- a/pkg/driftctl_test.go +++ b/pkg/driftctl_test.go @@ -1,20 +1,22 @@ package pkg_test import ( + terraform2 "github.com/snyk/driftctl/enumeration/terraform" + resource2 "github.com/snyk/driftctl/pkg/resource" + "reflect" "testing" "github.com/r3labs/diff/v2" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" + "github.com/snyk/driftctl/enumeration/resource/github" "github.com/snyk/driftctl/pkg" - "github.com/snyk/driftctl/pkg/alerter" "github.com/snyk/driftctl/pkg/analyser" "github.com/snyk/driftctl/pkg/filter" "github.com/snyk/driftctl/pkg/memstore" "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/resource/github" - "github.com/snyk/driftctl/pkg/terraform" "github.com/snyk/driftctl/test" testresource "github.com/snyk/driftctl/test/resource" "github.com/stretchr/testify/assert" @@ -62,7 +64,7 @@ func runTest(t *testing.T, cases TestCases) { res.Sch = schema } - stateSupplier := &resource.MockIaCSupplier{} + stateSupplier := &resource2.MockIaCSupplier{} stateSupplier.On("Resources").Return(c.stateResources, nil) stateSupplier.On("SourceCount").Return(uint(2)) @@ -77,10 +79,10 @@ func runTest(t *testing.T, cases TestCases) { remoteSupplier := &resource.MockSupplier{} remoteSupplier.On("Resources").Return(c.remoteResources, nil) - var resourceFactory resource.ResourceFactory = terraform.NewTerraformResourceFactory(repo) + var resourceFactory resource.ResourceFactory = terraform2.NewTerraformResourceFactory(repo) if c.mocks != nil { - resourceFactory = &terraform.MockResourceFactory{} + resourceFactory = &terraform2.MockResourceFactory{} c.mocks(resourceFactory, repo) } @@ -364,7 +366,7 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) { { name: "we should ignore default AWS IAM role when strict mode is disabled", mocks: func(factory resource.ResourceFactory, repo resource.SchemaRepositoryInterface) { - factory.(*terraform.MockResourceFactory).On( + factory.(*terraform2.MockResourceFactory).On( "CreateAbstractResource", aws.AwsIamPolicyAttachmentResourceType, "role-test-1-policy-test-1", @@ -462,7 +464,7 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) { { name: "we should not ignore default AWS IAM role when strict mode is enabled", mocks: func(factory resource.ResourceFactory, repo resource.SchemaRepositoryInterface) { - factory.(*terraform.MockResourceFactory).On( + factory.(*terraform2.MockResourceFactory).On( "CreateAbstractResource", aws.AwsIamPolicyAttachmentResourceType, "role-test-1-policy-test-1", @@ -560,7 +562,7 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) { { name: "we should not ignore default AWS IAM role when strict mode is enabled and a filter is specified", mocks: func(factory resource.ResourceFactory, repo resource.SchemaRepositoryInterface) { - factory.(*terraform.MockResourceFactory).On( + factory.(*terraform2.MockResourceFactory).On( "CreateAbstractResource", aws.AwsIamPolicyAttachmentResourceType, "role-test-1-policy-test-1", @@ -791,7 +793,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) { }, }, mocks: func(factory resource.ResourceFactory, repo resource.SchemaRepositoryInterface) { - factory.(*terraform.MockResourceFactory).On( + factory.(*terraform2.MockResourceFactory).On( "CreateAbstractResource", aws.AwsS3BucketPolicyResourceType, "foo", @@ -888,7 +890,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) { }, Sch: getSchema(repo, "aws_ebs_volume"), } - factory.(*terraform.MockResourceFactory).On("CreateAbstractResource", "aws_ebs_volume", mock.Anything, mock.MatchedBy(func(input map[string]interface{}) bool { + factory.(*terraform2.MockResourceFactory).On("CreateAbstractResource", "aws_ebs_volume", mock.Anything, mock.MatchedBy(func(input map[string]interface{}) bool { return matchByAttributes(input, map[string]interface{}{ "id": "vol-018c5ae89895aca4c", "availability_zone": "us-east-1", @@ -907,7 +909,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) { }, Sch: getSchema(repo, "aws_ebs_volume"), } - factory.(*terraform.MockResourceFactory).On("CreateAbstractResource", "aws_ebs_volume", mock.Anything, mock.MatchedBy(func(input map[string]interface{}) bool { + factory.(*terraform2.MockResourceFactory).On("CreateAbstractResource", "aws_ebs_volume", mock.Anything, mock.MatchedBy(func(input map[string]interface{}) bool { return matchByAttributes(input, map[string]interface{}{ "id": "vol-02862d9b39045a3a4", "availability_zone": "us-east-1", @@ -993,7 +995,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) { }, }, mocks: func(factory resource.ResourceFactory, repo resource.SchemaRepositoryInterface) { - factory.(*terraform.MockResourceFactory).On("CreateAbstractResource", "aws_route", "r-table1080289494", mock.MatchedBy(func(input map[string]interface{}) bool { + factory.(*terraform2.MockResourceFactory).On("CreateAbstractResource", "aws_route", "r-table1080289494", mock.MatchedBy(func(input map[string]interface{}) bool { return matchByAttributes(input, map[string]interface{}{ "destination_cidr_block": "0.0.0.0/0", "gateway_id": "igw-07b7844a8fd17a638", @@ -1012,7 +1014,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) { "state": "active", }, }, nil) - factory.(*terraform.MockResourceFactory).On("CreateAbstractResource", "aws_route", "r-table2750132062", mock.MatchedBy(func(input map[string]interface{}) bool { + factory.(*terraform2.MockResourceFactory).On("CreateAbstractResource", "aws_route", "r-table2750132062", mock.MatchedBy(func(input map[string]interface{}) bool { return matchByAttributes(input, map[string]interface{}{ "destination_ipv6_cidr_block": "::/0", "gateway_id": "igw-07b7844a8fd17a638", @@ -1071,7 +1073,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) { }, }, mocks: func(factory resource.ResourceFactory, repo resource.SchemaRepositoryInterface) { - factory.(*terraform.MockResourceFactory).On("CreateAbstractResource", "aws_sns_topic_policy", "foo", map[string]interface{}{ + factory.(*terraform2.MockResourceFactory).On("CreateAbstractResource", "aws_sns_topic_policy", "foo", map[string]interface{}{ "id": "foo", "arn": "arn", "policy": "{\"policy\":\"bar\"}", @@ -1133,7 +1135,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) { }, }, mocks: func(factory resource.ResourceFactory, repo resource.SchemaRepositoryInterface) { - factory.(*terraform.MockResourceFactory).On("CreateAbstractResource", "aws_sqs_queue_policy", "foo", map[string]interface{}{ + factory.(*terraform2.MockResourceFactory).On("CreateAbstractResource", "aws_sqs_queue_policy", "foo", map[string]interface{}{ "id": "foo", "queue_url": "foo", "policy": "{\"policy\":\"bar\"}", @@ -1362,7 +1364,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) { "prefix_list_ids": []interface{}{}, }, } - factory.(*terraform.MockResourceFactory).On("CreateAbstractResource", "aws_security_group_rule", rule1.Id, + factory.(*terraform2.MockResourceFactory).On("CreateAbstractResource", "aws_security_group_rule", rule1.Id, mock.MatchedBy(func(input map[string]interface{}) bool { return matchByAttributes(input, map[string]interface{}{ "id": "sgrule-1707973622", @@ -1394,7 +1396,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) { "prefix_list_ids": []interface{}{}, }, } - factory.(*terraform.MockResourceFactory).On("CreateAbstractResource", "aws_security_group_rule", rule2.Id, + factory.(*terraform2.MockResourceFactory).On("CreateAbstractResource", "aws_security_group_rule", rule2.Id, mock.MatchedBy(func(input map[string]interface{}) bool { return matchByAttributes(input, map[string]interface{}{ "id": "sgrule-2821752134", @@ -1426,7 +1428,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) { "prefix_list_ids": []interface{}{}, }, } - factory.(*terraform.MockResourceFactory).On("CreateAbstractResource", "aws_security_group_rule", rule3.Id, + factory.(*terraform2.MockResourceFactory).On("CreateAbstractResource", "aws_security_group_rule", rule3.Id, mock.MatchedBy(func(input map[string]interface{}) bool { return matchByAttributes(input, map[string]interface{}{ "id": "sgrule-2165103420", @@ -1458,7 +1460,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) { "prefix_list_ids": []interface{}{}, }, } - factory.(*terraform.MockResourceFactory).On("CreateAbstractResource", "aws_security_group_rule", rule4.Id, + factory.(*terraform2.MockResourceFactory).On("CreateAbstractResource", "aws_security_group_rule", rule4.Id, mock.MatchedBy(func(input map[string]interface{}) bool { return matchByAttributes(input, map[string]interface{}{ "id": "sgrule-2582518759", @@ -1554,7 +1556,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) { }, }, mocks: func(factory resource.ResourceFactory, repo resource.SchemaRepositoryInterface) { - factory.(*terraform.MockResourceFactory).On("CreateAbstractResource", aws.AwsIamPolicyAttachmentResourceType, "iduser1", map[string]interface{}{ + factory.(*terraform2.MockResourceFactory).On("CreateAbstractResource", aws.AwsIamPolicyAttachmentResourceType, "iduser1", map[string]interface{}{ "id": "iduser1", "policy_arn": "policy_arn1", "users": []interface{}{"user1"}, @@ -1571,7 +1573,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) { "roles": []interface{}{}, }, }, nil) - factory.(*terraform.MockResourceFactory).On("CreateAbstractResource", aws.AwsIamPolicyAttachmentResourceType, "user1-policy_arn1", map[string]interface{}{ + factory.(*terraform2.MockResourceFactory).On("CreateAbstractResource", aws.AwsIamPolicyAttachmentResourceType, "user1-policy_arn1", map[string]interface{}{ "policy_arn": "policy_arn1", "users": []interface{}{"user1"}, }).Twice().Return(&resource.Resource{ @@ -1582,7 +1584,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) { "users": []interface{}{"user1"}, }, }, nil) - factory.(*terraform.MockResourceFactory).On("CreateAbstractResource", aws.AwsIamPolicyAttachmentResourceType, "idrole1", map[string]interface{}{ + factory.(*terraform2.MockResourceFactory).On("CreateAbstractResource", aws.AwsIamPolicyAttachmentResourceType, "idrole1", map[string]interface{}{ "id": "idrole1", "policy_arn": "policy_arn1", "users": []interface{}{}, @@ -1599,7 +1601,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) { "roles": []interface{}{"role1"}, }, }, nil) - factory.(*terraform.MockResourceFactory).On("CreateAbstractResource", aws.AwsIamPolicyAttachmentResourceType, "role1-policy_arn1", map[string]interface{}{ + factory.(*terraform2.MockResourceFactory).On("CreateAbstractResource", aws.AwsIamPolicyAttachmentResourceType, "role1-policy_arn1", map[string]interface{}{ "policy_arn": "policy_arn1", "roles": []interface{}{"role1"}, }).Twice().Return(&resource.Resource{ diff --git a/pkg/filter/driftignore.go b/pkg/filter/driftignore.go index ac9033f2..762a1245 100644 --- a/pkg/filter/driftignore.go +++ b/pkg/filter/driftignore.go @@ -8,7 +8,7 @@ import ( "github.com/go-git/go-git/v5/plumbing/format/gitignore" "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) const separator = "_-_" diff --git a/pkg/filter/driftignore_test.go b/pkg/filter/driftignore_test.go index 575ea7b2..0695a874 100644 --- a/pkg/filter/driftignore_test.go +++ b/pkg/filter/driftignore_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) func TestDriftIgnore_IsResourceIgnored(t *testing.T) { diff --git a/pkg/filter/filter.go b/pkg/filter/filter.go index 479501b6..5a08a1ae 100644 --- a/pkg/filter/filter.go +++ b/pkg/filter/filter.go @@ -1,6 +1,6 @@ package filter -import "github.com/snyk/driftctl/pkg/resource" +import "github.com/snyk/driftctl/enumeration/resource" type Filter interface { IsTypeIgnored(ty resource.ResourceType) bool diff --git a/pkg/filter/filter_engine.go b/pkg/filter/filter_engine.go index 699dc702..8bcacdea 100644 --- a/pkg/filter/filter_engine.go +++ b/pkg/filter/filter_engine.go @@ -4,7 +4,7 @@ import ( "errors" "github.com/jmespath/go-jmespath" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) type FilterEngine struct { diff --git a/pkg/filter/filter_engine_test.go b/pkg/filter/filter_engine_test.go index 4a487a6c..8eee8a2f 100644 --- a/pkg/filter/filter_engine_test.go +++ b/pkg/filter/filter_engine_test.go @@ -5,7 +5,7 @@ import ( "reflect" "testing" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) func TestFilterEngine_Run(t *testing.T) { diff --git a/pkg/filter/mock_Filter.go b/pkg/filter/mock_Filter.go index f10df601..92be8af5 100644 --- a/pkg/filter/mock_Filter.go +++ b/pkg/filter/mock_Filter.go @@ -3,7 +3,7 @@ package filter import ( - resource "github.com/snyk/driftctl/pkg/resource" + resource "github.com/snyk/driftctl/enumeration/resource" mock "github.com/stretchr/testify/mock" ) diff --git a/pkg/helpers/interface.go b/pkg/helpers/interface.go deleted file mode 100644 index e3f3088f..00000000 --- a/pkg/helpers/interface.go +++ /dev/null @@ -1,34 +0,0 @@ -package helpers - -import "strings" - -func Join(elems []interface{}, sep string) string { - firstElemt, ok := elems[0].(string) - if !ok { - panic("cannot join a slice that contains something else than strings") - } - switch len(elems) { - case 0: - return "" - case 1: - - return firstElemt - } - n := len(sep) * (len(elems) - 1) - for i := 0; i < len(elems); i++ { - n += len(elems[i].(string)) - } - - var b strings.Builder - b.Grow(n) - b.WriteString(firstElemt) - for _, s := range elems[1:] { - b.WriteString(sep) - elem, ok := s.(string) - if !ok { - panic("cannot join a slice that contains something else than strings") - } - b.WriteString(elem) - } - return b.String() -} diff --git a/pkg/iac/supplier/IacChainSupplier.go b/pkg/iac/supplier/IacChainSupplier.go index d63d5461..d2ca27a8 100644 --- a/pkg/iac/supplier/IacChainSupplier.go +++ b/pkg/iac/supplier/IacChainSupplier.go @@ -4,13 +4,15 @@ import ( "context" "runtime" + "github.com/snyk/driftctl/enumeration/parallel" + resource2 "github.com/snyk/driftctl/pkg/resource" + + "github.com/snyk/driftctl/enumeration/resource" "github.com/snyk/driftctl/pkg/iac" - "github.com/snyk/driftctl/pkg/parallel" - "github.com/snyk/driftctl/pkg/resource" ) type IacChainSupplier struct { - suppliers []resource.IaCSupplier + suppliers []resource2.IaCSupplier runner *parallel.ParallelRunner } @@ -28,7 +30,7 @@ func (r *IacChainSupplier) SourceCount() uint { return count } -func (r *IacChainSupplier) AddSupplier(supplier resource.IaCSupplier) { +func (r *IacChainSupplier) AddSupplier(supplier resource2.IaCSupplier) { r.suppliers = append(r.suppliers, supplier) } diff --git a/pkg/iac/supplier/IacChainSupplier_test.go b/pkg/iac/supplier/IacChainSupplier_test.go index 9c7d7885..c2071c72 100644 --- a/pkg/iac/supplier/IacChainSupplier_test.go +++ b/pkg/iac/supplier/IacChainSupplier_test.go @@ -5,28 +5,29 @@ import ( "testing" "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + resource2 "github.com/snyk/driftctl/pkg/resource" ) func TestIacChainSupplier_Resources(t *testing.T) { tests := []struct { name string - initSuppliers func(suppliers *[]resource.IaCSupplier) + initSuppliers func(suppliers *[]resource2.IaCSupplier) want []*resource.Resource wantErr bool }{ { name: "All failed", - initSuppliers: func(suppliers *[]resource.IaCSupplier) { - sup := &resource.MockIaCSupplier{} + initSuppliers: func(suppliers *[]resource2.IaCSupplier) { + sup := &resource2.MockIaCSupplier{} sup.On("Resources").Return(nil, errors.New("1")) *suppliers = append(*suppliers, sup) - sup = &resource.MockIaCSupplier{} + sup = &resource2.MockIaCSupplier{} sup.On("Resources").Return(nil, errors.New("2")) *suppliers = append(*suppliers, sup) - sup = &resource.MockIaCSupplier{} + sup = &resource2.MockIaCSupplier{} sup.On("Resources").Return(nil, errors.New("3")) *suppliers = append(*suppliers, sup) }, @@ -35,16 +36,16 @@ func TestIacChainSupplier_Resources(t *testing.T) { }, { name: "Partial failed", - initSuppliers: func(suppliers *[]resource.IaCSupplier) { - sup := &resource.MockIaCSupplier{} + initSuppliers: func(suppliers *[]resource2.IaCSupplier) { + sup := &resource2.MockIaCSupplier{} sup.On("Resources").Return(nil, errors.New("1")) *suppliers = append(*suppliers, sup) - sup = &resource.MockIaCSupplier{} + sup = &resource2.MockIaCSupplier{} sup.On("Resources").Return(nil, errors.New("2")) *suppliers = append(*suppliers, sup) - sup = &resource.MockIaCSupplier{} + sup = &resource2.MockIaCSupplier{} sup.On("Resources").Return([]*resource.Resource{ &resource.Resource{ Id: "ID", @@ -67,7 +68,7 @@ func TestIacChainSupplier_Resources(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := NewIacChainSupplier() - suppliers := make([]resource.IaCSupplier, 0) + suppliers := make([]resource2.IaCSupplier, 0) tt.initSuppliers(&suppliers) for _, supplier := range suppliers { diff --git a/pkg/iac/supplier/supplier.go b/pkg/iac/supplier/supplier.go index 184f794d..8840afb1 100644 --- a/pkg/iac/supplier/supplier.go +++ b/pkg/iac/supplier/supplier.go @@ -3,19 +3,20 @@ package supplier import ( "fmt" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/terraform" + resource2 "github.com/snyk/driftctl/pkg/resource" + "github.com/pkg/errors" "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/alerter" "github.com/snyk/driftctl/pkg/filter" + "github.com/snyk/driftctl/pkg/iac/config" "github.com/snyk/driftctl/pkg/iac/terraform/state/backend" "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/terraform" - - "github.com/snyk/driftctl/pkg/iac/config" "github.com/snyk/driftctl/pkg/iac/terraform/state" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) var supportedSuppliers = []string{ @@ -37,7 +38,7 @@ func GetIACSupplier(configs []config.SupplierConfig, progress output.Progress, alerter *alerter.Alerter, factory resource.ResourceFactory, - filter filter.Filter) (resource.IaCSupplier, error) { + filter filter.Filter) (resource2.IaCSupplier, error) { chainSupplier := NewIacChainSupplier() for _, config := range configs { @@ -47,7 +48,7 @@ func GetIACSupplier(configs []config.SupplierConfig, deserializer := resource.NewDeserializer(factory) - var supplier resource.IaCSupplier + var supplier resource2.IaCSupplier var err error switch config.Key { case state.TerraformStateReaderSupplier: diff --git a/pkg/iac/supplier/supplier_test.go b/pkg/iac/supplier/supplier_test.go index 2b0b6072..615e77b1 100644 --- a/pkg/iac/supplier/supplier_test.go +++ b/pkg/iac/supplier/supplier_test.go @@ -5,12 +5,13 @@ import ( "reflect" "testing" - "github.com/snyk/driftctl/pkg/alerter" + "github.com/snyk/driftctl/enumeration/alerter" + terraform2 "github.com/snyk/driftctl/enumeration/terraform" + "github.com/snyk/driftctl/pkg/filter" "github.com/snyk/driftctl/pkg/iac/config" "github.com/snyk/driftctl/pkg/iac/terraform/state/backend" "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/terraform" "github.com/snyk/driftctl/test/resource" ) @@ -90,12 +91,12 @@ func TestGetIACSupplier(t *testing.T) { progress.On("Start").Return().Times(1) repo := resource.InitFakeSchemaRepository("aws", "3.19.0") - factory := terraform.NewTerraformResourceFactory(repo) + factory := terraform2.NewTerraformResourceFactory(repo) alerter := alerter.NewAlerter() testFilter := &filter.MockFilter{} - _, err := GetIACSupplier(tt.args.config, terraform.NewProviderLibrary(), tt.args.options, progress, alerter, factory, testFilter) + _, err := GetIACSupplier(tt.args.config, terraform2.NewProviderLibrary(), tt.args.options, progress, alerter, factory, testFilter) if tt.wantErr != nil && err.Error() != tt.wantErr.Error() { t.Errorf("GetIACSupplier() error = %v, wantErr %v", err, tt.wantErr) diff --git a/pkg/iac/terraform/state/terraform_state_reader.go b/pkg/iac/terraform/state/terraform_state_reader.go index d94a2db6..16d081ad 100644 --- a/pkg/iac/terraform/state/terraform_state_reader.go +++ b/pkg/iac/terraform/state/terraform_state_reader.go @@ -4,12 +4,14 @@ import ( "fmt" "strings" + "github.com/snyk/driftctl/enumeration/alerter" + "github.com/snyk/driftctl/enumeration/terraform" + "github.com/hashicorp/terraform/addrs" "github.com/hashicorp/terraform/states" "github.com/hashicorp/terraform/states/statefile" "github.com/pkg/errors" "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/alerter" "github.com/snyk/driftctl/pkg/filter" "github.com/snyk/driftctl/pkg/iac" "github.com/snyk/driftctl/pkg/output" @@ -17,11 +19,10 @@ import ( ctyconvert "github.com/zclconf/go-cty/cty/convert" ctyjson "github.com/zclconf/go-cty/cty/json" + "github.com/snyk/driftctl/enumeration/resource" "github.com/snyk/driftctl/pkg/iac/config" "github.com/snyk/driftctl/pkg/iac/terraform/state/backend" "github.com/snyk/driftctl/pkg/iac/terraform/state/enumerator" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/terraform" ) const TerraformStateReaderSupplier = "tfstate" diff --git a/pkg/iac/terraform/state/terraform_state_reader_test.go b/pkg/iac/terraform/state/terraform_state_reader_test.go index 830775e0..dd15c00e 100644 --- a/pkg/iac/terraform/state/terraform_state_reader_test.go +++ b/pkg/iac/terraform/state/terraform_state_reader_test.go @@ -7,24 +7,29 @@ import ( "strings" "testing" + "github.com/snyk/driftctl/enumeration/remote/aws" + "github.com/snyk/driftctl/enumeration/remote/azurerm" + "github.com/snyk/driftctl/enumeration/remote/github" + "github.com/snyk/driftctl/enumeration/remote/google" + "github.com/snyk/driftctl/enumeration/terraform" + aws2 "github.com/snyk/driftctl/pkg/resource/aws" + azurerm2 "github.com/snyk/driftctl/pkg/resource/azurerm" + github2 "github.com/snyk/driftctl/pkg/resource/github" + google2 "github.com/snyk/driftctl/pkg/resource/google" + "github.com/pkg/errors" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" + resourceazure "github.com/snyk/driftctl/enumeration/resource/azurerm" + resourcegithub "github.com/snyk/driftctl/enumeration/resource/github" + resourcegoogle "github.com/snyk/driftctl/enumeration/resource/google" "github.com/snyk/driftctl/pkg/filter" "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/remote/azurerm" - "github.com/snyk/driftctl/pkg/remote/google" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - resourceazure "github.com/snyk/driftctl/pkg/resource/azurerm" - resourcegithub "github.com/snyk/driftctl/pkg/resource/github" - resourcegoogle "github.com/snyk/driftctl/pkg/resource/google" testresource "github.com/snyk/driftctl/test/resource" terraform2 "github.com/snyk/driftctl/test/terraform" "github.com/stretchr/testify/assert" + "github.com/snyk/driftctl/enumeration/resource" "github.com/snyk/driftctl/pkg/iac/config" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/github" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/terraform" "github.com/snyk/driftctl/test/goldenfile" "github.com/snyk/driftctl/test/mocks" @@ -236,6 +241,7 @@ func TestTerraformStateReader_AWS_Resources(t *testing.T) { repo := testresource.InitFakeSchemaRepository(terraform.AWS, tt.providerVersion) resourceaws.InitResourcesMetadata(repo) + aws2.InitResourcesMetadata(repo) factory := terraform.NewTerraformResourceFactory(repo) @@ -268,7 +274,8 @@ func TestTerraformStateReader_AWS_Resources(t *testing.T) { t.Errorf("Resources() error = %v, wantErr %v", err, tt.wantErr) return } - changelog, err := diff.Diff(convert(got), want) + gotc := convert(got) + changelog, err := diff.Diff(gotc, want) if err != nil { panic(err) } @@ -321,6 +328,7 @@ func TestTerraformStateReader_Github_Resources(t *testing.T) { repo := testresource.InitFakeSchemaRepository(terraform.GITHUB, "4.4.0") resourcegithub.InitResourcesMetadata(repo) + github2.InitResourcesMetadata(repo) factory := terraform.NewTerraformResourceFactory(repo) r := &TerraformStateReader{ @@ -426,6 +434,7 @@ func TestTerraformStateReader_Google_Resources(t *testing.T) { repo := testresource.InitFakeSchemaRepository(terraform.GOOGLE, providerVersion) resourcegoogle.InitResourcesMetadata(repo) + google2.InitResourcesMetadata(repo) factory := terraform.NewTerraformResourceFactory(repo) r := &TerraformStateReader{ @@ -520,6 +529,7 @@ func TestTerraformStateReader_Azure_Resources(t *testing.T) { repo := testresource.InitFakeSchemaRepository(terraform.AZURE, providerVersion) resourceazure.InitResourcesMetadata(repo) + azurerm2.InitResourcesMetadata(repo) factory := terraform.NewTerraformResourceFactory(repo) r := &TerraformStateReader{ diff --git a/pkg/iac/terraform/state/testdata/acc/multiple_states_local/route53/.terraform.lock.hcl b/pkg/iac/terraform/state/testdata/acc/multiple_states_local/route53/.terraform.lock.hcl old mode 100755 new mode 100644 index 1045bc9c..5230555a --- a/pkg/iac/terraform/state/testdata/acc/multiple_states_local/route53/.terraform.lock.hcl +++ b/pkg/iac/terraform/state/testdata/acc/multiple_states_local/route53/.terraform.lock.hcl @@ -3,9 +3,10 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.19.0" - constraints = "~> 3.19.0" + constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", @@ -23,6 +24,7 @@ provider "registry.terraform.io/hashicorp/random" { version = "3.1.0" hashes = [ "h1:BZMEPucF+pbu9gsPk0G0BHx7YP04+tKdq2MrRDF1EDM=", + "h1:rKYu5ZUbXwrLG1w81k7H3nce/Ys6yAxXhWcbtk36HjY=", "zh:2bbb3339f0643b5daa07480ef4397bd23a79963cc364cdfbb4e86354cb7725bc", "zh:3cd456047805bf639fbf2c761b1848880ea703a054f76db51852008b11008626", "zh:4f251b0eda5bb5e3dc26ea4400dba200018213654b69b4a5f96abee815b4f5ff", diff --git a/pkg/iac/terraform/state/testdata/acc/multiple_states_local/s3/.terraform.lock.hcl b/pkg/iac/terraform/state/testdata/acc/multiple_states_local/s3/.terraform.lock.hcl old mode 100755 new mode 100644 index 1045bc9c..5230555a --- a/pkg/iac/terraform/state/testdata/acc/multiple_states_local/s3/.terraform.lock.hcl +++ b/pkg/iac/terraform/state/testdata/acc/multiple_states_local/s3/.terraform.lock.hcl @@ -3,9 +3,10 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.19.0" - constraints = "~> 3.19.0" + constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", @@ -23,6 +24,7 @@ provider "registry.terraform.io/hashicorp/random" { version = "3.1.0" hashes = [ "h1:BZMEPucF+pbu9gsPk0G0BHx7YP04+tKdq2MrRDF1EDM=", + "h1:rKYu5ZUbXwrLG1w81k7H3nce/Ys6yAxXhWcbtk36HjY=", "zh:2bbb3339f0643b5daa07480ef4397bd23a79963cc364cdfbb4e86354cb7725bc", "zh:3cd456047805bf639fbf2c761b1848880ea703a054f76db51852008b11008626", "zh:4f251b0eda5bb5e3dc26ea4400dba200018213654b69b4a5f96abee815b4f5ff", diff --git a/pkg/memstore/store_test.go b/pkg/memstore/store_test.go index 1bc98ddc..6a0ee73e 100644 --- a/pkg/memstore/store_test.go +++ b/pkg/memstore/store_test.go @@ -5,7 +5,7 @@ import ( "sync" "testing" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" "github.com/stretchr/testify/assert" ) diff --git a/pkg/middlewares/aws_alb_listener_transformer.go b/pkg/middlewares/aws_alb_listener_transformer.go index f1bcb91b..b2be0173 100644 --- a/pkg/middlewares/aws_alb_listener_transformer.go +++ b/pkg/middlewares/aws_alb_listener_transformer.go @@ -1,8 +1,8 @@ package middlewares import ( - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // AwsALBListenerTransformer is a simple middleware to turn all aws_alb_listener resources into aws_lb_listener ones diff --git a/pkg/middlewares/aws_alb_listener_transformer_test.go b/pkg/middlewares/aws_alb_listener_transformer_test.go index abcb0ea4..4064a4f8 100644 --- a/pkg/middlewares/aws_alb_listener_transformer_test.go +++ b/pkg/middlewares/aws_alb_listener_transformer_test.go @@ -1,14 +1,14 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsALBListenerTransformer_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_alb_transformer.go b/pkg/middlewares/aws_alb_transformer.go index 1b643e38..ae9bd54e 100644 --- a/pkg/middlewares/aws_alb_transformer.go +++ b/pkg/middlewares/aws_alb_transformer.go @@ -1,8 +1,8 @@ package middlewares import ( - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // AwsALBTransformer is a simple middleware to turn all aws_alb resources into aws_lb ones diff --git a/pkg/middlewares/aws_alb_transformer_test.go b/pkg/middlewares/aws_alb_transformer_test.go index 82cc7bb0..b5bec4aa 100644 --- a/pkg/middlewares/aws_alb_transformer_test.go +++ b/pkg/middlewares/aws_alb_transformer_test.go @@ -1,14 +1,14 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsALBTransformer_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_api_gateway_api_expander.go b/pkg/middlewares/aws_api_gateway_api_expander.go index 1073d433..d05f0f3a 100644 --- a/pkg/middlewares/aws_api_gateway_api_expander.go +++ b/pkg/middlewares/aws_api_gateway_api_expander.go @@ -10,8 +10,8 @@ import ( "github.com/getkin/kin-openapi/openapi3" "github.com/ghodss/yaml" "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Explodes the body attribute of api gateway apis v1|v2 to dedicated resources as per Terraform documentations diff --git a/pkg/middlewares/aws_api_gateway_api_expander_test.go b/pkg/middlewares/aws_api_gateway_api_expander_test.go index ec8f484a..e02ae0c7 100644 --- a/pkg/middlewares/aws_api_gateway_api_expander_test.go +++ b/pkg/middlewares/aws_api_gateway_api_expander_test.go @@ -1,15 +1,15 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "path/filepath" "strings" "testing" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsApiGatewayApiExpander_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_api_gateway_base_path_mapping_reconciler.go b/pkg/middlewares/aws_api_gateway_base_path_mapping_reconciler.go index a27ed13d..a1286aaf 100644 --- a/pkg/middlewares/aws_api_gateway_base_path_mapping_reconciler.go +++ b/pkg/middlewares/aws_api_gateway_base_path_mapping_reconciler.go @@ -1,8 +1,8 @@ package middlewares import ( - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // AwsApiGatewayBasePathMappingReconciler is used to reconcile API Gateway base path mapping (v1 and v2) diff --git a/pkg/middlewares/aws_api_gateway_base_path_mapping_reconciler_test.go b/pkg/middlewares/aws_api_gateway_base_path_mapping_reconciler_test.go index 80e67fbf..be9d8e3a 100644 --- a/pkg/middlewares/aws_api_gateway_base_path_mapping_reconciler_test.go +++ b/pkg/middlewares/aws_api_gateway_base_path_mapping_reconciler_test.go @@ -6,8 +6,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsApiGatewayBasePathMappingReconciler_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_api_gateway_deployment_expander.go b/pkg/middlewares/aws_api_gateway_deployment_expander.go index b88c6358..eb707dca 100644 --- a/pkg/middlewares/aws_api_gateway_deployment_expander.go +++ b/pkg/middlewares/aws_api_gateway_deployment_expander.go @@ -3,8 +3,8 @@ package middlewares import ( "strings" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Create a aws_api_gateway_stage resource from a aws_api_gateway_deployment resource and ignore the latter resource diff --git a/pkg/middlewares/aws_api_gateway_deployment_expander_test.go b/pkg/middlewares/aws_api_gateway_deployment_expander_test.go index 35392119..33bd7413 100644 --- a/pkg/middlewares/aws_api_gateway_deployment_expander_test.go +++ b/pkg/middlewares/aws_api_gateway_deployment_expander_test.go @@ -1,15 +1,14 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" "github.com/aws/aws-sdk-go/aws/awsutil" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/r3labs/diff/v2" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsApiGatewayDeploymentExpander_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_api_gateway_domain_names_reconciler.go b/pkg/middlewares/aws_api_gateway_domain_names_reconciler.go index cb65cd7b..f56ef52d 100644 --- a/pkg/middlewares/aws_api_gateway_domain_names_reconciler.go +++ b/pkg/middlewares/aws_api_gateway_domain_names_reconciler.go @@ -1,8 +1,8 @@ package middlewares import ( - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Used to reconcile API Gateway domain names (v1 and v2) from both remote diff --git a/pkg/middlewares/aws_api_gateway_domain_names_reconciler_test.go b/pkg/middlewares/aws_api_gateway_domain_names_reconciler_test.go index 1aee8ff9..4d81325d 100644 --- a/pkg/middlewares/aws_api_gateway_domain_names_reconciler_test.go +++ b/pkg/middlewares/aws_api_gateway_domain_names_reconciler_test.go @@ -6,8 +6,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsApiGatewayDomainNamesReconciler_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_api_gateway_resource_expander.go b/pkg/middlewares/aws_api_gateway_resource_expander.go index 2aaf54ec..593ae85b 100644 --- a/pkg/middlewares/aws_api_gateway_resource_expander.go +++ b/pkg/middlewares/aws_api_gateway_resource_expander.go @@ -3,8 +3,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Explodes api gateway default resource found in aws_api_gateway_rest_api.root_resource_id from state resources to dedicated resources diff --git a/pkg/middlewares/aws_api_gateway_resource_expander_test.go b/pkg/middlewares/aws_api_gateway_resource_expander_test.go index ffaac6e4..e29ac560 100644 --- a/pkg/middlewares/aws_api_gateway_resource_expander_test.go +++ b/pkg/middlewares/aws_api_gateway_resource_expander_test.go @@ -1,15 +1,14 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" "github.com/aws/aws-sdk-go/aws/awsutil" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/r3labs/diff/v2" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsApiGatewayResourceExpander_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_api_gateway_rest_api_policy_expander.go b/pkg/middlewares/aws_api_gateway_rest_api_policy_expander.go index c6ae37dd..d6f69080 100644 --- a/pkg/middlewares/aws_api_gateway_rest_api_policy_expander.go +++ b/pkg/middlewares/aws_api_gateway_rest_api_policy_expander.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Explodes policy found in aws_api_gateway_rest_api.policy from state resources to dedicated resources diff --git a/pkg/middlewares/aws_api_gateway_rest_api_policy_expander_test.go b/pkg/middlewares/aws_api_gateway_rest_api_policy_expander_test.go index 45cefd9b..fc60c650 100644 --- a/pkg/middlewares/aws_api_gateway_rest_api_policy_expander_test.go +++ b/pkg/middlewares/aws_api_gateway_rest_api_policy_expander_test.go @@ -1,15 +1,14 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" "github.com/aws/aws-sdk-go/aws/awsutil" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/r3labs/diff/v2" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsApiGatewayRestApiPolicyPolicyExpander_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_bucket_policy_expander.go b/pkg/middlewares/aws_bucket_policy_expander.go index 4c3631ac..d7d27822 100644 --- a/pkg/middlewares/aws_bucket_policy_expander.go +++ b/pkg/middlewares/aws_bucket_policy_expander.go @@ -3,8 +3,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Explodes policy found in aws_s3_bucket.policy from state resources to dedicated resources diff --git a/pkg/middlewares/aws_bucket_policy_expander_test.go b/pkg/middlewares/aws_bucket_policy_expander_test.go index 7f393730..941ef244 100644 --- a/pkg/middlewares/aws_bucket_policy_expander_test.go +++ b/pkg/middlewares/aws_bucket_policy_expander_test.go @@ -1,16 +1,15 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" awssdk "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awsutil" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/r3labs/diff/v2" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsBucketPolicyExpander_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_console_api_gateway_gateway_response.go b/pkg/middlewares/aws_console_api_gateway_gateway_response.go index b2310b9a..15ab216b 100644 --- a/pkg/middlewares/aws_console_api_gateway_gateway_response.go +++ b/pkg/middlewares/aws_console_api_gateway_gateway_response.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Each API Gateway rest API has by design all the gateway responses available to edit in the console diff --git a/pkg/middlewares/aws_console_api_gateway_gateway_response_test.go b/pkg/middlewares/aws_console_api_gateway_gateway_response_test.go index 28bdea28..1092e1b0 100644 --- a/pkg/middlewares/aws_console_api_gateway_gateway_response_test.go +++ b/pkg/middlewares/aws_console_api_gateway_gateway_response_test.go @@ -6,8 +6,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsConsoleApiGatewayGatewayResponse_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_default_api_gateway_account.go b/pkg/middlewares/aws_default_api_gateway_account.go index 0d1d79b1..43653d3f 100644 --- a/pkg/middlewares/aws_default_api_gateway_account.go +++ b/pkg/middlewares/aws_default_api_gateway_account.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // AwsDefaultApiGatewayAccount is a middleware that ignores the default API Gateway account resource in the current region. diff --git a/pkg/middlewares/aws_default_api_gateway_account_test.go b/pkg/middlewares/aws_default_api_gateway_account_test.go index a391e868..68095585 100644 --- a/pkg/middlewares/aws_default_api_gateway_account_test.go +++ b/pkg/middlewares/aws_default_api_gateway_account_test.go @@ -6,8 +6,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsDefaultApiGatewayAccount_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_default_igw_route.go b/pkg/middlewares/aws_default_igw_route.go index b727dff1..df38d756 100644 --- a/pkg/middlewares/aws_default_igw_route.go +++ b/pkg/middlewares/aws_default_igw_route.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Each region has a default vpc which has an internet gateway attached and thus the route table of this diff --git a/pkg/middlewares/aws_default_igw_route_test.go b/pkg/middlewares/aws_default_igw_route_test.go index 10a004fa..8d1918f2 100644 --- a/pkg/middlewares/aws_default_igw_route_test.go +++ b/pkg/middlewares/aws_default_igw_route_test.go @@ -6,8 +6,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsDefaultInternetGatewayRoute_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_default_internet_gateway.go b/pkg/middlewares/aws_default_internet_gateway.go index a6424111..675d97e5 100644 --- a/pkg/middlewares/aws_default_internet_gateway.go +++ b/pkg/middlewares/aws_default_internet_gateway.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Each default vpc has an internet gateway attached that should not be seen as unmanaged if not managed by IaC diff --git a/pkg/middlewares/aws_default_internet_gateway_test.go b/pkg/middlewares/aws_default_internet_gateway_test.go index 4e324afa..757e5af1 100644 --- a/pkg/middlewares/aws_default_internet_gateway_test.go +++ b/pkg/middlewares/aws_default_internet_gateway_test.go @@ -6,8 +6,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsDefaultInternetGateway_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_default_network_acl.go b/pkg/middlewares/aws_default_network_acl.go index 8429caea..f1a15f15 100644 --- a/pkg/middlewares/aws_default_network_acl.go +++ b/pkg/middlewares/aws_default_network_acl.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Default network ACL should not be shown as unmanaged as they are present by default diff --git a/pkg/middlewares/aws_default_network_acl_rule.go b/pkg/middlewares/aws_default_network_acl_rule.go index 1ef85e9a..776b5373 100644 --- a/pkg/middlewares/aws_default_network_acl_rule.go +++ b/pkg/middlewares/aws_default_network_acl_rule.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Default network acl rules should not be shown as unmanaged as they are present by default diff --git a/pkg/middlewares/aws_default_network_acl_rule_test.go b/pkg/middlewares/aws_default_network_acl_rule_test.go index f27ec11b..bce59ad2 100644 --- a/pkg/middlewares/aws_default_network_acl_rule_test.go +++ b/pkg/middlewares/aws_default_network_acl_rule_test.go @@ -6,8 +6,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsDefaultNetworkACLRule_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_default_network_acl_test.go b/pkg/middlewares/aws_default_network_acl_test.go index db40f18c..d8dd82d7 100644 --- a/pkg/middlewares/aws_default_network_acl_test.go +++ b/pkg/middlewares/aws_default_network_acl_test.go @@ -6,8 +6,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsDefaultNetworkACL_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_default_route.go b/pkg/middlewares/aws_default_route.go index bcdbea62..cf603f64 100644 --- a/pkg/middlewares/aws_default_route.go +++ b/pkg/middlewares/aws_default_route.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Default routes should not be shown as unmanaged as they are present by default diff --git a/pkg/middlewares/aws_default_route_table.go b/pkg/middlewares/aws_default_route_table.go index 657bf27f..e2bb738a 100644 --- a/pkg/middlewares/aws_default_route_table.go +++ b/pkg/middlewares/aws_default_route_table.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Default route table should not be shown as unmanaged as they are present by default diff --git a/pkg/middlewares/aws_default_route_table_test.go b/pkg/middlewares/aws_default_route_table_test.go index 4cc7b729..26fdc3e8 100644 --- a/pkg/middlewares/aws_default_route_table_test.go +++ b/pkg/middlewares/aws_default_route_table_test.go @@ -6,8 +6,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsDefaultRouteTable_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_default_route_test.go b/pkg/middlewares/aws_default_route_test.go index 37b5c95a..3ca2eb0a 100644 --- a/pkg/middlewares/aws_default_route_test.go +++ b/pkg/middlewares/aws_default_route_test.go @@ -6,8 +6,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsDefaultRoute_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_default_security_group_rule.go b/pkg/middlewares/aws_default_security_group_rule.go index ebc68f68..a939cb09 100644 --- a/pkg/middlewares/aws_default_security_group_rule.go +++ b/pkg/middlewares/aws_default_security_group_rule.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Remove default security group rules of the default security group from remote resources diff --git a/pkg/middlewares/aws_default_security_group_rule_test.go b/pkg/middlewares/aws_default_security_group_rule_test.go index c09db7d0..775dbc54 100644 --- a/pkg/middlewares/aws_default_security_group_rule_test.go +++ b/pkg/middlewares/aws_default_security_group_rule_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsDefaultSecurityGroupRule_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_default_sqs_queue_policy.go b/pkg/middlewares/aws_default_sqs_queue_policy.go index 9e6921ff..e18a21f3 100644 --- a/pkg/middlewares/aws_default_sqs_queue_policy.go +++ b/pkg/middlewares/aws_default_sqs_queue_policy.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // SQS queues from AWS have a weird behaviour when we fetch them. diff --git a/pkg/middlewares/aws_default_sqs_queue_policy_test.go b/pkg/middlewares/aws_default_sqs_queue_policy_test.go index 7e1200e6..0c6aa513 100644 --- a/pkg/middlewares/aws_default_sqs_queue_policy_test.go +++ b/pkg/middlewares/aws_default_sqs_queue_policy_test.go @@ -7,8 +7,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsDefaultSQSQueuePolicy_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_default_subnet.go b/pkg/middlewares/aws_default_subnet.go index 85d4f8a8..adc6091c 100644 --- a/pkg/middlewares/aws_default_subnet.go +++ b/pkg/middlewares/aws_default_subnet.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Default subnet should not be shown as unmanaged as they are present by default diff --git a/pkg/middlewares/aws_default_vpc.go b/pkg/middlewares/aws_default_vpc.go index ddfac18c..a12ffccc 100644 --- a/pkg/middlewares/aws_default_vpc.go +++ b/pkg/middlewares/aws_default_vpc.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Default VPC should not be shown as unmanaged as they are present by default diff --git a/pkg/middlewares/aws_defaults.go b/pkg/middlewares/aws_defaults.go index 18456a61..7b9b1c87 100644 --- a/pkg/middlewares/aws_defaults.go +++ b/pkg/middlewares/aws_defaults.go @@ -4,8 +4,8 @@ import ( "strings" "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) const defaultIamRolePathPrefix = "/aws-service-role/" diff --git a/pkg/middlewares/aws_defaults_test.go b/pkg/middlewares/aws_defaults_test.go index 0c3e5bdf..9810a819 100644 --- a/pkg/middlewares/aws_defaults_test.go +++ b/pkg/middlewares/aws_defaults_test.go @@ -5,8 +5,8 @@ import ( "github.com/stretchr/testify/assert" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsDefaults_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_ebs_encryption_by_default_reconciler.go b/pkg/middlewares/aws_ebs_encryption_by_default_reconciler.go index ef684048..648427bd 100644 --- a/pkg/middlewares/aws_ebs_encryption_by_default_reconciler.go +++ b/pkg/middlewares/aws_ebs_encryption_by_default_reconciler.go @@ -1,8 +1,8 @@ package middlewares import ( - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // AwsEbsEncryptionByDefaultReconciler is a middleware that either creates an 'aws_ebs_encryption_by_default' resource diff --git a/pkg/middlewares/aws_ebs_encryption_by_default_reconciler_test.go b/pkg/middlewares/aws_ebs_encryption_by_default_reconciler_test.go index 5bbbd71d..7a60d3c2 100644 --- a/pkg/middlewares/aws_ebs_encryption_by_default_reconciler_test.go +++ b/pkg/middlewares/aws_ebs_encryption_by_default_reconciler_test.go @@ -1,15 +1,14 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/terraform" - - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsEbsEncryptionByDefaultReconciler_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_eip_association_expander.go b/pkg/middlewares/aws_eip_association_expander.go index 33d663a5..75981514 100644 --- a/pkg/middlewares/aws_eip_association_expander.go +++ b/pkg/middlewares/aws_eip_association_expander.go @@ -1,8 +1,8 @@ package middlewares import ( - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) /** diff --git a/pkg/middlewares/aws_iam_policy_attachement_transformer.go b/pkg/middlewares/aws_iam_policy_attachement_transformer.go index 76ad5d42..84074f62 100644 --- a/pkg/middlewares/aws_iam_policy_attachement_transformer.go +++ b/pkg/middlewares/aws_iam_policy_attachement_transformer.go @@ -1,8 +1,8 @@ package middlewares import ( - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) /** diff --git a/pkg/middlewares/aws_iam_policy_attachement_transformer_test.go b/pkg/middlewares/aws_iam_policy_attachement_transformer_test.go index 1690dea6..ae9e3bcd 100644 --- a/pkg/middlewares/aws_iam_policy_attachement_transformer_test.go +++ b/pkg/middlewares/aws_iam_policy_attachement_transformer_test.go @@ -1,14 +1,14 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" testresource "github.com/snyk/driftctl/test/resource" ) diff --git a/pkg/middlewares/aws_instance_block_device.go b/pkg/middlewares/aws_instance_block_device.go index b086d80d..926faf5d 100644 --- a/pkg/middlewares/aws_instance_block_device.go +++ b/pkg/middlewares/aws_instance_block_device.go @@ -3,8 +3,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Remove root_block_device from aws_instance resources and create dedicated aws_ebs_volume resources diff --git a/pkg/middlewares/aws_instance_block_device_test.go b/pkg/middlewares/aws_instance_block_device_test.go index 9e374ac4..2429f90a 100644 --- a/pkg/middlewares/aws_instance_block_device_test.go +++ b/pkg/middlewares/aws_instance_block_device_test.go @@ -1,13 +1,13 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/resource" "github.com/stretchr/testify/mock" ) diff --git a/pkg/middlewares/aws_instance_eip.go b/pkg/middlewares/aws_instance_eip.go index 07b6dc0d..e3a5c95b 100644 --- a/pkg/middlewares/aws_instance_eip.go +++ b/pkg/middlewares/aws_instance_eip.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) type AwsInstanceEIP struct{} diff --git a/pkg/middlewares/aws_instance_eip_test.go b/pkg/middlewares/aws_instance_eip_test.go index a149f9c8..c671dc0d 100644 --- a/pkg/middlewares/aws_instance_eip_test.go +++ b/pkg/middlewares/aws_instance_eip_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsInstanceEIP_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_nat_gateway_eip_assoc.go b/pkg/middlewares/aws_nat_gateway_eip_assoc.go index 8914b88a..5dc96ede 100644 --- a/pkg/middlewares/aws_nat_gateway_eip_assoc.go +++ b/pkg/middlewares/aws_nat_gateway_eip_assoc.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) type AwsNatGatewayEipAssoc struct{} diff --git a/pkg/middlewares/aws_nat_gateway_eip_assoc_test.go b/pkg/middlewares/aws_nat_gateway_eip_assoc_test.go index 995ef711..abbb3f64 100644 --- a/pkg/middlewares/aws_nat_gateway_eip_assoc_test.go +++ b/pkg/middlewares/aws_nat_gateway_eip_assoc_test.go @@ -6,8 +6,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsNatGatewayEipAssoc_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_network_acl_expander.go b/pkg/middlewares/aws_network_acl_expander.go index 8bdd48ac..3f09d6ea 100644 --- a/pkg/middlewares/aws_network_acl_expander.go +++ b/pkg/middlewares/aws_network_acl_expander.go @@ -1,8 +1,8 @@ package middlewares import ( - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // This middelware goal is to explode aws_network_acl ingress and egress block into a set of aws_network_acl_rule diff --git a/pkg/middlewares/aws_network_acl_expander_test.go b/pkg/middlewares/aws_network_acl_expander_test.go index f80c7f40..9b9bf794 100644 --- a/pkg/middlewares/aws_network_acl_expander_test.go +++ b/pkg/middlewares/aws_network_acl_expander_test.go @@ -1,14 +1,14 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestAwsNetworkACLExpander_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_rds_cluster_instance_expander.go b/pkg/middlewares/aws_rds_cluster_instance_expander.go index 36d358ba..50d391c0 100644 --- a/pkg/middlewares/aws_rds_cluster_instance_expander.go +++ b/pkg/middlewares/aws_rds_cluster_instance_expander.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // AwsRDSClusterInstanceExpander search for cluster instances from state to import corresponding remote db instances. diff --git a/pkg/middlewares/aws_rds_cluster_instance_expander_test.go b/pkg/middlewares/aws_rds_cluster_instance_expander_test.go index 8e4adcd3..ce62dba8 100644 --- a/pkg/middlewares/aws_rds_cluster_instance_expander_test.go +++ b/pkg/middlewares/aws_rds_cluster_instance_expander_test.go @@ -1,11 +1,11 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "testing" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/stretchr/testify/assert" ) diff --git a/pkg/middlewares/aws_role_managed_policy_expander.go b/pkg/middlewares/aws_role_managed_policy_expander.go index ded2f5f1..30d301eb 100644 --- a/pkg/middlewares/aws_role_managed_policy_expander.go +++ b/pkg/middlewares/aws_role_managed_policy_expander.go @@ -4,8 +4,8 @@ import ( "fmt" "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // The role of this middleware is to expand policy contained in `managed_policy_arns` to dedicated `aws_iam_policy_attachment` diff --git a/pkg/middlewares/aws_route_table_expander.go b/pkg/middlewares/aws_route_table_expander.go index 8c0ddf04..e4dcbd63 100644 --- a/pkg/middlewares/aws_route_table_expander.go +++ b/pkg/middlewares/aws_route_table_expander.go @@ -2,10 +2,10 @@ package middlewares import ( "github.com/sirupsen/logrus" + "github.com/snyk/driftctl/enumeration/alerter" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Explodes routes found in aws_default_route_table.route and aws_route_table.route to dedicated resources diff --git a/pkg/middlewares/aws_route_table_expander_test.go b/pkg/middlewares/aws_route_table_expander_test.go index 4d27adb7..cebc6825 100644 --- a/pkg/middlewares/aws_route_table_expander_test.go +++ b/pkg/middlewares/aws_route_table_expander_test.go @@ -1,6 +1,7 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" @@ -8,10 +9,9 @@ import ( "github.com/r3labs/diff/v2" "github.com/stretchr/testify/mock" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" ) func TestAwsRouteTableExpander_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_s3_bucket_public_access_block_reconcilier.go b/pkg/middlewares/aws_s3_bucket_public_access_block_reconcilier.go index 92143a6d..b040c4e3 100644 --- a/pkg/middlewares/aws_s3_bucket_public_access_block_reconcilier.go +++ b/pkg/middlewares/aws_s3_bucket_public_access_block_reconcilier.go @@ -3,8 +3,8 @@ package middlewares import ( awssdk "github.com/aws/aws-sdk-go/aws" "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // AwsS3BucketPublicAccessBlockReconciler middleware ignores every s3 bucket public block that is set to the default values (every option set to false) diff --git a/pkg/middlewares/aws_s3_bucket_public_access_block_reconcilier_test.go b/pkg/middlewares/aws_s3_bucket_public_access_block_reconcilier_test.go index c129e086..56c00943 100644 --- a/pkg/middlewares/aws_s3_bucket_public_access_block_reconcilier_test.go +++ b/pkg/middlewares/aws_s3_bucket_public_access_block_reconcilier_test.go @@ -3,8 +3,8 @@ package middlewares import ( "testing" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/stretchr/testify/assert" ) diff --git a/pkg/middlewares/aws_sns_topic_policy_expander.go b/pkg/middlewares/aws_sns_topic_policy_expander.go index bbe4e4e5..224c8852 100644 --- a/pkg/middlewares/aws_sns_topic_policy_expander.go +++ b/pkg/middlewares/aws_sns_topic_policy_expander.go @@ -4,8 +4,8 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Explodes policy found in aws_sns_topic from state resources to aws_sns_topic_policy resources diff --git a/pkg/middlewares/aws_sns_topic_policy_expander_test.go b/pkg/middlewares/aws_sns_topic_policy_expander_test.go index 55e8bada..f443a189 100644 --- a/pkg/middlewares/aws_sns_topic_policy_expander_test.go +++ b/pkg/middlewares/aws_sns_topic_policy_expander_test.go @@ -1,19 +1,19 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" "github.com/stretchr/testify/mock" - awsresource "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" + awsresource "github.com/snyk/driftctl/enumeration/resource/aws" testresource "github.com/snyk/driftctl/test/resource" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) func TestAwsSNSTopicPolicyExpander_Execute(t *testing.T) { diff --git a/pkg/middlewares/aws_sqs_queue_policy_expander.go b/pkg/middlewares/aws_sqs_queue_policy_expander.go index 3e532aaa..0d4c414e 100644 --- a/pkg/middlewares/aws_sqs_queue_policy_expander.go +++ b/pkg/middlewares/aws_sqs_queue_policy_expander.go @@ -3,8 +3,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Explodes policy found in aws_sqs_queue.policy from state resources to dedicated resources diff --git a/pkg/middlewares/aws_sqs_queue_policy_expander_test.go b/pkg/middlewares/aws_sqs_queue_policy_expander_test.go index 60738c28..9db97997 100644 --- a/pkg/middlewares/aws_sqs_queue_policy_expander_test.go +++ b/pkg/middlewares/aws_sqs_queue_policy_expander_test.go @@ -1,15 +1,15 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/stretchr/testify/mock" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" testresource "github.com/snyk/driftctl/test/resource" "github.com/r3labs/diff/v2" diff --git a/pkg/middlewares/azurerm_route_expander.go b/pkg/middlewares/azurerm_route_expander.go index 11f59b56..c402e497 100644 --- a/pkg/middlewares/azurerm_route_expander.go +++ b/pkg/middlewares/azurerm_route_expander.go @@ -3,8 +3,8 @@ package middlewares import ( "strings" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" ) // Explodes routes found in azurerm_route_table.route from state resources to dedicated resources diff --git a/pkg/middlewares/azurerm_route_expander_test.go b/pkg/middlewares/azurerm_route_expander_test.go index ba8dc7e8..7a24b7a8 100644 --- a/pkg/middlewares/azurerm_route_expander_test.go +++ b/pkg/middlewares/azurerm_route_expander_test.go @@ -1,14 +1,14 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" ) func TestAzurermRouteExpander_Execute(t *testing.T) { diff --git a/pkg/middlewares/azurerm_subnet_expander.go b/pkg/middlewares/azurerm_subnet_expander.go index 4e504900..4a0a82b4 100644 --- a/pkg/middlewares/azurerm_subnet_expander.go +++ b/pkg/middlewares/azurerm_subnet_expander.go @@ -1,8 +1,8 @@ package middlewares import ( - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" ) // Explodes subnet found in azurerm_virtual_network.subnet from state resources to dedicated resources diff --git a/pkg/middlewares/azurerm_subnet_expander_test.go b/pkg/middlewares/azurerm_subnet_expander_test.go index 82b244b3..95af5fe5 100644 --- a/pkg/middlewares/azurerm_subnet_expander_test.go +++ b/pkg/middlewares/azurerm_subnet_expander_test.go @@ -1,14 +1,14 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" ) func TestAzurermSubnetExpander_Execute(t *testing.T) { diff --git a/pkg/middlewares/chain_middleware.go b/pkg/middlewares/chain_middleware.go index c4ab0109..62fa5c6d 100644 --- a/pkg/middlewares/chain_middleware.go +++ b/pkg/middlewares/chain_middleware.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) type Chain []Middleware diff --git a/pkg/middlewares/chain_middleware_test.go b/pkg/middlewares/chain_middleware_test.go index e16a48b4..abe23404 100644 --- a/pkg/middlewares/chain_middleware_test.go +++ b/pkg/middlewares/chain_middleware_test.go @@ -4,7 +4,7 @@ import ( "errors" "testing" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) var callCounters map[string]int diff --git a/pkg/middlewares/default_vpc_test.go b/pkg/middlewares/default_vpc_test.go index 7f86f6ec..acc1dc30 100644 --- a/pkg/middlewares/default_vpc_test.go +++ b/pkg/middlewares/default_vpc_test.go @@ -3,9 +3,9 @@ package middlewares import ( "testing" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource/aws" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) func TestAwsDefaultVPCShouldBeIgnored(t *testing.T) { diff --git a/pkg/middlewares/google_compute_instance_group_manager_reconciler.go b/pkg/middlewares/google_compute_instance_group_manager_reconciler.go index 24dade78..471845de 100644 --- a/pkg/middlewares/google_compute_instance_group_manager_reconciler.go +++ b/pkg/middlewares/google_compute_instance_group_manager_reconciler.go @@ -1,8 +1,8 @@ package middlewares import ( - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" ) type GoogleComputeInstanceGroupManagerReconciler struct{} diff --git a/pkg/middlewares/google_compute_instance_group_manager_reconciler_test.go b/pkg/middlewares/google_compute_instance_group_manager_reconciler_test.go index ac0e28b3..16d8aa1c 100644 --- a/pkg/middlewares/google_compute_instance_group_manager_reconciler_test.go +++ b/pkg/middlewares/google_compute_instance_group_manager_reconciler_test.go @@ -6,8 +6,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" ) func TestGoogleComputeInstanceGroupManagerExpander_Execute(t *testing.T) { diff --git a/pkg/middlewares/google_default_iam_member.go b/pkg/middlewares/google_default_iam_member.go index 93a18e59..a50c4c05 100644 --- a/pkg/middlewares/google_default_iam_member.go +++ b/pkg/middlewares/google_default_iam_member.go @@ -5,8 +5,8 @@ import ( "strings" "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" ) // Some service accounts are created by default when activating APIs, this middleware will filter them unless they are managed. diff --git a/pkg/middlewares/google_default_iam_member_test.go b/pkg/middlewares/google_default_iam_member_test.go index 38b90643..130477db 100644 --- a/pkg/middlewares/google_default_iam_member_test.go +++ b/pkg/middlewares/google_default_iam_member_test.go @@ -6,8 +6,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" ) func TestGoogleDefaultIAMMember_Execute(t *testing.T) { diff --git a/pkg/middlewares/google_iam_binding_tranformer_test.go b/pkg/middlewares/google_iam_binding_tranformer_test.go index 5c43eec9..57df0cb7 100644 --- a/pkg/middlewares/google_iam_binding_tranformer_test.go +++ b/pkg/middlewares/google_iam_binding_tranformer_test.go @@ -1,14 +1,14 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" ) func TestGoogleProjectIAMBindingTransformer_Execute(t *testing.T) { diff --git a/pkg/middlewares/google_iam_binding_transformer.go b/pkg/middlewares/google_iam_binding_transformer.go index 820f33ea..789654f7 100644 --- a/pkg/middlewares/google_iam_binding_transformer.go +++ b/pkg/middlewares/google_iam_binding_transformer.go @@ -4,8 +4,8 @@ import ( "fmt" "strings" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" ) // GoogleIAMBindingTransformer Transforms Bucket IAM binding in bucket iam member to ease comparison. diff --git a/pkg/middlewares/google_iam_policy_tranformer_test.go b/pkg/middlewares/google_iam_policy_tranformer_test.go index d4c3de95..7d96c6a6 100644 --- a/pkg/middlewares/google_iam_policy_tranformer_test.go +++ b/pkg/middlewares/google_iam_policy_tranformer_test.go @@ -1,14 +1,14 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" ) func TestGoogleProjectIAMPolicyTransformer_Execute(t *testing.T) { diff --git a/pkg/middlewares/google_iam_policy_transformer.go b/pkg/middlewares/google_iam_policy_transformer.go index 2b9b8a46..70f3aa06 100644 --- a/pkg/middlewares/google_iam_policy_transformer.go +++ b/pkg/middlewares/google_iam_policy_transformer.go @@ -5,8 +5,8 @@ import ( "fmt" "strings" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" ) // GoogleStorageBucketIAMPolicyTransformer Transforms Bucket IAM policy in bucket iam binding to ease comparison. diff --git a/pkg/middlewares/google_legacy_bucket_iam_member.go b/pkg/middlewares/google_legacy_bucket_iam_member.go index 1328ad47..184162ab 100644 --- a/pkg/middlewares/google_legacy_bucket_iam_member.go +++ b/pkg/middlewares/google_legacy_bucket_iam_member.go @@ -4,8 +4,8 @@ import ( "strings" "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" ) // Creating buckets add legacy role bindings, this middleware will filter them unless they are managed. diff --git a/pkg/middlewares/google_legacy_bucket_iam_member_test.go b/pkg/middlewares/google_legacy_bucket_iam_member_test.go index 63bb92d2..4e1ee7f7 100644 --- a/pkg/middlewares/google_legacy_bucket_iam_member_test.go +++ b/pkg/middlewares/google_legacy_bucket_iam_member_test.go @@ -6,8 +6,8 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" ) func TestGoogleLegacyBucketIAMMember_Execute(t *testing.T) { diff --git a/pkg/middlewares/iam_policy_attachment_expander.go b/pkg/middlewares/iam_policy_attachment_expander.go index 787e1cfc..fd87bfc4 100644 --- a/pkg/middlewares/iam_policy_attachment_expander.go +++ b/pkg/middlewares/iam_policy_attachment_expander.go @@ -3,8 +3,8 @@ package middlewares import ( "fmt" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" ) // Split Policy attachment when there is multiple user and groups and generate a repeatable id diff --git a/pkg/middlewares/iam_policy_attachment_expander_test.go b/pkg/middlewares/iam_policy_attachment_expander_test.go index 964a33f6..29cdac28 100644 --- a/pkg/middlewares/iam_policy_attachment_expander_test.go +++ b/pkg/middlewares/iam_policy_attachment_expander_test.go @@ -1,16 +1,15 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "strings" "testing" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" + "github.com/snyk/driftctl/enumeration/resource/aws" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) func TestIamPolicyAttachmentExpander_Execute(t *testing.T) { diff --git a/pkg/middlewares/middlewares.go b/pkg/middlewares/middlewares.go index 09d20327..6789f326 100644 --- a/pkg/middlewares/middlewares.go +++ b/pkg/middlewares/middlewares.go @@ -1,6 +1,6 @@ package middlewares -import "github.com/snyk/driftctl/pkg/resource" +import "github.com/snyk/driftctl/enumeration/resource" type Middleware interface { Execute(remoteResources, resourcesFromState *[]*resource.Resource) error diff --git a/pkg/middlewares/route53_records.go b/pkg/middlewares/route53_records.go index 7df12e35..4959a089 100644 --- a/pkg/middlewares/route53_records.go +++ b/pkg/middlewares/route53_records.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Remote NS and SAO records from remote state if not managed by IAC diff --git a/pkg/middlewares/route53_records_id_reconcilier.go b/pkg/middlewares/route53_records_id_reconcilier.go index 06db9f25..29b23721 100644 --- a/pkg/middlewares/route53_records_id_reconcilier.go +++ b/pkg/middlewares/route53_records_id_reconcilier.go @@ -4,8 +4,8 @@ import ( "strings" "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Since AWS returns the FQDN as the name of the remote record, we must change the Id of the diff --git a/pkg/middlewares/route53_records_id_reconcilier_test.go b/pkg/middlewares/route53_records_id_reconcilier_test.go index 74a9bd8a..c30b75c0 100644 --- a/pkg/middlewares/route53_records_id_reconcilier_test.go +++ b/pkg/middlewares/route53_records_id_reconcilier_test.go @@ -3,8 +3,8 @@ package middlewares import ( "testing" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/stretchr/testify/assert" ) diff --git a/pkg/middlewares/route53_records_test.go b/pkg/middlewares/route53_records_test.go index 7e9bc977..a74fe2ef 100644 --- a/pkg/middlewares/route53_records_test.go +++ b/pkg/middlewares/route53_records_test.go @@ -3,9 +3,9 @@ package middlewares import ( "testing" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource/aws" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) func TestDefaultRoute53RecordShouldBeIgnored(t *testing.T) { diff --git a/pkg/middlewares/s3_bucket_acl.go b/pkg/middlewares/s3_bucket_acl.go index 4764978f..11ea3ecb 100644 --- a/pkg/middlewares/s3_bucket_acl.go +++ b/pkg/middlewares/s3_bucket_acl.go @@ -3,8 +3,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Remove grant field on remote resources when acl field != private in state diff --git a/pkg/middlewares/s3_bucket_acl_test.go b/pkg/middlewares/s3_bucket_acl_test.go index e353afb7..9a13a8f6 100644 --- a/pkg/middlewares/s3_bucket_acl_test.go +++ b/pkg/middlewares/s3_bucket_acl_test.go @@ -5,9 +5,9 @@ import ( "github.com/stretchr/testify/assert" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource/aws" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) func TestS3BucketAcl_Execute(t *testing.T) { diff --git a/pkg/middlewares/tags_all_manager.go b/pkg/middlewares/tags_all_manager.go index 1e293f7b..a84b76ce 100644 --- a/pkg/middlewares/tags_all_manager.go +++ b/pkg/middlewares/tags_all_manager.go @@ -1,7 +1,7 @@ package middlewares import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) // Manage tags_all attribute on each compatible resources diff --git a/pkg/middlewares/tags_all_manager_test.go b/pkg/middlewares/tags_all_manager_test.go index 9cc05b5d..a89a9a48 100644 --- a/pkg/middlewares/tags_all_manager_test.go +++ b/pkg/middlewares/tags_all_manager_test.go @@ -6,7 +6,7 @@ import ( "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) func TestTagsAllManager_Execute(t *testing.T) { diff --git a/pkg/middlewares/vpc_security_group_default.go b/pkg/middlewares/vpc_security_group_default.go index bfc76a0a..68ab7c9c 100644 --- a/pkg/middlewares/vpc_security_group_default.go +++ b/pkg/middlewares/vpc_security_group_default.go @@ -2,8 +2,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) // Remove default security group from remote resources diff --git a/pkg/middlewares/vpc_security_group_default_test.go b/pkg/middlewares/vpc_security_group_default_test.go index 38d6f2d1..3700727a 100644 --- a/pkg/middlewares/vpc_security_group_default_test.go +++ b/pkg/middlewares/vpc_security_group_default_test.go @@ -3,8 +3,8 @@ package middlewares import ( "testing" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestDefaultVPCSecurityGroupShouldBeIgnored(t *testing.T) { diff --git a/pkg/middlewares/vpc_security_group_rules.go b/pkg/middlewares/vpc_security_group_rules.go index 75973e95..278a28e4 100644 --- a/pkg/middlewares/vpc_security_group_rules.go +++ b/pkg/middlewares/vpc_security_group_rules.go @@ -3,8 +3,8 @@ package middlewares import ( "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/enumeration/resource" + resourceaws "github.com/snyk/driftctl/enumeration/resource/aws" ) // Split security group rule if it needs to given its attributes diff --git a/pkg/middlewares/vpc_security_group_rules_test.go b/pkg/middlewares/vpc_security_group_rules_test.go index 2f70c735..cee0351d 100644 --- a/pkg/middlewares/vpc_security_group_rules_test.go +++ b/pkg/middlewares/vpc_security_group_rules_test.go @@ -1,13 +1,13 @@ package middlewares import ( + "github.com/snyk/driftctl/enumeration/terraform" "testing" "github.com/stretchr/testify/mock" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) func TestVPCSecurityGroupRuleSanitizer(t *testing.T) { diff --git a/pkg/remote/alerts/alerts.go b/pkg/remote/alerts/alerts.go deleted file mode 100644 index 4bcc0b70..00000000 --- a/pkg/remote/alerts/alerts.go +++ /dev/null @@ -1,96 +0,0 @@ -package alerts - -import ( - "fmt" - - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" -) - -type ScanningPhase int - -const ( - EnumerationPhase ScanningPhase = iota - DetailsFetchingPhase -) - -type RemoteAccessDeniedAlert struct { - message string - provider string - scanningPhase ScanningPhase -} - -func NewRemoteAccessDeniedAlert(provider string, scanErr *remoteerror.ResourceScanningError, scanningPhase ScanningPhase) *RemoteAccessDeniedAlert { - var message string - switch scanningPhase { - case EnumerationPhase: - message = fmt.Sprintf( - "Ignoring %s from drift calculation: Listing %s is forbidden: %s", - scanErr.Resource(), - scanErr.ListedTypeError(), - scanErr.RootCause().Error(), - ) - case DetailsFetchingPhase: - message = fmt.Sprintf( - "Ignoring %s from drift calculation: Reading details of %s is forbidden: %s", - scanErr.Resource(), - scanErr.ListedTypeError(), - scanErr.RootCause().Error(), - ) - default: - message = fmt.Sprintf( - "Ignoring %s from drift calculation: %s", - scanErr.Resource(), - scanErr.RootCause().Error(), - ) - } - return &RemoteAccessDeniedAlert{message, provider, scanningPhase} -} - -func (e *RemoteAccessDeniedAlert) Message() string { - return e.message -} - -func (e *RemoteAccessDeniedAlert) ShouldIgnoreResource() bool { - return true -} - -func (e *RemoteAccessDeniedAlert) GetProviderMessage() string { - var message string - if e.scanningPhase == DetailsFetchingPhase { - message = "It seems that we got access denied exceptions while reading details of resources.\n" - } - if e.scanningPhase == EnumerationPhase { - message = "It seems that we got access denied exceptions while listing resources.\n" - } - - switch e.provider { - case common.RemoteGithubTerraform: - message += "Please be sure that your Github token has the right permissions, check the last up-to-date documentation there: https://docs.driftctl.com/github/policy" - case common.RemoteAWSTerraform: - message += "The latest minimal read-only IAM policy for driftctl is always available here, please update yours: https://docs.driftctl.com/aws/policy" - case common.RemoteGoogleTerraform: - message += "Please ensure that you have configured the required roles, please check our documentation at https://docs.driftctl.com/google/policy" - default: - return "" - } - return message -} - -func sendRemoteAccessDeniedAlert(provider string, alerter alerter.AlerterInterface, listError *remoteerror.ResourceScanningError, p ScanningPhase) { - logrus.WithFields(logrus.Fields{ - "resource": listError.Resource(), - "listed_type": listError.ListedTypeError(), - }).Debugf("Got an access denied error: %+v", listError.Error()) - alerter.SendAlert(listError.Resource(), NewRemoteAccessDeniedAlert(provider, listError, p)) -} - -func SendEnumerationAlert(provider string, alerter alerter.AlerterInterface, listError *remoteerror.ResourceScanningError) { - sendRemoteAccessDeniedAlert(provider, alerter, listError, EnumerationPhase) -} - -func SendDetailsFetchingAlert(provider string, alerter alerter.AlerterInterface, listError *remoteerror.ResourceScanningError) { - sendRemoteAccessDeniedAlert(provider, alerter, listError, DetailsFetchingPhase) -} diff --git a/pkg/remote/aws/api_gateway_account_enumerator.go b/pkg/remote/aws/api_gateway_account_enumerator.go deleted file mode 100644 index dec86260..00000000 --- a/pkg/remote/aws/api_gateway_account_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayAccountEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayAccountEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayAccountEnumerator { - return &ApiGatewayAccountEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayAccountEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayAccountResourceType -} - -func (e *ApiGatewayAccountEnumerator) Enumerate() ([]*resource.Resource, error) { - account, err := e.repository.GetAccount() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, 1) - - if account != nil { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - "api-gateway-account", - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/api_gateway_api_key_enumerator.go b/pkg/remote/aws/api_gateway_api_key_enumerator.go deleted file mode 100644 index 1f6b1101..00000000 --- a/pkg/remote/aws/api_gateway_api_key_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayApiKeyEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayApiKeyEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayApiKeyEnumerator { - return &ApiGatewayApiKeyEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayApiKeyEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayApiKeyResourceType -} - -func (e *ApiGatewayApiKeyEnumerator) Enumerate() ([]*resource.Resource, error) { - keys, err := e.repository.ListAllApiKeys() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(keys)) - - for _, key := range keys { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *key.Id, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/api_gateway_authorizer_enumerator.go b/pkg/remote/aws/api_gateway_authorizer_enumerator.go deleted file mode 100644 index a652e6a3..00000000 --- a/pkg/remote/aws/api_gateway_authorizer_enumerator.go +++ /dev/null @@ -1,56 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayAuthorizerEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayAuthorizerEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayAuthorizerEnumerator { - return &ApiGatewayAuthorizerEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayAuthorizerEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayAuthorizerResourceType -} - -func (e *ApiGatewayAuthorizerEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllRestApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, api := range apis { - a := api - authorizers, err := e.repository.ListAllRestApiAuthorizers(*a.Id) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, authorizer := range authorizers { - au := authorizer - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *au.Id, - map[string]interface{}{}, - ), - ) - } - - } - - return results, err -} diff --git a/pkg/remote/aws/api_gateway_base_path_mapping_enumerator.go b/pkg/remote/aws/api_gateway_base_path_mapping_enumerator.go deleted file mode 100644 index 868996f9..00000000 --- a/pkg/remote/aws/api_gateway_base_path_mapping_enumerator.go +++ /dev/null @@ -1,64 +0,0 @@ -package aws - -import ( - "strings" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayBasePathMappingEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayBasePathMappingEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayBasePathMappingEnumerator { - return &ApiGatewayBasePathMappingEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayBasePathMappingEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayBasePathMappingResourceType -} - -func (e *ApiGatewayBasePathMappingEnumerator) Enumerate() ([]*resource.Resource, error) { - domainNames, err := e.repository.ListAllDomainNames() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayDomainNameResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, domainName := range domainNames { - d := domainName - mappings, err := e.repository.ListAllDomainNameBasePathMappings(*d.DomainName) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, mapping := range mappings { - m := mapping - - basePath := "" - if m.BasePath != nil && *m.BasePath != "(none)" { - basePath = *m.BasePath - } - - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - strings.Join([]string{*d.DomainName, basePath}, "/"), - map[string]interface{}{}, - ), - ) - } - - } - - return results, err -} diff --git a/pkg/remote/aws/api_gateway_domain_name_enumerator.go b/pkg/remote/aws/api_gateway_domain_name_enumerator.go deleted file mode 100644 index b5984cdc..00000000 --- a/pkg/remote/aws/api_gateway_domain_name_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayDomainNameEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayDomainNameEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayDomainNameEnumerator { - return &ApiGatewayDomainNameEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayDomainNameEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayDomainNameResourceType -} - -func (e *ApiGatewayDomainNameEnumerator) Enumerate() ([]*resource.Resource, error) { - domainNames, err := e.repository.ListAllDomainNames() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(domainNames)) - - for _, domainName := range domainNames { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *domainName.DomainName, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/api_gateway_gateway_response_enumerator.go b/pkg/remote/aws/api_gateway_gateway_response_enumerator.go deleted file mode 100644 index 2dd406fb..00000000 --- a/pkg/remote/aws/api_gateway_gateway_response_enumerator.go +++ /dev/null @@ -1,57 +0,0 @@ -package aws - -import ( - "strings" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayGatewayResponseEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayGatewayResponseEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayGatewayResponseEnumerator { - return &ApiGatewayGatewayResponseEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayGatewayResponseEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayGatewayResponseResourceType -} - -func (e *ApiGatewayGatewayResponseEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllRestApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, api := range apis { - a := api - gtwResponses, err := e.repository.ListAllRestApiGatewayResponses(*a.Id) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, gtwResponse := range gtwResponses { - g := gtwResponse - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - strings.Join([]string{"aggr", *a.Id, *g.ResponseType}, "-"), - map[string]interface{}{}, - ), - ) - } - - } - return results, err -} diff --git a/pkg/remote/aws/api_gateway_integration_enumerator.go b/pkg/remote/aws/api_gateway_integration_enumerator.go deleted file mode 100644 index eef94939..00000000 --- a/pkg/remote/aws/api_gateway_integration_enumerator.go +++ /dev/null @@ -1,59 +0,0 @@ -package aws - -import ( - "strings" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayIntegrationEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayIntegrationEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayIntegrationEnumerator { - return &ApiGatewayIntegrationEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayIntegrationEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayIntegrationResourceType -} - -func (e *ApiGatewayIntegrationEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllRestApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, api := range apis { - a := api - resources, err := e.repository.ListAllRestApiResources(*a.Id) - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayResourceResourceType) - } - - for _, resource := range resources { - r := resource - for httpMethod := range r.ResourceMethods { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - strings.Join([]string{"agi", *a.Id, *r.Id, httpMethod}, "-"), - map[string]interface{}{}, - ), - ) - } - } - } - - return results, err -} diff --git a/pkg/remote/aws/api_gateway_integration_response_enumerator.go b/pkg/remote/aws/api_gateway_integration_response_enumerator.go deleted file mode 100644 index 2d138231..00000000 --- a/pkg/remote/aws/api_gateway_integration_response_enumerator.go +++ /dev/null @@ -1,63 +0,0 @@ -package aws - -import ( - "strings" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayIntegrationResponseEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayIntegrationResponseEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayIntegrationResponseEnumerator { - return &ApiGatewayIntegrationResponseEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayIntegrationResponseEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayIntegrationResponseResourceType -} - -func (e *ApiGatewayIntegrationResponseEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllRestApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, api := range apis { - a := api - resources, err := e.repository.ListAllRestApiResources(*a.Id) - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayResourceResourceType) - } - - for _, resource := range resources { - r := resource - for httpMethod, method := range r.ResourceMethods { - if method.MethodIntegration != nil { - for statusCode := range method.MethodIntegration.IntegrationResponses { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - strings.Join([]string{"agir", *a.Id, *r.Id, httpMethod, statusCode}, "-"), - map[string]interface{}{}, - ), - ) - } - } - } - } - } - - return results, err -} diff --git a/pkg/remote/aws/api_gateway_method_enumerator.go b/pkg/remote/aws/api_gateway_method_enumerator.go deleted file mode 100644 index f8c3443b..00000000 --- a/pkg/remote/aws/api_gateway_method_enumerator.go +++ /dev/null @@ -1,59 +0,0 @@ -package aws - -import ( - "strings" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayMethodEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayMethodEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayMethodEnumerator { - return &ApiGatewayMethodEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayMethodEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayMethodResourceType -} - -func (e *ApiGatewayMethodEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllRestApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, api := range apis { - a := api - resources, err := e.repository.ListAllRestApiResources(*a.Id) - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayResourceResourceType) - } - - for _, resource := range resources { - r := resource - for method := range r.ResourceMethods { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - strings.Join([]string{"agm", *a.Id, *r.Id, method}, "-"), - map[string]interface{}{}, - ), - ) - } - } - } - - return results, err -} diff --git a/pkg/remote/aws/api_gateway_method_response_enumerator.go b/pkg/remote/aws/api_gateway_method_response_enumerator.go deleted file mode 100644 index b56e7011..00000000 --- a/pkg/remote/aws/api_gateway_method_response_enumerator.go +++ /dev/null @@ -1,61 +0,0 @@ -package aws - -import ( - "strings" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayMethodResponseEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayMethodResponseEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayMethodResponseEnumerator { - return &ApiGatewayMethodResponseEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayMethodResponseEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayMethodResponseResourceType -} - -func (e *ApiGatewayMethodResponseEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllRestApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, api := range apis { - a := api - resources, err := e.repository.ListAllRestApiResources(*a.Id) - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayResourceResourceType) - } - - for _, resource := range resources { - r := resource - for httpMethod, method := range r.ResourceMethods { - for statusCode := range method.MethodResponses { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - strings.Join([]string{"agmr", *a.Id, *r.Id, httpMethod, statusCode}, "-"), - map[string]interface{}{}, - ), - ) - } - } - } - } - - return results, err -} diff --git a/pkg/remote/aws/api_gateway_method_settings_enumerator.go b/pkg/remote/aws/api_gateway_method_settings_enumerator.go deleted file mode 100644 index 219a5976..00000000 --- a/pkg/remote/aws/api_gateway_method_settings_enumerator.go +++ /dev/null @@ -1,59 +0,0 @@ -package aws - -import ( - "strings" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayMethodSettingsEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayMethodSettingsEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayMethodSettingsEnumerator { - return &ApiGatewayMethodSettingsEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayMethodSettingsEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayMethodSettingsResourceType -} - -func (e *ApiGatewayMethodSettingsEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllRestApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, api := range apis { - a := api - stages, err := e.repository.ListAllRestApiStages(*a.Id) - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayStageResourceType) - } - - for _, stage := range stages { - s := stage - for methodPath := range s.MethodSettings { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - strings.Join([]string{*a.Id, *s.StageName, methodPath}, "-"), - map[string]interface{}{}, - ), - ) - } - } - } - - return results, err -} diff --git a/pkg/remote/aws/api_gateway_model_enumerator.go b/pkg/remote/aws/api_gateway_model_enumerator.go deleted file mode 100644 index a959cd87..00000000 --- a/pkg/remote/aws/api_gateway_model_enumerator.go +++ /dev/null @@ -1,55 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayModelEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayModelEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayModelEnumerator { - return &ApiGatewayModelEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayModelEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayModelResourceType -} - -func (e *ApiGatewayModelEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllRestApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, api := range apis { - a := api - models, err := e.repository.ListAllRestApiModels(*a.Id) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, model := range models { - m := model - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *m.Id, - map[string]interface{}{}, - ), - ) - } - } - - return results, err -} diff --git a/pkg/remote/aws/api_gateway_request_validator_enumerator.go b/pkg/remote/aws/api_gateway_request_validator_enumerator.go deleted file mode 100644 index 05cdf506..00000000 --- a/pkg/remote/aws/api_gateway_request_validator_enumerator.go +++ /dev/null @@ -1,55 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayRequestValidatorEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayRequestValidatorEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayRequestValidatorEnumerator { - return &ApiGatewayRequestValidatorEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayRequestValidatorEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayRequestValidatorResourceType -} - -func (e *ApiGatewayRequestValidatorEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllRestApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, api := range apis { - a := api - requestValidators, err := e.repository.ListAllRestApiRequestValidators(*a.Id) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, requestValidator := range requestValidators { - r := requestValidator - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *r.Id, - map[string]interface{}{}, - ), - ) - } - - } - return results, err -} diff --git a/pkg/remote/aws/api_gateway_resource_enumerator.go b/pkg/remote/aws/api_gateway_resource_enumerator.go deleted file mode 100644 index 8f47afb4..00000000 --- a/pkg/remote/aws/api_gateway_resource_enumerator.go +++ /dev/null @@ -1,58 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayResourceEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayResourceEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayResourceEnumerator { - return &ApiGatewayResourceEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayResourceEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayResourceResourceType -} - -func (e *ApiGatewayResourceEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllRestApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, api := range apis { - a := api - resources, err := e.repository.ListAllRestApiResources(*a.Id) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, resource := range resources { - r := resource - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *r.Id, - map[string]interface{}{ - "rest_api_id": *a.Id, - "path": *r.Path, - }, - ), - ) - } - } - - return results, err -} diff --git a/pkg/remote/aws/api_gateway_rest_api_enumerator.go b/pkg/remote/aws/api_gateway_rest_api_enumerator.go deleted file mode 100644 index 68b24fda..00000000 --- a/pkg/remote/aws/api_gateway_rest_api_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayRestApiEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayRestApiEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayRestApiEnumerator { - return &ApiGatewayRestApiEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayRestApiEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayRestApiResourceType -} - -func (e *ApiGatewayRestApiEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllRestApis() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(apis)) - - for _, api := range apis { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *api.Id, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/api_gateway_rest_api_policy_enumerator.go b/pkg/remote/aws/api_gateway_rest_api_policy_enumerator.go deleted file mode 100644 index 4520c153..00000000 --- a/pkg/remote/aws/api_gateway_rest_api_policy_enumerator.go +++ /dev/null @@ -1,49 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayRestApiPolicyEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayRestApiPolicyEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayRestApiPolicyEnumerator { - return &ApiGatewayRestApiPolicyEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayRestApiPolicyEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayRestApiPolicyResourceType -} - -func (e *ApiGatewayRestApiPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllRestApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, api := range apis { - a := api - if a.Policy == nil || *a.Policy == "" { - continue - } - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *a.Id, - map[string]interface{}{}, - ), - ) - } - return results, err -} diff --git a/pkg/remote/aws/api_gateway_stage_enumerator.go b/pkg/remote/aws/api_gateway_stage_enumerator.go deleted file mode 100644 index 79dc6363..00000000 --- a/pkg/remote/aws/api_gateway_stage_enumerator.go +++ /dev/null @@ -1,57 +0,0 @@ -package aws - -import ( - "strings" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayStageEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayStageEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayStageEnumerator { - return &ApiGatewayStageEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayStageEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayStageResourceType -} - -func (e *ApiGatewayStageEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllRestApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, api := range apis { - a := api - stages, err := e.repository.ListAllRestApiStages(*a.Id) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, stage := range stages { - s := stage - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - strings.Join([]string{"ags", *a.Id, *s.StageName}, "-"), - map[string]interface{}{}, - ), - ) - } - - } - return results, err -} diff --git a/pkg/remote/aws/api_gateway_vpc_link_enumerator.go b/pkg/remote/aws/api_gateway_vpc_link_enumerator.go deleted file mode 100644 index ad24665a..00000000 --- a/pkg/remote/aws/api_gateway_vpc_link_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayVpcLinkEnumerator struct { - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayVpcLinkEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayVpcLinkEnumerator { - return &ApiGatewayVpcLinkEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayVpcLinkEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayVpcLinkResourceType -} - -func (e *ApiGatewayVpcLinkEnumerator) Enumerate() ([]*resource.Resource, error) { - vpcLinks, err := e.repository.ListAllVpcLinks() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(vpcLinks)) - - for _, vpcLink := range vpcLinks { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *vpcLink.Id, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/apigatewayv2_api_enumerator.go b/pkg/remote/aws/apigatewayv2_api_enumerator.go deleted file mode 100644 index d4eefee9..00000000 --- a/pkg/remote/aws/apigatewayv2_api_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayV2ApiEnumerator struct { - repository repository.ApiGatewayV2Repository - factory resource.ResourceFactory -} - -func NewApiGatewayV2ApiEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2ApiEnumerator { - return &ApiGatewayV2ApiEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayV2ApiEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayV2ApiResourceType -} - -func (e *ApiGatewayV2ApiEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllApis() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(apis)) - - for _, api := range apis { - a := api - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *a.ApiId, - map[string]interface{}{}, - ), - ) - } - return results, err -} diff --git a/pkg/remote/aws/apigatewayv2_authorizer_enumerator.go b/pkg/remote/aws/apigatewayv2_authorizer_enumerator.go deleted file mode 100644 index cff25da2..00000000 --- a/pkg/remote/aws/apigatewayv2_authorizer_enumerator.go +++ /dev/null @@ -1,56 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayV2AuthorizerEnumerator struct { - repository repository.ApiGatewayV2Repository - factory resource.ResourceFactory -} - -func NewApiGatewayV2AuthorizerEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2AuthorizerEnumerator { - return &ApiGatewayV2AuthorizerEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayV2AuthorizerEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayV2AuthorizerResourceType -} - -func (e *ApiGatewayV2AuthorizerEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, api := range apis { - a := api - authorizers, err := e.repository.ListAllApiAuthorizers(*a.ApiId) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, authorizer := range authorizers { - au := authorizer - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *au.AuthorizerId, - map[string]interface{}{}, - ), - ) - } - - } - - return results, err -} diff --git a/pkg/remote/aws/apigatewayv2_deployment_enumerator.go b/pkg/remote/aws/apigatewayv2_deployment_enumerator.go deleted file mode 100644 index 4a66f92d..00000000 --- a/pkg/remote/aws/apigatewayv2_deployment_enumerator.go +++ /dev/null @@ -1,50 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayV2DeploymentEnumerator struct { - repository repository.ApiGatewayV2Repository - factory resource.ResourceFactory -} - -func NewApiGatewayV2DeploymentEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2DeploymentEnumerator { - return &ApiGatewayV2DeploymentEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayV2DeploymentEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayV2DeploymentResourceType -} - -func (e *ApiGatewayV2DeploymentEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) - } - - var results []*resource.Resource - for _, api := range apis { - deployments, err := e.repository.ListAllApiDeployments(api.ApiId) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - for _, deployment := range deployments { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *deployment.DeploymentId, - map[string]interface{}{}, - ), - ) - } - } - return results, err -} diff --git a/pkg/remote/aws/apigatewayv2_domain_name_enumerator.go b/pkg/remote/aws/apigatewayv2_domain_name_enumerator.go deleted file mode 100644 index 64b70bb8..00000000 --- a/pkg/remote/aws/apigatewayv2_domain_name_enumerator.go +++ /dev/null @@ -1,49 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayV2DomainNameEnumerator struct { - // AWS SDK list domain names endpoint from API Gateway v2 returns the - // same results as the v1 one, thus let's re-use the method from - // the API Gateway v1 - repository repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayV2DomainNameEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayV2DomainNameEnumerator { - return &ApiGatewayV2DomainNameEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayV2DomainNameEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayV2DomainNameResourceType -} - -func (e *ApiGatewayV2DomainNameEnumerator) Enumerate() ([]*resource.Resource, error) { - domainNames, err := e.repository.ListAllDomainNames() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(domainNames)) - - for _, domainName := range domainNames { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *domainName.DomainName, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/apigatewayv2_integration_enumerator.go b/pkg/remote/aws/apigatewayv2_integration_enumerator.go deleted file mode 100644 index 7b7434ba..00000000 --- a/pkg/remote/aws/apigatewayv2_integration_enumerator.go +++ /dev/null @@ -1,63 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayV2IntegrationEnumerator struct { - repository repository.ApiGatewayV2Repository - factory resource.ResourceFactory -} - -func NewApiGatewayV2IntegrationEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2IntegrationEnumerator { - return &ApiGatewayV2IntegrationEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayV2IntegrationEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayV2IntegrationResourceType -} - -func (e *ApiGatewayV2IntegrationEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, a := range apis { - api := a - integrations, err := e.repository.ListAllApiIntegrations(*api.ApiId) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, integration := range integrations { - data := map[string]interface{}{ - "api_id": *api.ApiId, - "integration_type": *integration.IntegrationType, - } - - if integration.IntegrationMethod != nil { - // this is needed to discriminate in middleware. But it is nil when the type is mock... - data["integration_method"] = *integration.IntegrationMethod - } - - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *integration.IntegrationId, - data, - ), - ) - } - } - return results, err -} diff --git a/pkg/remote/aws/apigatewayv2_integration_response_enumerator.go b/pkg/remote/aws/apigatewayv2_integration_response_enumerator.go deleted file mode 100644 index 0b2976a7..00000000 --- a/pkg/remote/aws/apigatewayv2_integration_response_enumerator.go +++ /dev/null @@ -1,63 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayV2IntegrationResponseEnumerator struct { - repository repository.ApiGatewayV2Repository - factory resource.ResourceFactory -} - -func NewApiGatewayV2IntegrationResponseEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2IntegrationResponseEnumerator { - return &ApiGatewayV2IntegrationResponseEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayV2IntegrationResponseEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayV2IntegrationResponseResourceType -} - -func (e *ApiGatewayV2IntegrationResponseEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, a := range apis { - apiID := *a.ApiId - integrations, err := e.repository.ListAllApiIntegrations(apiID) - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2IntegrationResourceType) - } - - for _, integration := range integrations { - integrationId := *integration.IntegrationId - responses, err := e.repository.ListAllApiIntegrationResponses(apiID, integrationId) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, resp := range responses { - responseId := *resp.IntegrationResponseId - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - responseId, - map[string]interface{}{}, - ), - ) - } - - } - } - return results, err -} diff --git a/pkg/remote/aws/apigatewayv2_mapping_enumerator.go b/pkg/remote/aws/apigatewayv2_mapping_enumerator.go deleted file mode 100644 index 0bc1e876..00000000 --- a/pkg/remote/aws/apigatewayv2_mapping_enumerator.go +++ /dev/null @@ -1,61 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayV2MappingEnumerator struct { - repository repository.ApiGatewayV2Repository - repositoryV1 repository.ApiGatewayRepository - factory resource.ResourceFactory -} - -func NewApiGatewayV2MappingEnumerator(repo repository.ApiGatewayV2Repository, repov1 repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayV2MappingEnumerator { - return &ApiGatewayV2MappingEnumerator{ - repository: repo, - repositoryV1: repov1, - factory: factory, - } -} - -func (e *ApiGatewayV2MappingEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayV2MappingResourceType -} - -func (e *ApiGatewayV2MappingEnumerator) Enumerate() ([]*resource.Resource, error) { - domainNames, err := e.repositoryV1.ListAllDomainNames() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayDomainNameResourceType) - } - - var results []*resource.Resource - for _, domainName := range domainNames { - mappings, err := e.repository.ListAllApiMappings(*domainName.DomainName) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - for _, mapping := range mappings { - attrs := make(map[string]interface{}) - - if mapping.ApiId != nil { - attrs["api_id"] = *mapping.ApiId - } - if mapping.Stage != nil { - attrs["stage"] = *mapping.Stage - } - - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *mapping.ApiMappingId, - attrs, - ), - ) - } - } - return results, err -} diff --git a/pkg/remote/aws/apigatewayv2_model_enumerator.go b/pkg/remote/aws/apigatewayv2_model_enumerator.go deleted file mode 100644 index 64f7f0eb..00000000 --- a/pkg/remote/aws/apigatewayv2_model_enumerator.go +++ /dev/null @@ -1,52 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayV2ModelEnumerator struct { - repository repository.ApiGatewayV2Repository - factory resource.ResourceFactory -} - -func NewApiGatewayV2ModelEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2ModelEnumerator { - return &ApiGatewayV2ModelEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayV2ModelEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayV2ModelResourceType -} - -func (e *ApiGatewayV2ModelEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) - } - - var results []*resource.Resource - for _, api := range apis { - models, err := e.repository.ListAllApiModels(*api.ApiId) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - for _, model := range models { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *model.ModelId, - map[string]interface{}{ - "name": *model.Name, - }, - ), - ) - } - } - return results, err -} diff --git a/pkg/remote/aws/apigatewayv2_route_enumerator.go b/pkg/remote/aws/apigatewayv2_route_enumerator.go deleted file mode 100644 index ac41a8a8..00000000 --- a/pkg/remote/aws/apigatewayv2_route_enumerator.go +++ /dev/null @@ -1,53 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayV2RouteEnumerator struct { - repository repository.ApiGatewayV2Repository - factory resource.ResourceFactory -} - -func NewApiGatewayV2RouteEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2RouteEnumerator { - return &ApiGatewayV2RouteEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayV2RouteEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayV2RouteResourceType -} - -func (e *ApiGatewayV2RouteEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) - } - - var results []*resource.Resource - for _, api := range apis { - routes, err := e.repository.ListAllApiRoutes(api.ApiId) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - for _, route := range routes { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *route.RouteId, - map[string]interface{}{ - "api_id": *api.ApiId, - "route_key": *route.RouteKey, - }, - ), - ) - } - } - return results, err -} diff --git a/pkg/remote/aws/apigatewayv2_route_response_enumerator.go b/pkg/remote/aws/apigatewayv2_route_response_enumerator.go deleted file mode 100644 index 4e23bb09..00000000 --- a/pkg/remote/aws/apigatewayv2_route_response_enumerator.go +++ /dev/null @@ -1,59 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayV2RouteResponseEnumerator struct { - repository repository.ApiGatewayV2Repository - factory resource.ResourceFactory -} - -func NewApiGatewayV2RouteResponseEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2RouteResponseEnumerator { - return &ApiGatewayV2RouteResponseEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayV2RouteResponseEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayV2RouteResponseResourceType -} - -func (e *ApiGatewayV2RouteResponseEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) - } - - var results []*resource.Resource - for _, api := range apis { - a := api - routes, err := e.repository.ListAllApiRoutes(a.ApiId) - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2RouteResourceType) - } - for _, route := range routes { - r := route - responses, err := e.repository.ListAllApiRouteResponses(*a.ApiId, *r.RouteId) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - for _, response := range responses { - res := response - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *res.RouteResponseId, - map[string]interface{}{}, - ), - ) - } - } - } - return results, err -} diff --git a/pkg/remote/aws/apigatewayv2_stage_enumerator.go b/pkg/remote/aws/apigatewayv2_stage_enumerator.go deleted file mode 100644 index 8041a036..00000000 --- a/pkg/remote/aws/apigatewayv2_stage_enumerator.go +++ /dev/null @@ -1,54 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayV2StageEnumerator struct { - repository repository.ApiGatewayV2Repository - factory resource.ResourceFactory -} - -func NewApiGatewayV2StageEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2StageEnumerator { - return &ApiGatewayV2StageEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayV2StageEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayV2StageResourceType -} - -func (e *ApiGatewayV2StageEnumerator) Enumerate() ([]*resource.Resource, error) { - apis, err := e.repository.ListAllApis() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, api := range apis { - stages, err := e.repository.ListAllApiStages(*api.ApiId) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, stage := range stages { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *stage.StageName, - map[string]interface{}{}, - ), - ) - } - - } - - return results, err -} diff --git a/pkg/remote/aws/apigatewayv2_vpc_link_enumerator.go b/pkg/remote/aws/apigatewayv2_vpc_link_enumerator.go deleted file mode 100644 index 8f5c4bc0..00000000 --- a/pkg/remote/aws/apigatewayv2_vpc_link_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ApiGatewayV2VpcLinkEnumerator struct { - repository repository.ApiGatewayV2Repository - factory resource.ResourceFactory -} - -func NewApiGatewayV2VpcLinkEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2VpcLinkEnumerator { - return &ApiGatewayV2VpcLinkEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ApiGatewayV2VpcLinkEnumerator) SupportedType() resource.ResourceType { - return aws.AwsApiGatewayV2VpcLinkResourceType -} - -func (e *ApiGatewayV2VpcLinkEnumerator) Enumerate() ([]*resource.Resource, error) { - vpcLinks, err := e.repository.ListAllVpcLinks() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(vpcLinks)) - - for _, vpcLink := range vpcLinks { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *vpcLink.VpcLinkId, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/appautoscaling_policy_enumerator.go b/pkg/remote/aws/appautoscaling_policy_enumerator.go deleted file mode 100644 index e773cee4..00000000 --- a/pkg/remote/aws/appautoscaling_policy_enumerator.go +++ /dev/null @@ -1,53 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type AppAutoscalingPolicyEnumerator struct { - repository repository.AppAutoScalingRepository - factory resource.ResourceFactory -} - -func NewAppAutoscalingPolicyEnumerator(repository repository.AppAutoScalingRepository, factory resource.ResourceFactory) *AppAutoscalingPolicyEnumerator { - return &AppAutoscalingPolicyEnumerator{ - repository, - factory, - } -} - -func (e *AppAutoscalingPolicyEnumerator) SupportedType() resource.ResourceType { - return aws.AwsAppAutoscalingPolicyResourceType -} - -func (e *AppAutoscalingPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { - results := make([]*resource.Resource, 0) - - for _, ns := range e.repository.ServiceNamespaceValues() { - policies, err := e.repository.DescribeScalingPolicies(ns) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, policy := range policies { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *policy.PolicyName, - map[string]interface{}{ - "name": *policy.PolicyName, - "resource_id": *policy.ResourceId, - "scalable_dimension": *policy.ScalableDimension, - "service_namespace": *policy.ServiceNamespace, - }, - ), - ) - } - } - - return results, nil -} diff --git a/pkg/remote/aws/appautoscaling_scheduled_action_enumerator.go b/pkg/remote/aws/appautoscaling_scheduled_action_enumerator.go deleted file mode 100644 index d77874d8..00000000 --- a/pkg/remote/aws/appautoscaling_scheduled_action_enumerator.go +++ /dev/null @@ -1,50 +0,0 @@ -package aws - -import ( - "strings" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type AppAutoscalingScheduledActionEnumerator struct { - repository repository.AppAutoScalingRepository - factory resource.ResourceFactory -} - -func NewAppAutoscalingScheduledActionEnumerator(repository repository.AppAutoScalingRepository, factory resource.ResourceFactory) *AppAutoscalingScheduledActionEnumerator { - return &AppAutoscalingScheduledActionEnumerator{ - repository, - factory, - } -} - -func (e *AppAutoscalingScheduledActionEnumerator) SupportedType() resource.ResourceType { - return aws.AwsAppAutoscalingScheduledActionResourceType -} - -func (e *AppAutoscalingScheduledActionEnumerator) Enumerate() ([]*resource.Resource, error) { - results := make([]*resource.Resource, 0) - - for _, ns := range e.repository.ServiceNamespaceValues() { - actions, err := e.repository.DescribeScheduledActions(ns) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, action := range actions { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - strings.Join([]string{*action.ScheduledActionName, *action.ServiceNamespace, *action.ResourceId}, "-"), - map[string]interface{}{}, - ), - ) - } - } - - return results, nil -} diff --git a/pkg/remote/aws/appautoscaling_target_enumerator.go b/pkg/remote/aws/appautoscaling_target_enumerator.go deleted file mode 100644 index 177f3f1b..00000000 --- a/pkg/remote/aws/appautoscaling_target_enumerator.go +++ /dev/null @@ -1,55 +0,0 @@ -package aws - -import ( - "github.com/aws/aws-sdk-go/service/applicationautoscaling" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type AppAutoscalingTargetEnumerator struct { - repository repository.AppAutoScalingRepository - factory resource.ResourceFactory -} - -func NewAppAutoscalingTargetEnumerator(repository repository.AppAutoScalingRepository, factory resource.ResourceFactory) *AppAutoscalingTargetEnumerator { - return &AppAutoscalingTargetEnumerator{ - repository, - factory, - } -} - -func (e *AppAutoscalingTargetEnumerator) SupportedType() resource.ResourceType { - return aws.AwsAppAutoscalingTargetResourceType -} - -func (e *AppAutoscalingTargetEnumerator) Enumerate() ([]*resource.Resource, error) { - targets := make([]*applicationautoscaling.ScalableTarget, 0) - - for _, ns := range e.repository.ServiceNamespaceValues() { - results, err := e.repository.DescribeScalableTargets(ns) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - targets = append(targets, results...) - } - - results := make([]*resource.Resource, 0, len(targets)) - - for _, target := range targets { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *target.ResourceId, - map[string]interface{}{ - "service_namespace": *target.ServiceNamespace, - "scalable_dimension": *target.ScalableDimension, - }, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/classic_loadbalancer_enumerator.go b/pkg/remote/aws/classic_loadbalancer_enumerator.go deleted file mode 100644 index fa14f3fc..00000000 --- a/pkg/remote/aws/classic_loadbalancer_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ClassicLoadBalancerEnumerator struct { - repository repository.ELBRepository - factory resource.ResourceFactory -} - -func NewClassicLoadBalancerEnumerator(repo repository.ELBRepository, factory resource.ResourceFactory) *ClassicLoadBalancerEnumerator { - return &ClassicLoadBalancerEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ClassicLoadBalancerEnumerator) SupportedType() resource.ResourceType { - return aws.AwsClassicLoadBalancerResourceType -} - -func (e *ClassicLoadBalancerEnumerator) Enumerate() ([]*resource.Resource, error) { - loadBalancers, err := e.repository.ListAllLoadBalancers() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(loadBalancers)) - - for _, lb := range loadBalancers { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *lb.LoadBalancerName, - map[string]interface{}{}, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/cloudformation_stack_enumerator.go b/pkg/remote/aws/cloudformation_stack_enumerator.go deleted file mode 100644 index 7c3d064b..00000000 --- a/pkg/remote/aws/cloudformation_stack_enumerator.go +++ /dev/null @@ -1,60 +0,0 @@ -package aws - -import ( - "github.com/aws/aws-sdk-go/service/cloudformation" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type CloudformationStackEnumerator struct { - repository repository.CloudformationRepository - factory resource.ResourceFactory -} - -func NewCloudformationStackEnumerator(repo repository.CloudformationRepository, factory resource.ResourceFactory) *CloudformationStackEnumerator { - return &CloudformationStackEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *CloudformationStackEnumerator) SupportedType() resource.ResourceType { - return aws.AwsCloudformationStackResourceType -} - -func (e *CloudformationStackEnumerator) Enumerate() ([]*resource.Resource, error) { - stacks, err := e.repository.ListAllStacks() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(stacks)) - - for _, stack := range stacks { - attrs := map[string]interface{}{} - if stack.Parameters != nil && len(stack.Parameters) > 0 { - attrs["parameters"] = flattenParameters(stack.Parameters) - } - - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *stack.StackId, - attrs, - ), - ) - } - - return results, err -} - -func flattenParameters(parameters []*cloudformation.Parameter) interface{} { - params := make(map[string]interface{}, len(parameters)) - for _, p := range parameters { - params[*p.ParameterKey] = *p.ParameterValue - } - return params -} diff --git a/pkg/remote/aws/cloudfront_distribution_enumerator.go b/pkg/remote/aws/cloudfront_distribution_enumerator.go deleted file mode 100644 index 9114e2ac..00000000 --- a/pkg/remote/aws/cloudfront_distribution_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type CloudfrontDistributionEnumerator struct { - repository repository.CloudfrontRepository - factory resource.ResourceFactory -} - -func NewCloudfrontDistributionEnumerator(repo repository.CloudfrontRepository, factory resource.ResourceFactory) *CloudfrontDistributionEnumerator { - return &CloudfrontDistributionEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *CloudfrontDistributionEnumerator) SupportedType() resource.ResourceType { - return aws.AwsCloudfrontDistributionResourceType -} - -func (e *CloudfrontDistributionEnumerator) Enumerate() ([]*resource.Resource, error) { - distributions, err := e.repository.ListAllDistributions() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(distributions)) - - for _, distribution := range distributions { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *distribution.Id, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/default_vpc_enumerator.go b/pkg/remote/aws/default_vpc_enumerator.go deleted file mode 100644 index 31ba4508..00000000 --- a/pkg/remote/aws/default_vpc_enumerator.go +++ /dev/null @@ -1,48 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - - "github.com/snyk/driftctl/pkg/resource/aws" - - "github.com/snyk/driftctl/pkg/resource" -) - -type DefaultVPCEnumerator struct { - repo repository.EC2Repository - factory resource.ResourceFactory -} - -func NewDefaultVPCEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *DefaultVPCEnumerator { - return &DefaultVPCEnumerator{ - repo, - factory, - } -} - -func (e *DefaultVPCEnumerator) SupportedType() resource.ResourceType { - return aws.AwsDefaultVpcResourceType -} - -func (e *DefaultVPCEnumerator) Enumerate() ([]*resource.Resource, error) { - _, defaultVPCs, err := e.repo.ListAllVPCs() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(defaultVPCs)) - - for _, item := range defaultVPCs { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *item.VpcId, - map[string]interface{}{}, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/dynamodb_table_enumerator.go b/pkg/remote/aws/dynamodb_table_enumerator.go deleted file mode 100644 index 4085712f..00000000 --- a/pkg/remote/aws/dynamodb_table_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type DynamoDBTableEnumerator struct { - repository repository.DynamoDBRepository - factory resource.ResourceFactory -} - -func NewDynamoDBTableEnumerator(repository repository.DynamoDBRepository, factory resource.ResourceFactory) *DynamoDBTableEnumerator { - return &DynamoDBTableEnumerator{ - repository, - factory, - } -} - -func (e *DynamoDBTableEnumerator) SupportedType() resource.ResourceType { - return aws.AwsDynamodbTableResourceType -} - -func (e *DynamoDBTableEnumerator) Enumerate() ([]*resource.Resource, error) { - tables, err := e.repository.ListAllTables() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(tables)) - - for _, table := range tables { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *table, - map[string]interface{}{}, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/ebs_encryption_by_default_enumerator.go b/pkg/remote/aws/ebs_encryption_by_default_enumerator.go deleted file mode 100644 index 8ccf8937..00000000 --- a/pkg/remote/aws/ebs_encryption_by_default_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2EbsEncryptionByDefaultEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2EbsEncryptionByDefaultEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2EbsEncryptionByDefaultEnumerator { - return &EC2EbsEncryptionByDefaultEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2EbsEncryptionByDefaultEnumerator) SupportedType() resource.ResourceType { - return aws.AwsEbsEncryptionByDefaultResourceType -} - -func (e *EC2EbsEncryptionByDefaultEnumerator) Enumerate() ([]*resource.Resource, error) { - enabled, err := e.repository.IsEbsEncryptionEnabledByDefault() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0) - - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - "ebs_encryption_default", - map[string]interface{}{ - "enabled": enabled, - }, - ), - ) - - return results, err -} diff --git a/pkg/remote/aws/ec2_ami_enumerator.go b/pkg/remote/aws/ec2_ami_enumerator.go deleted file mode 100644 index fe4b5baa..00000000 --- a/pkg/remote/aws/ec2_ami_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2AmiEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2AmiEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2AmiEnumerator { - return &EC2AmiEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2AmiEnumerator) SupportedType() resource.ResourceType { - return aws.AwsAmiResourceType -} - -func (e *EC2AmiEnumerator) Enumerate() ([]*resource.Resource, error) { - images, err := e.repository.ListAllImages() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(images)) - - for _, image := range images { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *image.ImageId, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/ec2_default_network_acl_enumerator.go b/pkg/remote/aws/ec2_default_network_acl_enumerator.go deleted file mode 100644 index 3d268a67..00000000 --- a/pkg/remote/aws/ec2_default_network_acl_enumerator.go +++ /dev/null @@ -1,50 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2DefaultNetworkACLEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2DefaultNetworkACLEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2DefaultNetworkACLEnumerator { - return &EC2DefaultNetworkACLEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2DefaultNetworkACLEnumerator) SupportedType() resource.ResourceType { - return aws.AwsDefaultNetworkACLResourceType -} - -func (e *EC2DefaultNetworkACLEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.ListAllNetworkACLs() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - // Do not handle non-default network acl since it is a dedicated resource - if !*res.IsDefault { - continue - } - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *res.NetworkAclId, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/ec2_default_route_table_enumerator.go b/pkg/remote/aws/ec2_default_route_table_enumerator.go deleted file mode 100644 index 6bdf4514..00000000 --- a/pkg/remote/aws/ec2_default_route_table_enumerator.go +++ /dev/null @@ -1,50 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2DefaultRouteTableEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2DefaultRouteTableEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2DefaultRouteTableEnumerator { - return &EC2DefaultRouteTableEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2DefaultRouteTableEnumerator) SupportedType() resource.ResourceType { - return aws.AwsDefaultRouteTableResourceType -} - -func (e *EC2DefaultRouteTableEnumerator) Enumerate() ([]*resource.Resource, error) { - routeTables, err := e.repository.ListAllRouteTables() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - var results []*resource.Resource - - for _, routeTable := range routeTables { - if isMainRouteTable(routeTable) { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *routeTable.RouteTableId, - map[string]interface{}{ - "vpc_id": *routeTable.VpcId, - }, - ), - ) - } - } - - return results, err -} diff --git a/pkg/remote/aws/ec2_default_subnet_enumerator.go b/pkg/remote/aws/ec2_default_subnet_enumerator.go deleted file mode 100644 index 80979510..00000000 --- a/pkg/remote/aws/ec2_default_subnet_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2DefaultSubnetEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2DefaultSubnetEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2DefaultSubnetEnumerator { - return &EC2DefaultSubnetEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2DefaultSubnetEnumerator) SupportedType() resource.ResourceType { - return aws.AwsDefaultSubnetResourceType -} - -func (e *EC2DefaultSubnetEnumerator) Enumerate() ([]*resource.Resource, error) { - _, defaultSubnets, err := e.repository.ListAllSubnets() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(defaultSubnets)) - - for _, subnet := range defaultSubnets { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *subnet.SubnetId, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/ec2_ebs_snapshot_enumerator.go b/pkg/remote/aws/ec2_ebs_snapshot_enumerator.go deleted file mode 100644 index 5742e908..00000000 --- a/pkg/remote/aws/ec2_ebs_snapshot_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2EbsSnapshotEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2EbsSnapshotEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2EbsSnapshotEnumerator { - return &EC2EbsSnapshotEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2EbsSnapshotEnumerator) SupportedType() resource.ResourceType { - return aws.AwsEbsSnapshotResourceType -} - -func (e *EC2EbsSnapshotEnumerator) Enumerate() ([]*resource.Resource, error) { - snapshots, err := e.repository.ListAllSnapshots() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(snapshots)) - - for _, snapshot := range snapshots { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *snapshot.SnapshotId, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/ec2_ebs_volume_enumerator.go b/pkg/remote/aws/ec2_ebs_volume_enumerator.go deleted file mode 100644 index bfbe6033..00000000 --- a/pkg/remote/aws/ec2_ebs_volume_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2EbsVolumeEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2EbsVolumeEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2EbsVolumeEnumerator { - return &EC2EbsVolumeEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2EbsVolumeEnumerator) SupportedType() resource.ResourceType { - return aws.AwsEbsVolumeResourceType -} - -func (e *EC2EbsVolumeEnumerator) Enumerate() ([]*resource.Resource, error) { - volumes, err := e.repository.ListAllVolumes() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(volumes)) - - for _, volume := range volumes { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *volume.VolumeId, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/ec2_eip_association_enumerator.go b/pkg/remote/aws/ec2_eip_association_enumerator.go deleted file mode 100644 index d533450c..00000000 --- a/pkg/remote/aws/ec2_eip_association_enumerator.go +++ /dev/null @@ -1,48 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2EipAssociationEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2EipAssociationEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2EipAssociationEnumerator { - return &EC2EipAssociationEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2EipAssociationEnumerator) SupportedType() resource.ResourceType { - return aws.AwsEipAssociationResourceType -} - -func (e *EC2EipAssociationEnumerator) Enumerate() ([]*resource.Resource, error) { - addresses, err := e.repository.ListAllAddressesAssociation() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(addresses)) - - for _, address := range addresses { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *address.AssociationId, - map[string]interface{}{ - "allocation_id": *address.AllocationId, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/ec2_eip_enumerator.go b/pkg/remote/aws/ec2_eip_enumerator.go deleted file mode 100644 index 9513c62e..00000000 --- a/pkg/remote/aws/ec2_eip_enumerator.go +++ /dev/null @@ -1,51 +0,0 @@ -package aws - -import ( - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2EipEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2EipEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2EipEnumerator { - return &EC2EipEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2EipEnumerator) SupportedType() resource.ResourceType { - return aws.AwsEipResourceType -} - -func (e *EC2EipEnumerator) Enumerate() ([]*resource.Resource, error) { - addresses, err := e.repository.ListAllAddresses() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(addresses)) - - for _, address := range addresses { - if address.AllocationId == nil { - logrus.Warn("Elastic IP does not have an allocation ID, ignoring") - continue - } - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *address.AllocationId, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/ec2_instance_enumerator.go b/pkg/remote/aws/ec2_instance_enumerator.go deleted file mode 100644 index 80a39c3c..00000000 --- a/pkg/remote/aws/ec2_instance_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2InstanceEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2InstanceEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2InstanceEnumerator { - return &EC2InstanceEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2InstanceEnumerator) SupportedType() resource.ResourceType { - return aws.AwsInstanceResourceType -} - -func (e *EC2InstanceEnumerator) Enumerate() ([]*resource.Resource, error) { - instances, err := e.repository.ListAllInstances() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(instances)) - - for _, instance := range instances { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *instance.InstanceId, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/ec2_internet_gateway_enumerator.go b/pkg/remote/aws/ec2_internet_gateway_enumerator.go deleted file mode 100644 index 7d974851..00000000 --- a/pkg/remote/aws/ec2_internet_gateway_enumerator.go +++ /dev/null @@ -1,50 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2InternetGatewayEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2InternetGatewayEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2InternetGatewayEnumerator { - return &EC2InternetGatewayEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2InternetGatewayEnumerator) SupportedType() resource.ResourceType { - return aws.AwsInternetGatewayResourceType -} - -func (e *EC2InternetGatewayEnumerator) Enumerate() ([]*resource.Resource, error) { - internetGateways, err := e.repository.ListAllInternetGateways() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(internetGateways)) - - for _, internetGateway := range internetGateways { - data := map[string]interface{}{} - if len(internetGateway.Attachments) > 0 && internetGateway.Attachments[0].VpcId != nil { - data["vpc_id"] = *internetGateway.Attachments[0].VpcId - } - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *internetGateway.InternetGatewayId, - data, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/ec2_key_pair_enumerator.go b/pkg/remote/aws/ec2_key_pair_enumerator.go deleted file mode 100644 index 52f95aad..00000000 --- a/pkg/remote/aws/ec2_key_pair_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2KeyPairEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2KeyPairEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2KeyPairEnumerator { - return &EC2KeyPairEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2KeyPairEnumerator) SupportedType() resource.ResourceType { - return aws.AwsKeyPairResourceType -} - -func (e *EC2KeyPairEnumerator) Enumerate() ([]*resource.Resource, error) { - keyPairs, err := e.repository.ListAllKeyPairs() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(keyPairs)) - - for _, keyPair := range keyPairs { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *keyPair.KeyName, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/ec2_nat_gateway_enumerator.go b/pkg/remote/aws/ec2_nat_gateway_enumerator.go deleted file mode 100644 index 46f44294..00000000 --- a/pkg/remote/aws/ec2_nat_gateway_enumerator.go +++ /dev/null @@ -1,54 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2NatGatewayEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2NatGatewayEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2NatGatewayEnumerator { - return &EC2NatGatewayEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2NatGatewayEnumerator) SupportedType() resource.ResourceType { - return aws.AwsNatGatewayResourceType -} - -func (e *EC2NatGatewayEnumerator) Enumerate() ([]*resource.Resource, error) { - natGateways, err := e.repository.ListAllNatGateways() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(natGateways)) - - for _, natGateway := range natGateways { - - attrs := map[string]interface{}{} - if len(natGateway.NatGatewayAddresses) > 0 { - if allocId := natGateway.NatGatewayAddresses[0].AllocationId; allocId != nil { - attrs["allocation_id"] = *allocId - } - } - - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *natGateway.NatGatewayId, - attrs, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/ec2_network_acl_enumerator.go b/pkg/remote/aws/ec2_network_acl_enumerator.go deleted file mode 100644 index 5cae702f..00000000 --- a/pkg/remote/aws/ec2_network_acl_enumerator.go +++ /dev/null @@ -1,50 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2NetworkACLEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2NetworkACLEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2NetworkACLEnumerator { - return &EC2NetworkACLEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2NetworkACLEnumerator) SupportedType() resource.ResourceType { - return aws.AwsNetworkACLResourceType -} - -func (e *EC2NetworkACLEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.ListAllNetworkACLs() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - // Do not handle default network acl since it is a dedicated resource - if *res.IsDefault { - continue - } - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *res.NetworkAclId, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/ec2_network_acl_rule_enumerator.go b/pkg/remote/aws/ec2_network_acl_rule_enumerator.go deleted file mode 100644 index 5f343f51..00000000 --- a/pkg/remote/aws/ec2_network_acl_rule_enumerator.go +++ /dev/null @@ -1,65 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2NetworkACLRuleEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2NetworkACLRuleEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2NetworkACLRuleEnumerator { - return &EC2NetworkACLRuleEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2NetworkACLRuleEnumerator) SupportedType() resource.ResourceType { - return aws.AwsNetworkACLRuleResourceType -} - -func (e *EC2NetworkACLRuleEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.ListAllNetworkACLs() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsNetworkACLResourceType) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - for _, entry := range res.Entries { - - attrs := map[string]interface{}{ - "egress": *entry.Egress, - "network_acl_id": *res.NetworkAclId, - "rule_action": *entry.RuleAction, // Used in default middleware - "rule_number": float64(*entry.RuleNumber), // Used in default middleware - "protocol": *entry.Protocol, // Used in default middleware - } - - if entry.CidrBlock != nil { - attrs["cidr_block"] = *entry.CidrBlock - } - - if entry.Ipv6CidrBlock != nil { - attrs["ipv6_cidr_block"] = *entry.Ipv6CidrBlock - } - - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - "", // Will be computed during normalization - attrs, - ), - ) - } - } - - return results, err -} diff --git a/pkg/remote/aws/ec2_route_enumerator.go b/pkg/remote/aws/ec2_route_enumerator.go deleted file mode 100644 index a6bae4ff..00000000 --- a/pkg/remote/aws/ec2_route_enumerator.go +++ /dev/null @@ -1,66 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2RouteEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2RouteEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2RouteEnumerator { - return &EC2RouteEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2RouteEnumerator) SupportedType() resource.ResourceType { - return aws.AwsRouteResourceType -} - -func (e *EC2RouteEnumerator) Enumerate() ([]*resource.Resource, error) { - routeTables, err := e.repository.ListAllRouteTables() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsRouteTableResourceType) - } - - var results []*resource.Resource - - for _, routeTable := range routeTables { - for _, route := range routeTable.Routes { - routeId := aws.CalculateRouteID(routeTable.RouteTableId, route.DestinationCidrBlock, route.DestinationIpv6CidrBlock, route.DestinationPrefixListId) - data := map[string]interface{}{ - "route_table_id": *routeTable.RouteTableId, - "origin": *route.Origin, - } - if route.DestinationCidrBlock != nil && *route.DestinationCidrBlock != "" { - data["destination_cidr_block"] = *route.DestinationCidrBlock - } - if route.DestinationIpv6CidrBlock != nil && *route.DestinationIpv6CidrBlock != "" { - data["destination_ipv6_cidr_block"] = *route.DestinationIpv6CidrBlock - } - if route.DestinationPrefixListId != nil && *route.DestinationPrefixListId != "" { - data["destination_prefix_list_id"] = *route.DestinationPrefixListId - } - if route.GatewayId != nil && *route.GatewayId != "" { - data["gateway_id"] = *route.GatewayId - } - - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - routeId, - data, - ), - ) - } - } - - return results, err -} diff --git a/pkg/remote/aws/ec2_route_table_association_enumerator.go b/pkg/remote/aws/ec2_route_table_association_enumerator.go deleted file mode 100644 index 723dd2b1..00000000 --- a/pkg/remote/aws/ec2_route_table_association_enumerator.go +++ /dev/null @@ -1,69 +0,0 @@ -package aws - -import ( - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2RouteTableAssociationEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2RouteTableAssociationEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2RouteTableAssociationEnumerator { - return &EC2RouteTableAssociationEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2RouteTableAssociationEnumerator) SupportedType() resource.ResourceType { - return aws.AwsRouteTableAssociationResourceType -} - -func (e *EC2RouteTableAssociationEnumerator) Enumerate() ([]*resource.Resource, error) { - routeTables, err := e.repository.ListAllRouteTables() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsRouteTableResourceType) - } - - var results []*resource.Resource - - for _, routeTable := range routeTables { - for _, assoc := range routeTable.Associations { - if e.shouldBeIgnored(assoc) { - continue - } - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *assoc.RouteTableAssociationId, - map[string]interface{}{ - "route_table_id": *assoc.RouteTableId, - }, - ), - ) - } - } - - return results, err -} - -func (e *EC2RouteTableAssociationEnumerator) shouldBeIgnored(assoc *ec2.RouteTableAssociation) bool { - // Ignore when nothing is associated - if assoc.GatewayId == nil && assoc.SubnetId == nil { - return true - } - - // Ignore when association is not associated - if assoc.AssociationState != nil && assoc.AssociationState.State != nil && - *assoc.AssociationState.State != "associated" { - return true - } - - return false -} diff --git a/pkg/remote/aws/ec2_route_table_enumerator.go b/pkg/remote/aws/ec2_route_table_enumerator.go deleted file mode 100644 index 41e0a677..00000000 --- a/pkg/remote/aws/ec2_route_table_enumerator.go +++ /dev/null @@ -1,58 +0,0 @@ -package aws - -import ( - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2RouteTableEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2RouteTableEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2RouteTableEnumerator { - return &EC2RouteTableEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2RouteTableEnumerator) SupportedType() resource.ResourceType { - return aws.AwsRouteTableResourceType -} - -func (e *EC2RouteTableEnumerator) Enumerate() ([]*resource.Resource, error) { - routeTables, err := e.repository.ListAllRouteTables() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - var results []*resource.Resource - - for _, routeTable := range routeTables { - if !isMainRouteTable(routeTable) { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *routeTable.RouteTableId, - map[string]interface{}{}, - ), - ) - } - } - - return results, err -} - -func isMainRouteTable(routeTable *ec2.RouteTable) bool { - for _, assoc := range routeTable.Associations { - if assoc.Main != nil && *assoc.Main { - return true - } - } - return false -} diff --git a/pkg/remote/aws/ec2_subnet_enumerator.go b/pkg/remote/aws/ec2_subnet_enumerator.go deleted file mode 100644 index cf33875d..00000000 --- a/pkg/remote/aws/ec2_subnet_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type EC2SubnetEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewEC2SubnetEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2SubnetEnumerator { - return &EC2SubnetEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *EC2SubnetEnumerator) SupportedType() resource.ResourceType { - return aws.AwsSubnetResourceType -} - -func (e *EC2SubnetEnumerator) Enumerate() ([]*resource.Resource, error) { - subnets, _, err := e.repository.ListAllSubnets() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(subnets)) - - for _, subnet := range subnets { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *subnet.SubnetId, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/ecr_repository_enumerator.go b/pkg/remote/aws/ecr_repository_enumerator.go deleted file mode 100644 index 22c6cab2..00000000 --- a/pkg/remote/aws/ecr_repository_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ECRRepositoryEnumerator struct { - repository repository.ECRRepository - factory resource.ResourceFactory -} - -func NewECRRepositoryEnumerator(repo repository.ECRRepository, factory resource.ResourceFactory) *ECRRepositoryEnumerator { - return &ECRRepositoryEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ECRRepositoryEnumerator) SupportedType() resource.ResourceType { - return aws.AwsEcrRepositoryResourceType -} - -func (e *ECRRepositoryEnumerator) Enumerate() ([]*resource.Resource, error) { - repos, err := e.repository.ListAllRepositories() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(repos)) - - for _, repo := range repos { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *repo.RepositoryName, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/ecr_repository_policy_enumerator.go b/pkg/remote/aws/ecr_repository_policy_enumerator.go deleted file mode 100644 index c50617bb..00000000 --- a/pkg/remote/aws/ecr_repository_policy_enumerator.go +++ /dev/null @@ -1,55 +0,0 @@ -package aws - -import ( - "github.com/aws/aws-sdk-go/service/ecr" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ECRRepositoryPolicyEnumerator struct { - repository repository.ECRRepository - factory resource.ResourceFactory -} - -func NewECRRepositoryPolicyEnumerator(repo repository.ECRRepository, factory resource.ResourceFactory) *ECRRepositoryPolicyEnumerator { - return &ECRRepositoryPolicyEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ECRRepositoryPolicyEnumerator) SupportedType() resource.ResourceType { - return aws.AwsEcrRepositoryPolicyResourceType -} - -func (e *ECRRepositoryPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { - repos, err := e.repository.ListAllRepositories() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsEcrRepositoryResourceType) - } - - results := make([]*resource.Resource, 0, len(repos)) - - for _, repo := range repos { - repoOutput, err := e.repository.GetRepositoryPolicy(repo) - if _, ok := err.(*ecr.RepositoryPolicyNotFoundException); ok { - continue - } - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *repoOutput.RepositoryName, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/elasticache_cluster_enumerator.go b/pkg/remote/aws/elasticache_cluster_enumerator.go deleted file mode 100644 index 3d41ecb9..00000000 --- a/pkg/remote/aws/elasticache_cluster_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type ElastiCacheClusterEnumerator struct { - repository repository.ElastiCacheRepository - factory resource.ResourceFactory -} - -func NewElastiCacheClusterEnumerator(repo repository.ElastiCacheRepository, factory resource.ResourceFactory) *ElastiCacheClusterEnumerator { - return &ElastiCacheClusterEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *ElastiCacheClusterEnumerator) SupportedType() resource.ResourceType { - return aws.AwsElastiCacheClusterResourceType -} - -func (e *ElastiCacheClusterEnumerator) Enumerate() ([]*resource.Resource, error) { - clusters, err := e.repository.ListAllCacheClusters() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(clusters)) - - for _, cluster := range clusters { - c := cluster - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *c.CacheClusterId, - map[string]interface{}{}, - ), - ) - } - return results, err -} diff --git a/pkg/remote/aws/iam_access_key_enumerator.go b/pkg/remote/aws/iam_access_key_enumerator.go deleted file mode 100644 index 7e2b0fda..00000000 --- a/pkg/remote/aws/iam_access_key_enumerator.go +++ /dev/null @@ -1,52 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" -) - -type IamAccessKeyEnumerator struct { - repository repository.IAMRepository - factory resource.ResourceFactory -} - -func NewIamAccessKeyEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamAccessKeyEnumerator { - return &IamAccessKeyEnumerator{ - repository, - factory, - } -} - -func (e *IamAccessKeyEnumerator) SupportedType() resource.ResourceType { - return resourceaws.AwsIamAccessKeyResourceType -} - -func (e *IamAccessKeyEnumerator) Enumerate() ([]*resource.Resource, error) { - users, err := e.repository.ListAllUsers() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsIamUserResourceType) - } - - keys, err := e.repository.ListAllAccessKeys(users) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0) - for _, key := range keys { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *key.AccessKeyId, - map[string]interface{}{ - "user": *key.UserName, - }, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/iam_group_enumerator.go b/pkg/remote/aws/iam_group_enumerator.go deleted file mode 100644 index e24f424a..00000000 --- a/pkg/remote/aws/iam_group_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type IamGroupEnumerator struct { - repository repository.IAMRepository - factory resource.ResourceFactory -} - -func NewIamGroupEnumerator(repo repository.IAMRepository, factory resource.ResourceFactory) *IamGroupEnumerator { - return &IamGroupEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *IamGroupEnumerator) SupportedType() resource.ResourceType { - return aws.AwsIamGroupResourceType -} - -func (e *IamGroupEnumerator) Enumerate() ([]*resource.Resource, error) { - groups, err := e.repository.ListAllGroups() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsIamGroupResourceType) - } - - results := make([]*resource.Resource, 0, len(groups)) - - for _, group := range groups { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *group.GroupName, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/iam_group_policy_attachment_enumerator.go b/pkg/remote/aws/iam_group_policy_attachment_enumerator.go deleted file mode 100644 index c16dc7d1..00000000 --- a/pkg/remote/aws/iam_group_policy_attachment_enumerator.go +++ /dev/null @@ -1,56 +0,0 @@ -package aws - -import ( - "fmt" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" -) - -type IamGroupPolicyAttachmentEnumerator struct { - repository repository.IAMRepository - factory resource.ResourceFactory -} - -func NewIamGroupPolicyAttachmentEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamGroupPolicyAttachmentEnumerator { - return &IamGroupPolicyAttachmentEnumerator{ - repository, - factory, - } -} - -func (e *IamGroupPolicyAttachmentEnumerator) SupportedType() resource.ResourceType { - return resourceaws.AwsIamGroupPolicyAttachmentResourceType -} - -func (e *IamGroupPolicyAttachmentEnumerator) Enumerate() ([]*resource.Resource, error) { - groups, err := e.repository.ListAllGroups() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsIamGroupResourceType) - } - - results := make([]*resource.Resource, 0) - - policyAttachments, err := e.repository.ListAllGroupPolicyAttachments(groups) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, attachedPol := range policyAttachments { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - fmt.Sprintf("%s-%s", *attachedPol.PolicyName, attachedPol.GroupName), - map[string]interface{}{ - "group": attachedPol.GroupName, - "policy_arn": *attachedPol.PolicyArn, - }, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/iam_group_policy_enumerator.go b/pkg/remote/aws/iam_group_policy_enumerator.go deleted file mode 100644 index e2a9d10c..00000000 --- a/pkg/remote/aws/iam_group_policy_enumerator.go +++ /dev/null @@ -1,50 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type IamGroupPolicyEnumerator struct { - repository repository.IAMRepository - factory resource.ResourceFactory -} - -func NewIamGroupPolicyEnumerator(repo repository.IAMRepository, factory resource.ResourceFactory) *IamGroupPolicyEnumerator { - return &IamGroupPolicyEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *IamGroupPolicyEnumerator) SupportedType() resource.ResourceType { - return aws.AwsIamGroupPolicyResourceType -} - -func (e *IamGroupPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { - groups, err := e.repository.ListAllGroups() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsIamGroupResourceType) - } - groupPolicies, err := e.repository.ListAllGroupPolicies(groups) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(groupPolicies)) - - for _, groupPolicy := range groupPolicies { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - groupPolicy, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/iam_policy_enumerator.go b/pkg/remote/aws/iam_policy_enumerator.go deleted file mode 100644 index e1f8a2f3..00000000 --- a/pkg/remote/aws/iam_policy_enumerator.go +++ /dev/null @@ -1,47 +0,0 @@ -package aws - -import ( - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type IamPolicyEnumerator struct { - repository repository.IAMRepository - factory resource.ResourceFactory -} - -func NewIamPolicyEnumerator(repo repository.IAMRepository, factory resource.ResourceFactory) *IamPolicyEnumerator { - return &IamPolicyEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *IamPolicyEnumerator) SupportedType() resource.ResourceType { - return aws.AwsIamPolicyResourceType -} - -func (e *IamPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { - policies, err := e.repository.ListAllPolicies() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(policies)) - - for _, policy := range policies { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - awssdk.StringValue(policy.Arn), - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/iam_role_enumerator.go b/pkg/remote/aws/iam_role_enumerator.go deleted file mode 100644 index 65b2db8e..00000000 --- a/pkg/remote/aws/iam_role_enumerator.go +++ /dev/null @@ -1,65 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" -) - -var iamRoleExclusionList = map[string]struct{}{ - // Enabled by default for aws to enable support, not removable - "AWSServiceRoleForSupport": {}, - // Enabled and not removable for every org account - "AWSServiceRoleForOrganizations": {}, - // Not manageable by IaC and set by default - "AWSServiceRoleForTrustedAdvisor": {}, -} - -type IamRoleEnumerator struct { - repository repository.IAMRepository - factory resource.ResourceFactory -} - -func NewIamRoleEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamRoleEnumerator { - return &IamRoleEnumerator{ - repository, - factory, - } -} - -func (e *IamRoleEnumerator) SupportedType() resource.ResourceType { - return resourceaws.AwsIamRoleResourceType -} - -func awsIamRoleShouldBeIgnored(roleName string) bool { - _, ok := iamRoleExclusionList[roleName] - return ok -} - -func (e *IamRoleEnumerator) Enumerate() ([]*resource.Resource, error) { - roles, err := e.repository.ListAllRoles() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0) - for _, role := range roles { - if role.RoleName != nil && awsIamRoleShouldBeIgnored(*role.RoleName) { - continue - } - - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *role.RoleName, - map[string]interface{}{ - "path": *role.Path, - }, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/iam_role_policy_attachment_enumerator.go b/pkg/remote/aws/iam_role_policy_attachment_enumerator.go deleted file mode 100644 index 0685ffdb..00000000 --- a/pkg/remote/aws/iam_role_policy_attachment_enumerator.go +++ /dev/null @@ -1,69 +0,0 @@ -package aws - -import ( - "fmt" - - "github.com/aws/aws-sdk-go/service/iam" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" -) - -type IamRolePolicyAttachmentEnumerator struct { - repository repository.IAMRepository - factory resource.ResourceFactory -} - -func NewIamRolePolicyAttachmentEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamRolePolicyAttachmentEnumerator { - return &IamRolePolicyAttachmentEnumerator{ - repository, - factory, - } -} - -func (e *IamRolePolicyAttachmentEnumerator) SupportedType() resource.ResourceType { - return resourceaws.AwsIamRolePolicyAttachmentResourceType -} - -func (e *IamRolePolicyAttachmentEnumerator) Enumerate() ([]*resource.Resource, error) { - roles, err := e.repository.ListAllRoles() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsIamRoleResourceType) - } - - results := make([]*resource.Resource, 0) - rolesNotIgnored := make([]*iam.Role, 0) - - for _, role := range roles { - if role.RoleName != nil && awsIamRoleShouldBeIgnored(*role.RoleName) { - continue - } - rolesNotIgnored = append(rolesNotIgnored, role) - } - - if len(rolesNotIgnored) == 0 { - return results, nil - } - - policyAttachments, err := e.repository.ListAllRolePolicyAttachments(rolesNotIgnored) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, attachedPol := range policyAttachments { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - fmt.Sprintf("%s-%s", *attachedPol.PolicyName, attachedPol.RoleName), - map[string]interface{}{ - "role": attachedPol.RoleName, - "policy_arn": *attachedPol.PolicyArn, - }, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/iam_role_policy_enumerator.go b/pkg/remote/aws/iam_role_policy_enumerator.go deleted file mode 100644 index fb1de508..00000000 --- a/pkg/remote/aws/iam_role_policy_enumerator.go +++ /dev/null @@ -1,55 +0,0 @@ -package aws - -import ( - "fmt" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" -) - -type IamRolePolicyEnumerator struct { - repository repository.IAMRepository - factory resource.ResourceFactory -} - -func NewIamRolePolicyEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamRolePolicyEnumerator { - return &IamRolePolicyEnumerator{ - repository, - factory, - } -} - -func (e *IamRolePolicyEnumerator) SupportedType() resource.ResourceType { - return resourceaws.AwsIamRolePolicyResourceType -} - -func (e *IamRolePolicyEnumerator) Enumerate() ([]*resource.Resource, error) { - roles, err := e.repository.ListAllRoles() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsIamRoleResourceType) - } - - policies, err := e.repository.ListAllRolePolicies(roles) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(policies)) - for _, policy := range policies { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - fmt.Sprintf("%s:%s", policy.RoleName, policy.Policy), - map[string]interface{}{ - "role": policy.RoleName, - }, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/iam_user_enumerator.go b/pkg/remote/aws/iam_user_enumerator.go deleted file mode 100644 index 9b2d1120..00000000 --- a/pkg/remote/aws/iam_user_enumerator.go +++ /dev/null @@ -1,47 +0,0 @@ -package aws - -import ( - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type IamUserEnumerator struct { - repository repository.IAMRepository - factory resource.ResourceFactory -} - -func NewIamUserEnumerator(repo repository.IAMRepository, factory resource.ResourceFactory) *IamUserEnumerator { - return &IamUserEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *IamUserEnumerator) SupportedType() resource.ResourceType { - return aws.AwsIamUserResourceType -} - -func (e *IamUserEnumerator) Enumerate() ([]*resource.Resource, error) { - users, err := e.repository.ListAllUsers() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(users)) - - for _, user := range users { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - awssdk.StringValue(user.UserName), - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/iam_user_policy_attachment_enumerator.go b/pkg/remote/aws/iam_user_policy_attachment_enumerator.go deleted file mode 100644 index 3c0a1586..00000000 --- a/pkg/remote/aws/iam_user_policy_attachment_enumerator.go +++ /dev/null @@ -1,55 +0,0 @@ -package aws - -import ( - "fmt" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" -) - -type IamUserPolicyAttachmentEnumerator struct { - repository repository.IAMRepository - factory resource.ResourceFactory -} - -func NewIamUserPolicyAttachmentEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamUserPolicyAttachmentEnumerator { - return &IamUserPolicyAttachmentEnumerator{ - repository, - factory, - } -} - -func (e *IamUserPolicyAttachmentEnumerator) SupportedType() resource.ResourceType { - return resourceaws.AwsIamUserPolicyAttachmentResourceType -} - -func (e *IamUserPolicyAttachmentEnumerator) Enumerate() ([]*resource.Resource, error) { - users, err := e.repository.ListAllUsers() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsIamUserResourceType) - } - - results := make([]*resource.Resource, 0) - policyAttachments, err := e.repository.ListAllUserPolicyAttachments(users) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, attachedPol := range policyAttachments { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - fmt.Sprintf("%s-%s", *attachedPol.PolicyName, attachedPol.UserName), - map[string]interface{}{ - "user": attachedPol.UserName, - "policy_arn": *attachedPol.PolicyArn, - }, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/iam_user_policy_enumerator.go b/pkg/remote/aws/iam_user_policy_enumerator.go deleted file mode 100644 index fdc938c0..00000000 --- a/pkg/remote/aws/iam_user_policy_enumerator.go +++ /dev/null @@ -1,50 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type IamUserPolicyEnumerator struct { - repository repository.IAMRepository - factory resource.ResourceFactory -} - -func NewIamUserPolicyEnumerator(repo repository.IAMRepository, factory resource.ResourceFactory) *IamUserPolicyEnumerator { - return &IamUserPolicyEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *IamUserPolicyEnumerator) SupportedType() resource.ResourceType { - return aws.AwsIamUserPolicyResourceType -} - -func (e *IamUserPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { - users, err := e.repository.ListAllUsers() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsIamUserResourceType) - } - userPolicies, err := e.repository.ListAllUserPolicies(users) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(userPolicies)) - - for _, userPolicy := range userPolicies { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - userPolicy, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/init.go b/pkg/remote/aws/init.go deleted file mode 100644 index d601dbb5..00000000 --- a/pkg/remote/aws/init.go +++ /dev/null @@ -1,258 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/remote/aws/client" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" -) - -/** - * Initialize remote (configure credentials, launch tf providers and start gRPC clients) - * Required to use Scanner - */ - -func Init(version string, alerter *alerter.Alerter, - providerLibrary *terraform.ProviderLibrary, - remoteLibrary *common.RemoteLibrary, - progress output.Progress, - resourceSchemaRepository *resource.SchemaRepository, - factory resource.ResourceFactory, - configDir string) error { - - provider, err := NewAWSTerraformProvider(version, progress, configDir) - if err != nil { - return err - } - err = provider.CheckCredentialsExist() - if err != nil { - return err - } - err = provider.Init() - if err != nil { - return err - } - - repositoryCache := cache.New(100) - - s3Repository := repository.NewS3Repository(client.NewAWSClientFactory(provider.session), repositoryCache) - ec2repository := repository.NewEC2Repository(provider.session, repositoryCache) - elbv2Repository := repository.NewELBV2Repository(provider.session, repositoryCache) - route53repository := repository.NewRoute53Repository(provider.session, repositoryCache) - lambdaRepository := repository.NewLambdaRepository(provider.session, repositoryCache) - rdsRepository := repository.NewRDSRepository(provider.session, repositoryCache) - sqsRepository := repository.NewSQSRepository(provider.session, repositoryCache) - snsRepository := repository.NewSNSRepository(provider.session, repositoryCache) - cloudfrontRepository := repository.NewCloudfrontRepository(provider.session, repositoryCache) - dynamoDBRepository := repository.NewDynamoDBRepository(provider.session, repositoryCache) - ecrRepository := repository.NewECRRepository(provider.session, repositoryCache) - kmsRepository := repository.NewKMSRepository(provider.session, repositoryCache) - iamRepository := repository.NewIAMRepository(provider.session, repositoryCache) - cloudformationRepository := repository.NewCloudformationRepository(provider.session, repositoryCache) - apigatewayRepository := repository.NewApiGatewayRepository(provider.session, repositoryCache) - appAutoScalingRepository := repository.NewAppAutoScalingRepository(provider.session, repositoryCache) - apigatewayv2Repository := repository.NewApiGatewayV2Repository(provider.session, repositoryCache) - autoscalingRepository := repository.NewAutoScalingRepository(provider.session, repositoryCache) - elbRepository := repository.NewELBRepository(provider.session, repositoryCache) - elasticacheRepository := repository.NewElastiCacheRepository(provider.session, repositoryCache) - - deserializer := resource.NewDeserializer(factory) - providerLibrary.AddProvider(terraform.AWS, provider) - - remoteLibrary.AddEnumerator(NewS3BucketEnumerator(s3Repository, factory, provider.Config, alerter)) - remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketResourceType, common.NewGenericDetailsFetcher(aws.AwsS3BucketResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewS3BucketInventoryEnumerator(s3Repository, factory, provider.Config, alerter)) - remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketInventoryResourceType, common.NewGenericDetailsFetcher(aws.AwsS3BucketInventoryResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewS3BucketNotificationEnumerator(s3Repository, factory, provider.Config, alerter)) - remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketNotificationResourceType, common.NewGenericDetailsFetcher(aws.AwsS3BucketNotificationResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewS3BucketMetricsEnumerator(s3Repository, factory, provider.Config, alerter)) - remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketMetricResourceType, common.NewGenericDetailsFetcher(aws.AwsS3BucketMetricResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewS3BucketPolicyEnumerator(s3Repository, factory, provider.Config, alerter)) - remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketPolicyResourceType, common.NewGenericDetailsFetcher(aws.AwsS3BucketPolicyResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewS3BucketAnalyticEnumerator(s3Repository, factory, provider.Config, alerter)) - remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, common.NewGenericDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewS3BucketPublicAccessBlockEnumerator(s3Repository, factory, provider.Config, alerter)) - - remoteLibrary.AddEnumerator(NewEC2EbsVolumeEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsEbsVolumeResourceType, common.NewGenericDetailsFetcher(aws.AwsEbsVolumeResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2EbsSnapshotEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsEbsSnapshotResourceType, common.NewGenericDetailsFetcher(aws.AwsEbsSnapshotResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2EipEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsEipResourceType, common.NewGenericDetailsFetcher(aws.AwsEipResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2AmiEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsAmiResourceType, common.NewGenericDetailsFetcher(aws.AwsAmiResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2KeyPairEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsKeyPairResourceType, common.NewGenericDetailsFetcher(aws.AwsKeyPairResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2EipAssociationEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsEipAssociationResourceType, common.NewGenericDetailsFetcher(aws.AwsEipAssociationResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2InstanceEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsInstanceResourceType, common.NewGenericDetailsFetcher(aws.AwsInstanceResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2InternetGatewayEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsInternetGatewayResourceType, common.NewGenericDetailsFetcher(aws.AwsInternetGatewayResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewVPCEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsVpcResourceType, common.NewGenericDetailsFetcher(aws.AwsVpcResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewDefaultVPCEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsDefaultVpcResourceType, common.NewGenericDetailsFetcher(aws.AwsDefaultVpcResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2RouteTableEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsRouteTableResourceType, common.NewGenericDetailsFetcher(aws.AwsRouteTableResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2DefaultRouteTableEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsDefaultRouteTableResourceType, common.NewGenericDetailsFetcher(aws.AwsDefaultRouteTableResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2RouteTableAssociationEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsRouteTableAssociationResourceType, common.NewGenericDetailsFetcher(aws.AwsRouteTableAssociationResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2SubnetEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsSubnetResourceType, common.NewGenericDetailsFetcher(aws.AwsSubnetResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2DefaultSubnetEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsDefaultSubnetResourceType, common.NewGenericDetailsFetcher(aws.AwsDefaultSubnetResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewVPCSecurityGroupEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsSecurityGroupResourceType, common.NewGenericDetailsFetcher(aws.AwsSecurityGroupResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewVPCDefaultSecurityGroupEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsDefaultSecurityGroupResourceType, common.NewGenericDetailsFetcher(aws.AwsDefaultSecurityGroupResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2NatGatewayEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsNatGatewayResourceType, common.NewGenericDetailsFetcher(aws.AwsNatGatewayResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2NetworkACLEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsNetworkACLResourceType, common.NewGenericDetailsFetcher(aws.AwsNetworkACLResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2NetworkACLRuleEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsNetworkACLRuleResourceType, common.NewGenericDetailsFetcher(aws.AwsNetworkACLRuleResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2DefaultNetworkACLEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsDefaultNetworkACLResourceType, common.NewGenericDetailsFetcher(aws.AwsDefaultNetworkACLResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2RouteEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsRouteResourceType, common.NewGenericDetailsFetcher(aws.AwsRouteResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewVPCSecurityGroupRuleEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsSecurityGroupRuleResourceType, common.NewGenericDetailsFetcher(aws.AwsSecurityGroupRuleResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewLaunchTemplateEnumerator(ec2repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsLaunchTemplateResourceType, common.NewGenericDetailsFetcher(aws.AwsLaunchTemplateResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewEC2EbsEncryptionByDefaultEnumerator(ec2repository, factory)) - - remoteLibrary.AddEnumerator(NewKMSKeyEnumerator(kmsRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsKmsKeyResourceType, common.NewGenericDetailsFetcher(aws.AwsKmsKeyResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewKMSAliasEnumerator(kmsRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsKmsAliasResourceType, common.NewGenericDetailsFetcher(aws.AwsKmsAliasResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewRoute53HealthCheckEnumerator(route53repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsRoute53HealthCheckResourceType, common.NewGenericDetailsFetcher(aws.AwsRoute53HealthCheckResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewRoute53ZoneEnumerator(route53repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsRoute53ZoneResourceType, common.NewGenericDetailsFetcher(aws.AwsRoute53ZoneResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewRoute53RecordEnumerator(route53repository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsRoute53RecordResourceType, common.NewGenericDetailsFetcher(aws.AwsRoute53RecordResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewCloudfrontDistributionEnumerator(cloudfrontRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsCloudfrontDistributionResourceType, common.NewGenericDetailsFetcher(aws.AwsCloudfrontDistributionResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewRDSDBInstanceEnumerator(rdsRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsDbInstanceResourceType, common.NewGenericDetailsFetcher(aws.AwsDbInstanceResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewRDSDBSubnetGroupEnumerator(rdsRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsDbSubnetGroupResourceType, common.NewGenericDetailsFetcher(aws.AwsDbSubnetGroupResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewSQSQueueEnumerator(sqsRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsSqsQueueResourceType, NewSQSQueueDetailsFetcher(provider, deserializer)) - remoteLibrary.AddEnumerator(NewSQSQueuePolicyEnumerator(sqsRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsSqsQueuePolicyResourceType, common.NewGenericDetailsFetcher(aws.AwsSqsQueuePolicyResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewSNSTopicEnumerator(snsRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsSnsTopicResourceType, common.NewGenericDetailsFetcher(aws.AwsSnsTopicResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewSNSTopicPolicyEnumerator(snsRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsSnsTopicPolicyResourceType, common.NewGenericDetailsFetcher(aws.AwsSnsTopicPolicyResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewSNSTopicSubscriptionEnumerator(snsRepository, factory, alerter)) - remoteLibrary.AddDetailsFetcher(aws.AwsSnsTopicSubscriptionResourceType, common.NewGenericDetailsFetcher(aws.AwsSnsTopicSubscriptionResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewDynamoDBTableEnumerator(dynamoDBRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsDynamodbTableResourceType, common.NewGenericDetailsFetcher(aws.AwsDynamodbTableResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewIamPolicyEnumerator(iamRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsIamPolicyResourceType, common.NewGenericDetailsFetcher(aws.AwsIamPolicyResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewLambdaFunctionEnumerator(lambdaRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsLambdaFunctionResourceType, common.NewGenericDetailsFetcher(aws.AwsLambdaFunctionResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewLambdaEventSourceMappingEnumerator(lambdaRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsLambdaEventSourceMappingResourceType, common.NewGenericDetailsFetcher(aws.AwsLambdaEventSourceMappingResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewIamUserEnumerator(iamRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsIamUserResourceType, common.NewGenericDetailsFetcher(aws.AwsIamUserResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewIamUserPolicyEnumerator(iamRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsIamUserPolicyResourceType, common.NewGenericDetailsFetcher(aws.AwsIamUserPolicyResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewIamRoleEnumerator(iamRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsIamRoleResourceType, common.NewGenericDetailsFetcher(aws.AwsIamRoleResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewIamAccessKeyEnumerator(iamRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsIamAccessKeyResourceType, common.NewGenericDetailsFetcher(aws.AwsIamAccessKeyResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewIamRolePolicyAttachmentEnumerator(iamRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsIamRolePolicyAttachmentResourceType, common.NewGenericDetailsFetcher(aws.AwsIamRolePolicyAttachmentResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewIamRolePolicyEnumerator(iamRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsIamRolePolicyResourceType, common.NewGenericDetailsFetcher(aws.AwsIamRolePolicyResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewIamUserPolicyAttachmentEnumerator(iamRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsIamUserPolicyAttachmentResourceType, common.NewGenericDetailsFetcher(aws.AwsIamUserPolicyAttachmentResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewIamGroupPolicyEnumerator(iamRepository, factory)) - remoteLibrary.AddEnumerator(NewIamGroupEnumerator(iamRepository, factory)) - remoteLibrary.AddEnumerator(NewIamGroupPolicyAttachmentEnumerator(iamRepository, factory)) - - remoteLibrary.AddEnumerator(NewECRRepositoryEnumerator(ecrRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsEcrRepositoryResourceType, common.NewGenericDetailsFetcher(aws.AwsEcrRepositoryResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewECRRepositoryPolicyEnumerator(ecrRepository, factory)) - - remoteLibrary.AddEnumerator(NewRDSClusterEnumerator(rdsRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsRDSClusterResourceType, common.NewGenericDetailsFetcher(aws.AwsRDSClusterResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewCloudformationStackEnumerator(cloudformationRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsCloudformationStackResourceType, common.NewGenericDetailsFetcher(aws.AwsCloudformationStackResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewApiGatewayRestApiEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayAccountEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayApiKeyEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayAuthorizerEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayStageEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayResourceEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayDomainNameEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayVpcLinkEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayRequestValidatorEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayRestApiPolicyEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayBasePathMappingEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayMethodEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayModelEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayMethodResponseEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayGatewayResponseEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayMethodSettingsEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayIntegrationEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayIntegrationResponseEnumerator(apigatewayRepository, factory)) - - remoteLibrary.AddEnumerator(NewApiGatewayV2ApiEnumerator(apigatewayv2Repository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayV2RouteEnumerator(apigatewayv2Repository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayV2DeploymentEnumerator(apigatewayv2Repository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayV2VpcLinkEnumerator(apigatewayv2Repository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayV2AuthorizerEnumerator(apigatewayv2Repository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayV2IntegrationEnumerator(apigatewayv2Repository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayV2ModelEnumerator(apigatewayv2Repository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayV2StageEnumerator(apigatewayv2Repository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayV2RouteResponseEnumerator(apigatewayv2Repository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayV2MappingEnumerator(apigatewayv2Repository, apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayV2DomainNameEnumerator(apigatewayRepository, factory)) - remoteLibrary.AddEnumerator(NewApiGatewayV2IntegrationResponseEnumerator(apigatewayv2Repository, factory)) - - remoteLibrary.AddEnumerator(NewAppAutoscalingTargetEnumerator(appAutoScalingRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsAppAutoscalingTargetResourceType, common.NewGenericDetailsFetcher(aws.AwsAppAutoscalingTargetResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewAppAutoscalingPolicyEnumerator(appAutoScalingRepository, factory)) - remoteLibrary.AddDetailsFetcher(aws.AwsAppAutoscalingPolicyResourceType, common.NewGenericDetailsFetcher(aws.AwsAppAutoscalingPolicyResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewAppAutoscalingScheduledActionEnumerator(appAutoScalingRepository, factory)) - - remoteLibrary.AddEnumerator(NewLaunchConfigurationEnumerator(autoscalingRepository, factory)) - - remoteLibrary.AddEnumerator(NewLoadBalancerEnumerator(elbv2Repository, factory)) - remoteLibrary.AddEnumerator(NewLoadBalancerListenerEnumerator(elbv2Repository, factory)) - - remoteLibrary.AddEnumerator(NewClassicLoadBalancerEnumerator(elbRepository, factory)) - - remoteLibrary.AddEnumerator(NewElastiCacheClusterEnumerator(elasticacheRepository, factory)) - - err = resourceSchemaRepository.Init(terraform.AWS, provider.Version(), provider.Schema()) - if err != nil { - return err - } - aws.InitResourcesMetadata(resourceSchemaRepository) - - return nil -} diff --git a/pkg/remote/aws/kms_alias_enumerator.go b/pkg/remote/aws/kms_alias_enumerator.go deleted file mode 100644 index 43f550a3..00000000 --- a/pkg/remote/aws/kms_alias_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type KMSAliasEnumerator struct { - repository repository.KMSRepository - factory resource.ResourceFactory -} - -func NewKMSAliasEnumerator(repo repository.KMSRepository, factory resource.ResourceFactory) *KMSAliasEnumerator { - return &KMSAliasEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *KMSAliasEnumerator) SupportedType() resource.ResourceType { - return aws.AwsKmsAliasResourceType -} - -func (e *KMSAliasEnumerator) Enumerate() ([]*resource.Resource, error) { - aliases, err := e.repository.ListAllAliases() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(aliases)) - - for _, alias := range aliases { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *alias.AliasName, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/kms_key_enumerator.go b/pkg/remote/aws/kms_key_enumerator.go deleted file mode 100644 index 376844f1..00000000 --- a/pkg/remote/aws/kms_key_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type KMSKeyEnumerator struct { - repository repository.KMSRepository - factory resource.ResourceFactory -} - -func NewKMSKeyEnumerator(repo repository.KMSRepository, factory resource.ResourceFactory) *KMSKeyEnumerator { - return &KMSKeyEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *KMSKeyEnumerator) SupportedType() resource.ResourceType { - return aws.AwsKmsKeyResourceType -} - -func (e *KMSKeyEnumerator) Enumerate() ([]*resource.Resource, error) { - keys, err := e.repository.ListAllKeys() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(keys)) - - for _, key := range keys { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *key.KeyId, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/lambda_event_source_mapping_enumerator.go b/pkg/remote/aws/lambda_event_source_mapping_enumerator.go deleted file mode 100644 index f8b0781c..00000000 --- a/pkg/remote/aws/lambda_event_source_mapping_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" -) - -type LambdaEventSourceMappingEnumerator struct { - repository repository.LambdaRepository - factory resource.ResourceFactory -} - -func NewLambdaEventSourceMappingEnumerator(repo repository.LambdaRepository, factory resource.ResourceFactory) *LambdaEventSourceMappingEnumerator { - return &LambdaEventSourceMappingEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *LambdaEventSourceMappingEnumerator) SupportedType() resource.ResourceType { - return resourceaws.AwsLambdaEventSourceMappingResourceType -} - -func (e *LambdaEventSourceMappingEnumerator) Enumerate() ([]*resource.Resource, error) { - eventSourceMappings, err := e.repository.ListAllLambdaEventSourceMappings() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(eventSourceMappings)) - - for _, eventSourceMapping := range eventSourceMappings { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *eventSourceMapping.UUID, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/lambda_function_enumerator.go b/pkg/remote/aws/lambda_function_enumerator.go deleted file mode 100644 index cf854c65..00000000 --- a/pkg/remote/aws/lambda_function_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" -) - -type LambdaFunctionEnumerator struct { - repository repository.LambdaRepository - factory resource.ResourceFactory -} - -func NewLambdaFunctionEnumerator(repo repository.LambdaRepository, factory resource.ResourceFactory) *LambdaFunctionEnumerator { - return &LambdaFunctionEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *LambdaFunctionEnumerator) SupportedType() resource.ResourceType { - return resourceaws.AwsLambdaFunctionResourceType -} - -func (e *LambdaFunctionEnumerator) Enumerate() ([]*resource.Resource, error) { - functions, err := e.repository.ListAllLambdaFunctions() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(functions)) - - for _, function := range functions { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *function.FunctionName, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/launch_configuration_enumerator.go b/pkg/remote/aws/launch_configuration_enumerator.go deleted file mode 100644 index 940b8356..00000000 --- a/pkg/remote/aws/launch_configuration_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type LaunchConfigurationEnumerator struct { - repository repository.AutoScalingRepository - factory resource.ResourceFactory -} - -func NewLaunchConfigurationEnumerator(repo repository.AutoScalingRepository, factory resource.ResourceFactory) *LaunchConfigurationEnumerator { - return &LaunchConfigurationEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *LaunchConfigurationEnumerator) SupportedType() resource.ResourceType { - return aws.AwsLaunchConfigurationResourceType -} - -func (e *LaunchConfigurationEnumerator) Enumerate() ([]*resource.Resource, error) { - configs, err := e.repository.DescribeLaunchConfigurations() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(configs)) - - for _, config := range configs { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *config.LaunchConfigurationName, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/launch_template_enumerator.go b/pkg/remote/aws/launch_template_enumerator.go deleted file mode 100644 index 0c09c667..00000000 --- a/pkg/remote/aws/launch_template_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type LaunchTemplateEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewLaunchTemplateEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *LaunchTemplateEnumerator { - return &LaunchTemplateEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *LaunchTemplateEnumerator) SupportedType() resource.ResourceType { - return aws.AwsLaunchTemplateResourceType -} - -func (e *LaunchTemplateEnumerator) Enumerate() ([]*resource.Resource, error) { - templates, err := e.repository.DescribeLaunchTemplates() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(templates)) - - for _, tmpl := range templates { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *tmpl.LaunchTemplateId, - map[string]interface{}{}, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/load_balancer_enumerator.go b/pkg/remote/aws/load_balancer_enumerator.go deleted file mode 100644 index 7d87da40..00000000 --- a/pkg/remote/aws/load_balancer_enumerator.go +++ /dev/null @@ -1,48 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type LoadBalancerEnumerator struct { - repository repository.ELBV2Repository - factory resource.ResourceFactory -} - -func NewLoadBalancerEnumerator(repo repository.ELBV2Repository, factory resource.ResourceFactory) *LoadBalancerEnumerator { - return &LoadBalancerEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *LoadBalancerEnumerator) SupportedType() resource.ResourceType { - return aws.AwsLoadBalancerResourceType -} - -func (e *LoadBalancerEnumerator) Enumerate() ([]*resource.Resource, error) { - loadBalancers, err := e.repository.ListAllLoadBalancers() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(loadBalancers)) - - for _, lb := range loadBalancers { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *lb.LoadBalancerArn, - map[string]interface{}{ - "name": *lb.LoadBalancerName, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/load_balancer_listener_enumerator.go b/pkg/remote/aws/load_balancer_listener_enumerator.go deleted file mode 100644 index 2cc518aa..00000000 --- a/pkg/remote/aws/load_balancer_listener_enumerator.go +++ /dev/null @@ -1,53 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type LoadBalancerListenerEnumerator struct { - repository repository.ELBV2Repository - factory resource.ResourceFactory -} - -func NewLoadBalancerListenerEnumerator(repo repository.ELBV2Repository, factory resource.ResourceFactory) *LoadBalancerListenerEnumerator { - return &LoadBalancerListenerEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *LoadBalancerListenerEnumerator) SupportedType() resource.ResourceType { - return aws.AwsLoadBalancerListenerResourceType -} - -func (e *LoadBalancerListenerEnumerator) Enumerate() ([]*resource.Resource, error) { - loadBalancers, err := e.repository.ListAllLoadBalancers() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsLoadBalancerResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, lb := range loadBalancers { - listeners, err := e.repository.ListAllLoadBalancerListeners(*lb.LoadBalancerArn) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, listener := range listeners { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *listener.ListenerArn, - map[string]interface{}{}, - ), - ) - } - } - - return results, nil -} diff --git a/pkg/remote/aws/provider.go b/pkg/remote/aws/provider.go deleted file mode 100644 index 6ccb0e95..00000000 --- a/pkg/remote/aws/provider.go +++ /dev/null @@ -1,119 +0,0 @@ -package aws - -import ( - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/remote/terraform" - tf "github.com/snyk/driftctl/pkg/terraform" -) - -type awsConfig struct { - AccessKey string - SecretKey string - CredsFilename string - Profile string - Token string - Region string `cty:"region"` - MaxRetries int - - AssumeRoleARN string - AssumeRoleExternalID string - AssumeRoleSessionName string - AssumeRolePolicy string - - AllowedAccountIds []string - ForbiddenAccountIds []string - - Endpoints map[string]string - IgnoreTagsConfig map[string]string - Insecure bool - - SkipCredsValidation bool `cty:"skip_credentials_validation"` - SkipGetEC2Platforms bool - SkipRegionValidation bool - SkipRequestingAccountId bool `cty:"skip_requesting_account_id"` - SkipMetadataApiCheck bool - S3ForcePathStyle bool -} - -type AWSTerraformProvider struct { - *terraform.TerraformProvider - session *session.Session - name string - version string -} - -func NewAWSTerraformProvider(version string, progress output.Progress, configDir string) (*AWSTerraformProvider, error) { - if version == "" { - version = "3.19.0" - } - p := &AWSTerraformProvider{ - version: version, - name: "aws", - } - installer, err := tf.NewProviderInstaller(tf.ProviderConfig{ - Key: p.name, - Version: version, - ConfigDir: configDir, - }) - if err != nil { - return nil, err - } - p.session = session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - tfProvider, err := terraform.NewTerraformProvider(installer, terraform.TerraformProviderConfig{ - Name: p.name, - DefaultAlias: *p.session.Config.Region, - GetProviderConfig: func(alias string) interface{} { - return awsConfig{ - Region: alias, - // Those two parameters are used to make sure that the credentials are not validated when calling - // Configure(). Credentials validation is now handled directly in driftctl - SkipCredsValidation: true, - SkipRequestingAccountId: true, - - MaxRetries: 10, // TODO make this configurable - } - }, - }, progress) - if err != nil { - return nil, err - } - p.TerraformProvider = tfProvider - return p, err -} - -func (a *AWSTerraformProvider) Name() string { - return a.name -} - -func (p *AWSTerraformProvider) Version() string { - return p.version -} - -func (p *AWSTerraformProvider) CheckCredentialsExist() error { - _, err := p.session.Config.Credentials.Get() - if err == credentials.ErrNoValidProvidersFoundInChain { - return errors.New("Could not find a way to authenticate on AWS!\n" + - "Please refer to AWS documentation: https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html\n\n" + - "To use a different cloud provider, use --to=\"gcp+tf\" for GCP or --to=\"azure+tf\" for Azure.") - } - if err != nil { - return err - } - // This call is to make sure that the credentials are valid - // A more complex logic exist in terraform provider, but it's probably not worth to implement it - // https://github.com/hashicorp/terraform-provider-aws/blob/e3959651092864925045a6044961a73137095798/aws/auth_helpers.go#L111 - _, err = sts.New(p.session).GetCallerIdentity(&sts.GetCallerIdentityInput{}) - if err != nil { - logrus.Debug(err) - return errors.New("Could not authenticate successfully on AWS with the provided credentials.\n" + - "Please refer to the AWS documentation: https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html\n") - } - return nil -} diff --git a/pkg/remote/aws/rds_cluster_enumerator.go b/pkg/remote/aws/rds_cluster_enumerator.go deleted file mode 100644 index eb51ba65..00000000 --- a/pkg/remote/aws/rds_cluster_enumerator.go +++ /dev/null @@ -1,55 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type RDSClusterEnumerator struct { - repository repository.RDSRepository - factory resource.ResourceFactory -} - -func NewRDSClusterEnumerator(repository repository.RDSRepository, factory resource.ResourceFactory) *RDSClusterEnumerator { - return &RDSClusterEnumerator{ - repository, - factory, - } -} - -func (e *RDSClusterEnumerator) SupportedType() resource.ResourceType { - return aws.AwsRDSClusterResourceType -} - -func (e *RDSClusterEnumerator) Enumerate() ([]*resource.Resource, error) { - clusters, err := e.repository.ListAllDBClusters() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(clusters)) - - for _, cluster := range clusters { - var databaseName string - - if v := cluster.DatabaseName; v != nil { - databaseName = *cluster.DatabaseName - } - - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *cluster.DBClusterIdentifier, - map[string]interface{}{ - "cluster_identifier": *cluster.DBClusterIdentifier, - "database_name": databaseName, - }, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/rds_db_instance_enumerator.go b/pkg/remote/aws/rds_db_instance_enumerator.go deleted file mode 100644 index 5e35f1cf..00000000 --- a/pkg/remote/aws/rds_db_instance_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type RDSDBInstanceEnumerator struct { - repository repository.RDSRepository - factory resource.ResourceFactory -} - -func NewRDSDBInstanceEnumerator(repo repository.RDSRepository, factory resource.ResourceFactory) *RDSDBInstanceEnumerator { - return &RDSDBInstanceEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *RDSDBInstanceEnumerator) SupportedType() resource.ResourceType { - return aws.AwsDbInstanceResourceType -} - -func (e *RDSDBInstanceEnumerator) Enumerate() ([]*resource.Resource, error) { - instances, err := e.repository.ListAllDBInstances() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(instances)) - - for _, instance := range instances { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *instance.DBInstanceIdentifier, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/rds_db_subnet_group_enumerator.go b/pkg/remote/aws/rds_db_subnet_group_enumerator.go deleted file mode 100644 index 4992f09b..00000000 --- a/pkg/remote/aws/rds_db_subnet_group_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type RDSDBSubnetGroupEnumerator struct { - repository repository.RDSRepository - factory resource.ResourceFactory -} - -func NewRDSDBSubnetGroupEnumerator(repo repository.RDSRepository, factory resource.ResourceFactory) *RDSDBSubnetGroupEnumerator { - return &RDSDBSubnetGroupEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *RDSDBSubnetGroupEnumerator) SupportedType() resource.ResourceType { - return aws.AwsDbSubnetGroupResourceType -} - -func (e *RDSDBSubnetGroupEnumerator) Enumerate() ([]*resource.Resource, error) { - subnetGroups, err := e.repository.ListAllDBSubnetGroups() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(subnetGroups)) - - for _, subnetGroup := range subnetGroups { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *subnetGroup.DBSubnetGroupName, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/repository/api_gateway_repository.go b/pkg/remote/aws/repository/api_gateway_repository.go deleted file mode 100644 index 4cf88ff5..00000000 --- a/pkg/remote/aws/repository/api_gateway_repository.go +++ /dev/null @@ -1,285 +0,0 @@ -package repository - -import ( - "fmt" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/apigateway" - "github.com/aws/aws-sdk-go/service/apigateway/apigatewayiface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type ApiGatewayRepository interface { - ListAllRestApis() ([]*apigateway.RestApi, error) - GetAccount() (*apigateway.Account, error) - ListAllApiKeys() ([]*apigateway.ApiKey, error) - ListAllRestApiAuthorizers(string) ([]*apigateway.Authorizer, error) - ListAllRestApiStages(string) ([]*apigateway.Stage, error) - ListAllRestApiResources(string) ([]*apigateway.Resource, error) - ListAllDomainNames() ([]*apigateway.DomainName, error) - ListAllVpcLinks() ([]*apigateway.UpdateVpcLinkOutput, error) - ListAllRestApiRequestValidators(string) ([]*apigateway.UpdateRequestValidatorOutput, error) - ListAllDomainNameBasePathMappings(string) ([]*apigateway.BasePathMapping, error) - ListAllRestApiModels(string) ([]*apigateway.Model, error) - ListAllRestApiGatewayResponses(string) ([]*apigateway.UpdateGatewayResponseOutput, error) -} - -type apigatewayRepository struct { - client apigatewayiface.APIGatewayAPI - cache cache.Cache -} - -func NewApiGatewayRepository(session *session.Session, c cache.Cache) *apigatewayRepository { - return &apigatewayRepository{ - apigateway.New(session), - c, - } -} - -func (r *apigatewayRepository) ListAllRestApis() ([]*apigateway.RestApi, error) { - cacheKey := "apigatewayListAllRestApis" - v := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - if v != nil { - return v.([]*apigateway.RestApi), nil - } - - var restApis []*apigateway.RestApi - input := apigateway.GetRestApisInput{} - err := r.client.GetRestApisPages(&input, - func(resp *apigateway.GetRestApisOutput, lastPage bool) bool { - restApis = append(restApis, resp.Items...) - return !lastPage - }, - ) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, restApis) - return restApis, nil -} - -func (r *apigatewayRepository) GetAccount() (*apigateway.Account, error) { - if v := r.cache.Get("apigatewayGetAccount"); v != nil { - return v.(*apigateway.Account), nil - } - - account, err := r.client.GetAccount(&apigateway.GetAccountInput{}) - if err != nil { - return nil, err - } - - r.cache.Put("apigatewayGetAccount", account) - return account, nil -} - -func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) { - if v := r.cache.Get("apigatewayListAllApiKeys"); v != nil { - return v.([]*apigateway.ApiKey), nil - } - - var apiKeys []*apigateway.ApiKey - input := apigateway.GetApiKeysInput{} - err := r.client.GetApiKeysPages(&input, - func(resp *apigateway.GetApiKeysOutput, lastPage bool) bool { - apiKeys = append(apiKeys, resp.Items...) - return !lastPage - }, - ) - if err != nil { - return nil, err - } - - r.cache.Put("apigatewayListAllApiKeys", apiKeys) - return apiKeys, nil -} - -func (r *apigatewayRepository) ListAllRestApiAuthorizers(apiId string) ([]*apigateway.Authorizer, error) { - cacheKey := fmt.Sprintf("apigatewayListAllRestApiAuthorizers_api_%s", apiId) - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*apigateway.Authorizer), nil - } - - input := &apigateway.GetAuthorizersInput{ - RestApiId: &apiId, - } - resources, err := r.client.GetAuthorizers(input) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resources.Items) - return resources.Items, nil -} - -func (r *apigatewayRepository) ListAllRestApiStages(apiId string) ([]*apigateway.Stage, error) { - cacheKey := fmt.Sprintf("apigatewayListAllRestApiStages_api_%s", apiId) - v := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - if v != nil { - return v.([]*apigateway.Stage), nil - } - - input := &apigateway.GetStagesInput{ - RestApiId: &apiId, - } - resources, err := r.client.GetStages(input) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resources.Item) - return resources.Item, nil -} - -func (r *apigatewayRepository) ListAllRestApiResources(apiId string) ([]*apigateway.Resource, error) { - cacheKey := fmt.Sprintf("apigatewayListAllRestApiResources_api_%s", apiId) - v := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - if v != nil { - return v.([]*apigateway.Resource), nil - } - - var resources []*apigateway.Resource - input := &apigateway.GetResourcesInput{ - RestApiId: &apiId, - Embed: []*string{aws.String("methods")}, - } - err := r.client.GetResourcesPages(input, func(res *apigateway.GetResourcesOutput, lastPage bool) bool { - resources = append(resources, res.Items...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resources) - return resources, nil -} - -func (r *apigatewayRepository) ListAllDomainNames() ([]*apigateway.DomainName, error) { - cacheKey := "apigatewayListAllDomainNames" - v := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - if v != nil { - return v.([]*apigateway.DomainName), nil - } - - var domainNames []*apigateway.DomainName - input := apigateway.GetDomainNamesInput{} - err := r.client.GetDomainNamesPages(&input, - func(resp *apigateway.GetDomainNamesOutput, lastPage bool) bool { - domainNames = append(domainNames, resp.Items...) - return !lastPage - }, - ) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, domainNames) - return domainNames, nil -} - -func (r *apigatewayRepository) ListAllVpcLinks() ([]*apigateway.UpdateVpcLinkOutput, error) { - if v := r.cache.Get("apigatewayListAllVpcLinks"); v != nil { - return v.([]*apigateway.UpdateVpcLinkOutput), nil - } - - var vpcLinks []*apigateway.UpdateVpcLinkOutput - input := apigateway.GetVpcLinksInput{} - err := r.client.GetVpcLinksPages(&input, - func(resp *apigateway.GetVpcLinksOutput, lastPage bool) bool { - vpcLinks = append(vpcLinks, resp.Items...) - return !lastPage - }, - ) - if err != nil { - return nil, err - } - - r.cache.Put("apigatewayListAllVpcLinks", vpcLinks) - return vpcLinks, nil -} - -func (r *apigatewayRepository) ListAllRestApiRequestValidators(apiId string) ([]*apigateway.UpdateRequestValidatorOutput, error) { - cacheKey := fmt.Sprintf("apigatewayListAllRestApiRequestValidators_api_%s", apiId) - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*apigateway.UpdateRequestValidatorOutput), nil - } - - input := &apigateway.GetRequestValidatorsInput{ - RestApiId: &apiId, - } - resources, err := r.client.GetRequestValidators(input) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resources.Items) - return resources.Items, nil -} - -func (r *apigatewayRepository) ListAllDomainNameBasePathMappings(domainName string) ([]*apigateway.BasePathMapping, error) { - cacheKey := fmt.Sprintf("apigatewayListAllDomainNameBasePathMappings_domainName_%s", domainName) - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*apigateway.BasePathMapping), nil - } - - var mappings []*apigateway.BasePathMapping - input := &apigateway.GetBasePathMappingsInput{ - DomainName: &domainName, - } - err := r.client.GetBasePathMappingsPages(input, func(res *apigateway.GetBasePathMappingsOutput, lastPage bool) bool { - mappings = append(mappings, res.Items...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, mappings) - return mappings, nil -} - -func (r *apigatewayRepository) ListAllRestApiModels(apiId string) ([]*apigateway.Model, error) { - cacheKey := fmt.Sprintf("apigatewayListAllRestApiModels_api_%s", apiId) - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*apigateway.Model), nil - } - - var resources []*apigateway.Model - input := &apigateway.GetModelsInput{ - RestApiId: &apiId, - } - err := r.client.GetModelsPages(input, func(res *apigateway.GetModelsOutput, lastPage bool) bool { - resources = append(resources, res.Items...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resources) - return resources, nil -} - -func (r *apigatewayRepository) ListAllRestApiGatewayResponses(apiId string) ([]*apigateway.UpdateGatewayResponseOutput, error) { - cacheKey := fmt.Sprintf("apigatewayListAllRestApiGatewayResponses_api_%s", apiId) - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*apigateway.UpdateGatewayResponseOutput), nil - } - - input := &apigateway.GetGatewayResponsesInput{ - RestApiId: &apiId, - } - resources, err := r.client.GetGatewayResponses(input) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resources.Items) - return resources.Items, nil -} diff --git a/pkg/remote/aws/repository/api_gateway_repository_test.go b/pkg/remote/aws/repository/api_gateway_repository_test.go deleted file mode 100644 index 83898b4b..00000000 --- a/pkg/remote/aws/repository/api_gateway_repository_test.go +++ /dev/null @@ -1,890 +0,0 @@ -package repository - -import ( - "strings" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/apigateway" - "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - - "github.com/stretchr/testify/mock" - - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" -) - -func Test_apigatewayRepository_ListAllRestApis(t *testing.T) { - apis := []*apigateway.RestApi{ - {Id: aws.String("restapi1")}, - {Id: aws.String("restapi2")}, - {Id: aws.String("restapi3")}, - {Id: aws.String("restapi4")}, - {Id: aws.String("restapi5")}, - {Id: aws.String("restapi6")}, - } - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) - want []*apigateway.RestApi - wantErr error - }{ - { - name: "list multiple rest apis", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetRestApisPages", - &apigateway.GetRestApisInput{}, - mock.MatchedBy(func(callback func(res *apigateway.GetRestApisOutput, lastPage bool) bool) bool { - callback(&apigateway.GetRestApisOutput{ - Items: apis[:3], - }, false) - callback(&apigateway.GetRestApisOutput{ - Items: apis[3:], - }, true) - return true - })).Return(nil).Once() - - store.On("GetAndLock", "apigatewayListAllRestApis").Return(nil).Times(1) - store.On("Unlock", "apigatewayListAllRestApis").Times(1) - store.On("Put", "apigatewayListAllRestApis", apis).Return(false).Times(1) - }, - want: apis, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - store.On("GetAndLock", "apigatewayListAllRestApis").Return(apis).Times(1) - store.On("Unlock", "apigatewayListAllRestApis").Times(1) - }, - want: apis, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGateway{} - tt.mocks(client, store) - r := &apigatewayRepository{ - client: client, - cache: store, - } - got, err := r.ListAllRestApis() - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayRepository_GetAccount(t *testing.T) { - account := &apigateway.Account{ - CloudwatchRoleArn: aws.String("arn:aws:iam::017011014111:role/api_gateway_cloudwatch_global"), - } - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) - want *apigateway.Account - wantErr error - }{ - { - name: "get a single account", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetAccount", &apigateway.GetAccountInput{}).Return(account, nil).Once() - - store.On("Get", "apigatewayGetAccount").Return(nil).Times(1) - store.On("Put", "apigatewayGetAccount", account).Return(false).Times(1) - }, - want: account, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - store.On("Get", "apigatewayGetAccount").Return(account).Times(1) - }, - want: account, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGateway{} - tt.mocks(client, store) - r := &apigatewayRepository{ - client: client, - cache: store, - } - got, err := r.GetAccount() - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayRepository_ListAllApiKeys(t *testing.T) { - keys := []*apigateway.ApiKey{ - {Id: aws.String("apikey1")}, - {Id: aws.String("apikey2")}, - {Id: aws.String("apikey3")}, - {Id: aws.String("apikey4")}, - {Id: aws.String("apikey5")}, - {Id: aws.String("apikey6")}, - } - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) - want []*apigateway.ApiKey - wantErr error - }{ - { - name: "list multiple api keys", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetApiKeysPages", - &apigateway.GetApiKeysInput{}, - mock.MatchedBy(func(callback func(res *apigateway.GetApiKeysOutput, lastPage bool) bool) bool { - callback(&apigateway.GetApiKeysOutput{ - Items: keys[:3], - }, false) - callback(&apigateway.GetApiKeysOutput{ - Items: keys[3:], - }, true) - return true - })).Return(nil).Once() - - store.On("Get", "apigatewayListAllApiKeys").Return(nil).Times(1) - store.On("Put", "apigatewayListAllApiKeys", keys).Return(false).Times(1) - }, - want: keys, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - store.On("Get", "apigatewayListAllApiKeys").Return(keys).Times(1) - }, - want: keys, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGateway{} - tt.mocks(client, store) - r := &apigatewayRepository{ - client: client, - cache: store, - } - got, err := r.ListAllApiKeys() - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayRepository_ListAllRestApiAuthorizers(t *testing.T) { - api := &apigateway.RestApi{ - Id: aws.String("restapi1"), - } - - apiAuthorizers := []*apigateway.Authorizer{ - {Id: aws.String("resource1")}, - {Id: aws.String("resource2")}, - {Id: aws.String("resource3")}, - {Id: aws.String("resource4")}, - } - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) - want []*apigateway.Authorizer - wantErr error - }{ - { - name: "list multiple rest api authorizers", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetAuthorizers", - &apigateway.GetAuthorizersInput{ - RestApiId: aws.String("restapi1"), - }).Return(&apigateway.GetAuthorizersOutput{Items: apiAuthorizers}, nil).Once() - - store.On("Get", "apigatewayListAllRestApiAuthorizers_api_restapi1").Return(nil).Times(1) - store.On("Put", "apigatewayListAllRestApiAuthorizers_api_restapi1", apiAuthorizers).Return(false).Times(1) - }, - want: apiAuthorizers, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - store.On("Get", "apigatewayListAllRestApiAuthorizers_api_restapi1").Return(apiAuthorizers).Times(1) - }, - want: apiAuthorizers, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGateway{} - tt.mocks(client, store) - r := &apigatewayRepository{ - client: client, - cache: store, - } - got, err := r.ListAllRestApiAuthorizers(*api.Id) - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayRepository_ListAllRestApiStages(t *testing.T) { - api := &apigateway.RestApi{ - Id: aws.String("restapi1"), - } - - apiStages := []*apigateway.Stage{ - {StageName: aws.String("stage1")}, - {StageName: aws.String("stage2")}, - {StageName: aws.String("stage3")}, - {StageName: aws.String("stage4")}, - } - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) - want []*apigateway.Stage - wantErr error - }{ - { - name: "list multiple rest api stages", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetStages", - &apigateway.GetStagesInput{ - RestApiId: aws.String("restapi1"), - }).Return(&apigateway.GetStagesOutput{Item: apiStages}, nil).Once() - - store.On("GetAndLock", "apigatewayListAllRestApiStages_api_restapi1").Return(nil).Times(1) - store.On("Unlock", "apigatewayListAllRestApiStages_api_restapi1").Times(1) - store.On("Put", "apigatewayListAllRestApiStages_api_restapi1", apiStages).Return(false).Times(1) - }, - want: apiStages, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - store.On("GetAndLock", "apigatewayListAllRestApiStages_api_restapi1").Return(apiStages).Times(1) - store.On("Unlock", "apigatewayListAllRestApiStages_api_restapi1").Times(1) - }, - want: apiStages, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGateway{} - tt.mocks(client, store) - r := &apigatewayRepository{ - client: client, - cache: store, - } - got, err := r.ListAllRestApiStages(*api.Id) - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayRepository_ListAllRestApiResources(t *testing.T) { - api := &apigateway.RestApi{ - Id: aws.String("restapi1"), - } - - apiResources := []*apigateway.Resource{ - {Id: aws.String("resource1")}, - {Id: aws.String("resource2")}, - {Id: aws.String("resource3")}, - {Id: aws.String("resource4")}, - } - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) - want []*apigateway.Resource - wantErr error - }{ - { - name: "list multiple rest api resources", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetResourcesPages", - &apigateway.GetResourcesInput{ - RestApiId: aws.String("restapi1"), - Embed: []*string{aws.String("methods")}, - }, - mock.MatchedBy(func(callback func(res *apigateway.GetResourcesOutput, lastPage bool) bool) bool { - callback(&apigateway.GetResourcesOutput{ - Items: apiResources, - }, true) - return true - })).Return(nil).Once() - - store.On("GetAndLock", "apigatewayListAllRestApiResources_api_restapi1").Return(nil).Times(1) - store.On("Unlock", "apigatewayListAllRestApiResources_api_restapi1").Times(1) - store.On("Put", "apigatewayListAllRestApiResources_api_restapi1", apiResources).Return(false).Times(1) - }, - want: apiResources, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - store.On("GetAndLock", "apigatewayListAllRestApiResources_api_restapi1").Return(apiResources).Times(1) - store.On("Unlock", "apigatewayListAllRestApiResources_api_restapi1").Times(1) - }, - want: apiResources, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGateway{} - tt.mocks(client, store) - r := &apigatewayRepository{ - client: client, - cache: store, - } - got, err := r.ListAllRestApiResources(*api.Id) - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayRepository_ListAllDomainNames(t *testing.T) { - domainNames := []*apigateway.DomainName{ - {DomainName: aws.String("domainName1")}, - {DomainName: aws.String("domainName2")}, - {DomainName: aws.String("domainName3")}, - {DomainName: aws.String("domainName4")}, - {DomainName: aws.String("domainName5")}, - {DomainName: aws.String("domainName6")}, - } - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) - want []*apigateway.DomainName - wantErr error - }{ - { - name: "list multiple domain names", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetDomainNamesPages", - &apigateway.GetDomainNamesInput{}, - mock.MatchedBy(func(callback func(res *apigateway.GetDomainNamesOutput, lastPage bool) bool) bool { - callback(&apigateway.GetDomainNamesOutput{ - Items: domainNames[:3], - }, false) - callback(&apigateway.GetDomainNamesOutput{ - Items: domainNames[3:], - }, true) - return true - })).Return(nil).Once() - - store.On("GetAndLock", "apigatewayListAllDomainNames").Return(nil).Times(1) - store.On("Unlock", "apigatewayListAllDomainNames").Times(1) - store.On("Put", "apigatewayListAllDomainNames", domainNames).Return(false).Times(1) - }, - want: domainNames, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - store.On("GetAndLock", "apigatewayListAllDomainNames").Return(domainNames).Times(1) - store.On("Unlock", "apigatewayListAllDomainNames").Times(1) - }, - want: domainNames, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGateway{} - tt.mocks(client, store) - r := &apigatewayRepository{ - client: client, - cache: store, - } - got, err := r.ListAllDomainNames() - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayRepository_ListAllVpcLinks(t *testing.T) { - vpcLinks := []*apigateway.UpdateVpcLinkOutput{ - {Id: aws.String("vpcLink1")}, - {Id: aws.String("vpcLink2")}, - {Id: aws.String("vpcLink3")}, - {Id: aws.String("vpcLink4")}, - {Id: aws.String("vpcLink5")}, - {Id: aws.String("vpcLink6")}, - } - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) - want []*apigateway.UpdateVpcLinkOutput - wantErr error - }{ - { - name: "list multiple vpc links", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetVpcLinksPages", - &apigateway.GetVpcLinksInput{}, - mock.MatchedBy(func(callback func(res *apigateway.GetVpcLinksOutput, lastPage bool) bool) bool { - callback(&apigateway.GetVpcLinksOutput{ - Items: vpcLinks[:3], - }, false) - callback(&apigateway.GetVpcLinksOutput{ - Items: vpcLinks[3:], - }, true) - return true - })).Return(nil).Once() - - store.On("Get", "apigatewayListAllVpcLinks").Return(nil).Times(1) - store.On("Put", "apigatewayListAllVpcLinks", vpcLinks).Return(false).Times(1) - }, - want: vpcLinks, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - store.On("Get", "apigatewayListAllVpcLinks").Return(vpcLinks).Times(1) - }, - want: vpcLinks, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGateway{} - tt.mocks(client, store) - r := &apigatewayRepository{ - client: client, - cache: store, - } - got, err := r.ListAllVpcLinks() - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayRepository_ListAllRestApiRequestValidators(t *testing.T) { - api := &apigateway.RestApi{ - Id: aws.String("restapi1"), - } - - requestValidators := []*apigateway.UpdateRequestValidatorOutput{ - {Id: aws.String("reqVal1")}, - {Id: aws.String("reqVal2")}, - {Id: aws.String("reqVal3")}, - {Id: aws.String("reqVal4")}, - } - - remoteError := errors.New("remote error") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) - want []*apigateway.UpdateRequestValidatorOutput - wantErr error - }{ - { - name: "list multiple rest api request validators", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetRequestValidators", - &apigateway.GetRequestValidatorsInput{ - RestApiId: aws.String("restapi1"), - }).Return(&apigateway.GetRequestValidatorsOutput{Items: requestValidators}, nil).Once() - - store.On("Get", "apigatewayListAllRestApiRequestValidators_api_restapi1").Return(nil).Times(1) - store.On("Put", "apigatewayListAllRestApiRequestValidators_api_restapi1", requestValidators).Return(false).Times(1) - }, - want: requestValidators, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - store.On("Get", "apigatewayListAllRestApiRequestValidators_api_restapi1").Return(requestValidators).Times(1) - }, - want: requestValidators, - }, - { - name: "should return remote error", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetRequestValidators", - &apigateway.GetRequestValidatorsInput{ - RestApiId: aws.String("restapi1"), - }).Return(nil, remoteError).Once() - - store.On("Get", "apigatewayListAllRestApiRequestValidators_api_restapi1").Return(nil).Times(1) - }, - wantErr: remoteError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGateway{} - tt.mocks(client, store) - r := &apigatewayRepository{ - client: client, - cache: store, - } - got, err := r.ListAllRestApiRequestValidators(*api.Id) - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayRepository_ListAllDomainNameBasePathMappings(t *testing.T) { - domainName := &apigateway.DomainName{ - DomainName: aws.String("domainName1"), - } - - mappings := []*apigateway.BasePathMapping{ - {BasePath: aws.String("path1")}, - {BasePath: aws.String("path2")}, - {BasePath: aws.String("path3")}, - {BasePath: aws.String("path4")}, - } - - remoteError := errors.New("remote error") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) - want []*apigateway.BasePathMapping - wantErr error - }{ - { - name: "list multiple domain name base path mappings", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetBasePathMappingsPages", - &apigateway.GetBasePathMappingsInput{ - DomainName: aws.String("domainName1"), - }, - mock.MatchedBy(func(callback func(res *apigateway.GetBasePathMappingsOutput, lastPage bool) bool) bool { - callback(&apigateway.GetBasePathMappingsOutput{ - Items: mappings, - }, true) - return true - })).Return(nil).Once() - - store.On("Get", "apigatewayListAllDomainNameBasePathMappings_domainName_domainName1").Return(nil).Times(1) - store.On("Put", "apigatewayListAllDomainNameBasePathMappings_domainName_domainName1", mappings).Return(false).Times(1) - }, - want: mappings, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - store.On("Get", "apigatewayListAllDomainNameBasePathMappings_domainName_domainName1").Return(mappings).Times(1) - }, - want: mappings, - }, - { - name: "should return remote error", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetBasePathMappingsPages", - &apigateway.GetBasePathMappingsInput{ - DomainName: aws.String("domainName1"), - }, mock.AnythingOfType("func(*apigateway.GetBasePathMappingsOutput, bool) bool")).Return(remoteError).Once() - - store.On("Get", "apigatewayListAllDomainNameBasePathMappings_domainName_domainName1").Return(nil).Times(1) - }, - wantErr: remoteError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGateway{} - tt.mocks(client, store) - r := &apigatewayRepository{ - client: client, - cache: store, - } - got, err := r.ListAllDomainNameBasePathMappings(*domainName.DomainName) - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayRepository_ListAllRestApiModels(t *testing.T) { - api := &apigateway.RestApi{ - Id: aws.String("restapi1"), - } - - apiModels := []*apigateway.Model{ - {Id: aws.String("model1")}, - {Id: aws.String("model2")}, - {Id: aws.String("model3")}, - {Id: aws.String("model4")}, - } - - remoteError := errors.New("remote error") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) - want []*apigateway.Model - wantErr error - }{ - { - name: "list multiple rest api models", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetModelsPages", - &apigateway.GetModelsInput{ - RestApiId: aws.String("restapi1"), - }, - mock.MatchedBy(func(callback func(res *apigateway.GetModelsOutput, lastPage bool) bool) bool { - callback(&apigateway.GetModelsOutput{ - Items: apiModels, - }, true) - return true - })).Return(nil).Once() - - store.On("Get", "apigatewayListAllRestApiModels_api_restapi1").Return(nil).Times(1) - store.On("Put", "apigatewayListAllRestApiModels_api_restapi1", apiModels).Return(false).Times(1) - }, - want: apiModels, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - store.On("Get", "apigatewayListAllRestApiModels_api_restapi1").Return(apiModels).Times(1) - }, - want: apiModels, - }, - { - name: "should return remote error", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetModelsPages", - &apigateway.GetModelsInput{ - RestApiId: aws.String("restapi1"), - }, mock.AnythingOfType("func(*apigateway.GetModelsOutput, bool) bool")).Return(remoteError).Once() - - store.On("Get", "apigatewayListAllRestApiModels_api_restapi1").Return(nil).Times(1) - }, - wantErr: remoteError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGateway{} - tt.mocks(client, store) - r := &apigatewayRepository{ - client: client, - cache: store, - } - got, err := r.ListAllRestApiModels(*api.Id) - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayRepository_ListAllRestApiGatewayResponses(t *testing.T) { - api := &apigateway.RestApi{ - Id: aws.String("restapi1"), - } - - gtwResponses := []*apigateway.UpdateGatewayResponseOutput{ - {ResponseType: aws.String("ACCESS_DENIED")}, - {ResponseType: aws.String("DEFAULT_4XX")}, - {ResponseType: aws.String("DEFAULT_5XX")}, - {ResponseType: aws.String("UNAUTHORIZED")}, - } - - remoteError := errors.New("remote error") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache) - want []*apigateway.UpdateGatewayResponseOutput - wantErr error - }{ - { - name: "list multiple rest api gateway responses", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetGatewayResponses", - &apigateway.GetGatewayResponsesInput{ - RestApiId: aws.String("restapi1"), - }).Return(&apigateway.GetGatewayResponsesOutput{Items: gtwResponses}, nil).Once() - - store.On("Get", "apigatewayListAllRestApiGatewayResponses_api_restapi1").Return(nil).Times(1) - store.On("Put", "apigatewayListAllRestApiGatewayResponses_api_restapi1", gtwResponses).Return(false).Times(1) - }, - want: gtwResponses, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - store.On("Get", "apigatewayListAllRestApiGatewayResponses_api_restapi1").Return(gtwResponses).Times(1) - }, - want: gtwResponses, - }, - { - name: "should return remote error", - mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) { - client.On("GetGatewayResponses", - &apigateway.GetGatewayResponsesInput{ - RestApiId: aws.String("restapi1"), - }).Return(nil, remoteError).Once() - - store.On("Get", "apigatewayListAllRestApiGatewayResponses_api_restapi1").Return(nil).Times(1) - }, - wantErr: remoteError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGateway{} - tt.mocks(client, store) - r := &apigatewayRepository{ - client: client, - cache: store, - } - got, err := r.ListAllRestApiGatewayResponses(*api.Id) - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} diff --git a/pkg/remote/aws/repository/apigatewayv2_repository.go b/pkg/remote/aws/repository/apigatewayv2_repository.go deleted file mode 100644 index ad528682..00000000 --- a/pkg/remote/aws/repository/apigatewayv2_repository.go +++ /dev/null @@ -1,228 +0,0 @@ -package repository - -import ( - "fmt" - - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/apigatewayv2" - "github.com/aws/aws-sdk-go/service/apigatewayv2/apigatewayv2iface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type ApiGatewayV2Repository interface { - ListAllApis() ([]*apigatewayv2.Api, error) - ListAllApiRoutes(apiId *string) ([]*apigatewayv2.Route, error) - ListAllApiDeployments(apiId *string) ([]*apigatewayv2.Deployment, error) - ListAllVpcLinks() ([]*apigatewayv2.VpcLink, error) - ListAllApiAuthorizers(string) ([]*apigatewayv2.Authorizer, error) - ListAllApiIntegrations(string) ([]*apigatewayv2.Integration, error) - ListAllApiModels(string) ([]*apigatewayv2.Model, error) - ListAllApiStages(string) ([]*apigatewayv2.Stage, error) - ListAllApiRouteResponses(string, string) ([]*apigatewayv2.RouteResponse, error) - ListAllApiMappings(string) ([]*apigatewayv2.ApiMapping, error) - ListAllApiIntegrationResponses(string, string) ([]*apigatewayv2.IntegrationResponse, error) -} -type apigatewayv2Repository struct { - client apigatewayv2iface.ApiGatewayV2API - cache cache.Cache -} - -func NewApiGatewayV2Repository(session *session.Session, c cache.Cache) *apigatewayv2Repository { - return &apigatewayv2Repository{ - apigatewayv2.New(session), - c, - } -} - -func (r *apigatewayv2Repository) ListAllApis() ([]*apigatewayv2.Api, error) { - cacheKey := "apigatewayv2ListAllApis" - v := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - if v != nil { - return v.([]*apigatewayv2.Api), nil - } - - input := apigatewayv2.GetApisInput{} - resources, err := r.client.GetApis(&input) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resources.Items) - return resources.Items, nil -} - -func (r *apigatewayv2Repository) ListAllApiRoutes(apiID *string) ([]*apigatewayv2.Route, error) { - cacheKey := fmt.Sprintf("apigatewayv2ListAllApiRoutes_api_%s", *apiID) - v := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - if v != nil { - return v.([]*apigatewayv2.Route), nil - } - - resources, err := r.client.GetRoutes(&apigatewayv2.GetRoutesInput{ApiId: apiID}) - if err != nil { - return nil, err - } - r.cache.Put(cacheKey, resources.Items) - return resources.Items, nil -} - -func (r *apigatewayv2Repository) ListAllApiDeployments(apiID *string) ([]*apigatewayv2.Deployment, error) { - cacheKey := fmt.Sprintf("apigatewayv2ListAllApiDeployments_api_%s", *apiID) - v := r.cache.Get(cacheKey) - - if v != nil { - return v.([]*apigatewayv2.Deployment), nil - } - - resources, err := r.client.GetDeployments(&apigatewayv2.GetDeploymentsInput{ApiId: apiID}) - if err != nil { - return nil, err - } - r.cache.Put(cacheKey, resources.Items) - return resources.Items, nil -} - -func (r *apigatewayv2Repository) ListAllVpcLinks() ([]*apigatewayv2.VpcLink, error) { - if v := r.cache.Get("apigatewayv2ListAllVpcLinks"); v != nil { - return v.([]*apigatewayv2.VpcLink), nil - } - - input := apigatewayv2.GetVpcLinksInput{} - resources, err := r.client.GetVpcLinks(&input) - if err != nil { - return nil, err - } - - r.cache.Put("apigatewayv2ListAllVpcLinks", resources.Items) - return resources.Items, nil -} - -func (r *apigatewayv2Repository) ListAllApiAuthorizers(apiId string) ([]*apigatewayv2.Authorizer, error) { - cacheKey := fmt.Sprintf("apigatewayv2ListAllApiAuthorizers_api_%s", apiId) - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*apigatewayv2.Authorizer), nil - } - - input := apigatewayv2.GetAuthorizersInput{ - ApiId: &apiId, - } - resources, err := r.client.GetAuthorizers(&input) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resources.Items) - return resources.Items, nil -} - -func (r *apigatewayv2Repository) ListAllApiIntegrations(apiId string) ([]*apigatewayv2.Integration, error) { - cacheKey := fmt.Sprintf("apigatewayv2ListAllApiIntegrations_api_%s", apiId) - - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*apigatewayv2.Integration), nil - } - - input := apigatewayv2.GetIntegrationsInput{ - ApiId: &apiId, - } - resources, err := r.client.GetIntegrations(&input) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resources.Items) - return resources.Items, nil -} - -func (r *apigatewayv2Repository) ListAllApiModels(apiId string) ([]*apigatewayv2.Model, error) { - cacheKey := fmt.Sprintf("apigatewayv2ListAllApiModels_api_%s", apiId) - - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*apigatewayv2.Model), nil - } - - input := apigatewayv2.GetModelsInput{ - ApiId: &apiId, - } - resources, err := r.client.GetModels(&input) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resources.Items) - return resources.Items, nil -} - -func (r *apigatewayv2Repository) ListAllApiStages(apiId string) ([]*apigatewayv2.Stage, error) { - cacheKey := fmt.Sprintf("apigatewayv2ListAllApiStages_api_%s", apiId) - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*apigatewayv2.Stage), nil - } - - input := apigatewayv2.GetStagesInput{ - ApiId: &apiId, - } - resources, err := r.client.GetStages(&input) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resources.Items) - return resources.Items, nil -} - -func (r *apigatewayv2Repository) ListAllApiIntegrationResponses(apiId, integrationId string) ([]*apigatewayv2.IntegrationResponse, error) { - cacheKey := fmt.Sprintf("apigatewayv2ListAllApiIntegrationResponses_api_%s_integration_%s", apiId, integrationId) - v := r.cache.Get(cacheKey) - if v != nil { - return v.([]*apigatewayv2.IntegrationResponse), nil - } - input := apigatewayv2.GetIntegrationResponsesInput{ - ApiId: &apiId, - IntegrationId: &integrationId, - } - resources, err := r.client.GetIntegrationResponses(&input) - if err != nil { - return nil, err - } - r.cache.Put(cacheKey, resources.Items) - return resources.Items, nil -} - -func (r *apigatewayv2Repository) ListAllApiRouteResponses(apiId, routeId string) ([]*apigatewayv2.RouteResponse, error) { - cacheKey := fmt.Sprintf("apigatewayv2ListAllApiRouteResponses_api_%s_route_%s", apiId, routeId) - v := r.cache.Get(cacheKey) - if v != nil { - return v.([]*apigatewayv2.RouteResponse), nil - } - input := apigatewayv2.GetRouteResponsesInput{ - ApiId: &apiId, - RouteId: &routeId, - } - resources, err := r.client.GetRouteResponses(&input) - if err != nil { - return nil, err - } - r.cache.Put(cacheKey, resources.Items) - return resources.Items, nil -} - -func (r *apigatewayv2Repository) ListAllApiMappings(domainName string) ([]*apigatewayv2.ApiMapping, error) { - cacheKey := fmt.Sprintf("apigatewayv2ListAllApiMappings_api_%s", domainName) - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*apigatewayv2.ApiMapping), nil - } - - input := apigatewayv2.GetApiMappingsInput{ - DomainName: &domainName, - } - resources, err := r.client.GetApiMappings(&input) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resources.Items) - return resources.Items, nil -} diff --git a/pkg/remote/aws/repository/apigatewayv2_repository_test.go b/pkg/remote/aws/repository/apigatewayv2_repository_test.go deleted file mode 100644 index f459ed03..00000000 --- a/pkg/remote/aws/repository/apigatewayv2_repository_test.go +++ /dev/null @@ -1,637 +0,0 @@ -package repository - -import ( - "strings" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/apigatewayv2" - "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" -) - -func Test_apigatewayv2Repository_ListAllApis(t *testing.T) { - apis := []*apigatewayv2.Api{ - {ApiId: aws.String("api1")}, - {ApiId: aws.String("api2")}, - {ApiId: aws.String("api3")}, - {ApiId: aws.String("api4")}, - {ApiId: aws.String("api5")}, - {ApiId: aws.String("api6")}, - } - - remoteError := errors.New("remote error") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) - want []*apigatewayv2.Api - wantErr error - }{ - { - name: "list multiple apis", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetApis", - &apigatewayv2.GetApisInput{}).Return(&apigatewayv2.GetApisOutput{Items: apis}, nil).Once() - - store.On("GetAndLock", "apigatewayv2ListAllApis").Return(nil).Times(1) - store.On("Unlock", "apigatewayv2ListAllApis").Times(1) - store.On("Put", "apigatewayv2ListAllApis", apis).Return(false).Times(1) - }, - want: apis, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - store.On("GetAndLock", "apigatewayv2ListAllApis").Return(apis).Times(1) - store.On("Unlock", "apigatewayv2ListAllApis").Times(1) - }, - want: apis, - }, - { - name: "should return remote error", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetApis", - &apigatewayv2.GetApisInput{}).Return(nil, remoteError).Once() - - store.On("GetAndLock", "apigatewayv2ListAllApis").Return(nil).Times(1) - store.On("Unlock", "apigatewayv2ListAllApis").Times(1) - }, - wantErr: remoteError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGatewayV2{} - tt.mocks(client, store) - r := &apigatewayv2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllApis() - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayv2Repository_ListAllApiRoutes(t *testing.T) { - routes := []*apigatewayv2.Route{ - {RouteId: aws.String("route1")}, - {RouteId: aws.String("route2")}, - {RouteId: aws.String("route3")}, - } - - remoteError := errors.New("remote error") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) - want []*apigatewayv2.Route - wantErr error - }{ - { - name: "list multiple routes", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetRoutes", - &apigatewayv2.GetRoutesInput{ApiId: aws.String("an-id")}). - Return(&apigatewayv2.GetRoutesOutput{Items: routes}, nil).Once() - - store.On("GetAndLock", "apigatewayv2ListAllApiRoutes_api_an-id").Return(nil).Times(1) - store.On("Unlock", "apigatewayv2ListAllApiRoutes_api_an-id").Times(1) - store.On("Put", "apigatewayv2ListAllApiRoutes_api_an-id", routes).Return(false).Times(1) - }, - want: routes, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - store.On("GetAndLock", "apigatewayv2ListAllApiRoutes_api_an-id").Return(routes).Times(1) - store.On("Unlock", "apigatewayv2ListAllApiRoutes_api_an-id").Times(1) - }, - want: routes, - }, - { - name: "should return remote error", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetRoutes", - &apigatewayv2.GetRoutesInput{ApiId: aws.String("an-id")}).Return(nil, remoteError).Once() - - store.On("GetAndLock", "apigatewayv2ListAllApiRoutes_api_an-id").Return(nil).Times(1) - store.On("Unlock", "apigatewayv2ListAllApiRoutes_api_an-id").Times(1) - }, - wantErr: remoteError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGatewayV2{} - tt.mocks(client, store) - r := &apigatewayv2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllApiRoutes(aws.String("an-id")) - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayv2Repository_ListAllApiDeployments(t *testing.T) { - deployments := []*apigatewayv2.Deployment{ - {DeploymentId: aws.String("id1")}, - {DeploymentId: aws.String("id2")}, - {DeploymentId: aws.String("id3")}, - } - - remoteError := errors.New("remote error") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) - want []*apigatewayv2.Deployment - wantErr error - }{ - { - name: "list multiple deployments", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetDeployments", - &apigatewayv2.GetDeploymentsInput{ApiId: aws.String("an-id")}). - Return(&apigatewayv2.GetDeploymentsOutput{Items: deployments}, nil).Once() - - store.On("Get", "apigatewayv2ListAllApiDeployments_api_an-id").Return(nil).Times(1) - store.On("Put", "apigatewayv2ListAllApiDeployments_api_an-id", deployments).Return(false).Times(1) - }, - want: deployments, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - store.On("Get", "apigatewayv2ListAllApiDeployments_api_an-id").Return(deployments).Times(1) - }, - want: deployments, - }, - { - name: "should return remote error", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetDeployments", - &apigatewayv2.GetDeploymentsInput{ApiId: aws.String("an-id")}).Return(nil, remoteError).Once() - - store.On("Get", "apigatewayv2ListAllApiDeployments_api_an-id").Return(nil).Times(1) - }, - wantErr: remoteError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGatewayV2{} - tt.mocks(client, store) - r := &apigatewayv2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllApiDeployments(aws.String("an-id")) - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayv2Repository_ListAllVpcLinks(t *testing.T) { - vpcLinks := []*apigatewayv2.VpcLink{ - {VpcLinkId: aws.String("vpcLink1")}, - {VpcLinkId: aws.String("vpcLink2")}, - {VpcLinkId: aws.String("vpcLink3")}, - {VpcLinkId: aws.String("vpcLink4")}, - {VpcLinkId: aws.String("vpcLink5")}, - {VpcLinkId: aws.String("vpcLink6")}, - } - - remoteError := errors.New("remote error") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) - want []*apigatewayv2.VpcLink - wantErr error - }{ - { - name: "list multiple vpc links", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetVpcLinks", - &apigatewayv2.GetVpcLinksInput{}).Return(&apigatewayv2.GetVpcLinksOutput{Items: vpcLinks}, nil).Once() - - store.On("Get", "apigatewayv2ListAllVpcLinks").Return(nil).Times(1) - store.On("Put", "apigatewayv2ListAllVpcLinks", vpcLinks).Return(false).Times(1) - }, - want: vpcLinks, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - store.On("Get", "apigatewayv2ListAllVpcLinks").Return(vpcLinks).Times(1) - }, - want: vpcLinks, - }, - { - name: "should return remote error", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetVpcLinks", - &apigatewayv2.GetVpcLinksInput{}).Return(nil, remoteError).Once() - - store.On("Get", "apigatewayv2ListAllVpcLinks").Return(nil).Times(1) - }, - wantErr: remoteError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGatewayV2{} - tt.mocks(client, store) - r := &apigatewayv2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllVpcLinks() - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayv2Repository_ListAllApiAuthorizers(t *testing.T) { - api := &apigatewayv2.Api{ - ApiId: aws.String("api1"), - } - - apiAuthorizers := []*apigatewayv2.Authorizer{ - {AuthorizerId: aws.String("authorizer1")}, - {AuthorizerId: aws.String("authorizer2")}, - {AuthorizerId: aws.String("authorizer3")}, - {AuthorizerId: aws.String("authorizer4")}, - } - - remoteError := errors.New("remote error") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) - want []*apigatewayv2.Authorizer - wantErr error - }{ - { - name: "list multiple api authorizers", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetAuthorizers", - &apigatewayv2.GetAuthorizersInput{ - ApiId: aws.String("api1"), - }).Return(&apigatewayv2.GetAuthorizersOutput{Items: apiAuthorizers}, nil).Once() - - store.On("Get", "apigatewayv2ListAllApiAuthorizers_api_api1").Return(nil).Times(1) - store.On("Put", "apigatewayv2ListAllApiAuthorizers_api_api1", apiAuthorizers).Return(false).Times(1) - }, - want: apiAuthorizers, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - store.On("Get", "apigatewayv2ListAllApiAuthorizers_api_api1").Return(apiAuthorizers).Times(1) - }, - want: apiAuthorizers, - }, - { - name: "should return remote error", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetAuthorizers", - &apigatewayv2.GetAuthorizersInput{ - ApiId: aws.String("api1"), - }).Return(nil, remoteError).Once() - - store.On("Get", "apigatewayv2ListAllApiAuthorizers_api_api1").Return(nil).Times(1) - }, - wantErr: remoteError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGatewayV2{} - tt.mocks(client, store) - r := &apigatewayv2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllApiAuthorizers(*api.ApiId) - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayv2Repository_ListAllApiIntegrations(t *testing.T) { - api := &apigatewayv2.Api{ - ApiId: aws.String("api1"), - } - - apiIntegrations := []*apigatewayv2.Integration{ - {IntegrationId: aws.String("integration1")}, - {IntegrationId: aws.String("integration2")}, - {IntegrationId: aws.String("integration3")}, - {IntegrationId: aws.String("integration4")}, - } - - remoteError := errors.New("remote error") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) - want []*apigatewayv2.Integration - wantErr error - }{ - { - name: "list multiple api integrations", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetIntegrations", - &apigatewayv2.GetIntegrationsInput{ - ApiId: aws.String("api1"), - }).Return(&apigatewayv2.GetIntegrationsOutput{Items: apiIntegrations}, nil).Once() - - store.On("Get", "apigatewayv2ListAllApiIntegrations_api_api1").Return(nil).Times(1) - store.On("Put", "apigatewayv2ListAllApiIntegrations_api_api1", apiIntegrations).Return(false).Times(1) - }, - want: apiIntegrations, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - store.On("Get", "apigatewayv2ListAllApiIntegrations_api_api1").Return(apiIntegrations).Times(1) - }, - want: apiIntegrations, - }, - { - name: "should return remote error", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetIntegrations", - &apigatewayv2.GetIntegrationsInput{ - ApiId: aws.String("api1"), - }).Return(nil, remoteError).Once() - - store.On("Get", "apigatewayv2ListAllApiIntegrations_api_api1").Return(nil).Times(1) - }, - wantErr: remoteError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGatewayV2{} - tt.mocks(client, store) - r := &apigatewayv2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllApiIntegrations(*api.ApiId) - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayv2Repository_ListAllApiRouteResponses(t *testing.T) { - api := &apigatewayv2.Api{ - ApiId: aws.String("api1"), - } - - route := &apigatewayv2.Route{ - RouteId: aws.String("route1"), - } - - responses := []*apigatewayv2.RouteResponse{ - {RouteResponseId: aws.String("response1")}, - {RouteResponseId: aws.String("response2")}, - {RouteResponseId: aws.String("response3")}, - {RouteResponseId: aws.String("response4")}, - } - - remoteError := errors.New("remote error") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) - want []*apigatewayv2.RouteResponse - wantErr error - }{ - { - name: "list multiple api route responses", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetRouteResponses", - &apigatewayv2.GetRouteResponsesInput{ - ApiId: aws.String("api1"), - RouteId: aws.String("route1"), - }).Return(&apigatewayv2.GetRouteResponsesOutput{Items: responses}, nil).Once() - - store.On("Get", "apigatewayv2ListAllApiRouteResponses_api_api1_route_route1").Return(nil).Times(1) - store.On("Put", "apigatewayv2ListAllApiRouteResponses_api_api1_route_route1", responses).Return(false).Times(1) - }, - want: responses, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - store.On("Get", "apigatewayv2ListAllApiRouteResponses_api_api1_route_route1").Return(responses).Times(1) - }, - want: responses, - }, - { - name: "should return remote error", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetRouteResponses", - &apigatewayv2.GetRouteResponsesInput{ - ApiId: aws.String("api1"), - RouteId: aws.String("route1"), - }).Return(nil, remoteError).Once() - - store.On("Get", "apigatewayv2ListAllApiRouteResponses_api_api1_route_route1").Return(nil).Times(1) - }, - wantErr: remoteError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGatewayV2{} - tt.mocks(client, store) - r := &apigatewayv2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllApiRouteResponses(*api.ApiId, *route.RouteId) - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} - -func Test_apigatewayv2Repository_ListAllApiIntegrationResponses(t *testing.T) { - api := &apigatewayv2.Api{ - ApiId: aws.String("api1"), - } - - integration := &apigatewayv2.Integration{ - IntegrationId: aws.String("integration1"), - } - - responses := []*apigatewayv2.IntegrationResponse{ - {IntegrationResponseId: aws.String("response1")}, - {IntegrationResponseId: aws.String("response2")}, - {IntegrationResponseId: aws.String("response3")}, - {IntegrationResponseId: aws.String("response4")}, - } - - remoteError := errors.New("remote error") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) - want []*apigatewayv2.IntegrationResponse - wantErr error - }{ - { - name: "list multiple api integration responses", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetIntegrationResponses", - &apigatewayv2.GetIntegrationResponsesInput{ - ApiId: aws.String("api1"), - IntegrationId: aws.String("integration1"), - }).Return(&apigatewayv2.GetIntegrationResponsesOutput{Items: responses}, nil).Once() - - store.On("Get", "apigatewayv2ListAllApiIntegrationResponses_api_api1_integration_integration1").Return(nil).Times(1) - store.On("Put", "apigatewayv2ListAllApiIntegrationResponses_api_api1_integration_integration1", responses).Return(false).Times(1) - }, - want: responses, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - store.On("Get", "apigatewayv2ListAllApiIntegrationResponses_api_api1_integration_integration1").Return(responses).Times(1) - }, - want: responses, - }, - { - name: "should return remote error", - mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) { - client.On("GetIntegrationResponses", - &apigatewayv2.GetIntegrationResponsesInput{ - ApiId: aws.String("api1"), - IntegrationId: aws.String("integration1"), - }).Return(nil, remoteError).Once() - - store.On("Get", "apigatewayv2ListAllApiIntegrationResponses_api_api1_integration_integration1").Return(nil).Times(1) - }, - wantErr: remoteError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApiGatewayV2{} - tt.mocks(client, store) - r := &apigatewayv2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllApiIntegrationResponses(*api.ApiId, *integration.IntegrationId) - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} diff --git a/pkg/remote/aws/repository/appautoscaling_repository.go b/pkg/remote/aws/repository/appautoscaling_repository.go deleted file mode 100644 index 47fdaea8..00000000 --- a/pkg/remote/aws/repository/appautoscaling_repository.go +++ /dev/null @@ -1,87 +0,0 @@ -package repository - -import ( - "fmt" - - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/applicationautoscaling" - "github.com/aws/aws-sdk-go/service/applicationautoscaling/applicationautoscalingiface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type AppAutoScalingRepository interface { - ServiceNamespaceValues() []string - DescribeScalableTargets(string) ([]*applicationautoscaling.ScalableTarget, error) - DescribeScalingPolicies(string) ([]*applicationautoscaling.ScalingPolicy, error) - DescribeScheduledActions(string) ([]*applicationautoscaling.ScheduledAction, error) -} - -type appAutoScalingRepository struct { - client applicationautoscalingiface.ApplicationAutoScalingAPI - cache cache.Cache -} - -func NewAppAutoScalingRepository(session *session.Session, c cache.Cache) *appAutoScalingRepository { - return &appAutoScalingRepository{ - applicationautoscaling.New(session), - c, - } -} - -func (r *appAutoScalingRepository) ServiceNamespaceValues() []string { - return applicationautoscaling.ServiceNamespace_Values() -} - -func (r *appAutoScalingRepository) DescribeScalableTargets(namespace string) ([]*applicationautoscaling.ScalableTarget, error) { - cacheKey := fmt.Sprintf("appAutoScalingDescribeScalableTargets_%s", namespace) - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*applicationautoscaling.ScalableTarget), nil - } - - input := &applicationautoscaling.DescribeScalableTargetsInput{ - ServiceNamespace: &namespace, - } - result, err := r.client.DescribeScalableTargets(input) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, result.ScalableTargets) - return result.ScalableTargets, nil -} - -func (r *appAutoScalingRepository) DescribeScalingPolicies(namespace string) ([]*applicationautoscaling.ScalingPolicy, error) { - cacheKey := fmt.Sprintf("appAutoScalingDescribeScalingPolicies_%s", namespace) - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*applicationautoscaling.ScalingPolicy), nil - } - - input := &applicationautoscaling.DescribeScalingPoliciesInput{ - ServiceNamespace: &namespace, - } - result, err := r.client.DescribeScalingPolicies(input) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, result.ScalingPolicies) - return result.ScalingPolicies, nil -} - -func (r *appAutoScalingRepository) DescribeScheduledActions(namespace string) ([]*applicationautoscaling.ScheduledAction, error) { - cacheKey := fmt.Sprintf("appAutoScalingDescribeScheduledActions_%s", namespace) - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*applicationautoscaling.ScheduledAction), nil - } - - input := &applicationautoscaling.DescribeScheduledActionsInput{ - ServiceNamespace: &namespace, - } - result, err := r.client.DescribeScheduledActions(input) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, result.ScheduledActions) - return result.ScheduledActions, nil -} diff --git a/pkg/remote/aws/repository/appautoscaling_repository_test.go b/pkg/remote/aws/repository/appautoscaling_repository_test.go deleted file mode 100644 index 88d87138..00000000 --- a/pkg/remote/aws/repository/appautoscaling_repository_test.go +++ /dev/null @@ -1,342 +0,0 @@ -package repository - -import ( - "strings" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/applicationautoscaling" - "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" -) - -func Test_appautoscalingRepository_DescribeScalableTargets(t *testing.T) { - type args struct { - namespace string - } - - tests := []struct { - name string - args args - mocks func(*awstest.MockFakeApplicationAutoScaling, *cache.MockCache) - want []*applicationautoscaling.ScalableTarget - wantErr error - }{ - { - name: "should return remote error", - args: args{ - namespace: "test", - }, - mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { - client.On("DescribeScalableTargets", - &applicationautoscaling.DescribeScalableTargetsInput{ - ServiceNamespace: aws.String("test"), - }).Return(nil, errors.New("remote error")).Once() - - c.On("Get", "appAutoScalingDescribeScalableTargets_test").Return(nil).Once() - }, - want: nil, - wantErr: errors.New("remote error"), - }, - { - name: "should return scalable targets", - args: args{ - namespace: "test", - }, - mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { - results := []*applicationautoscaling.ScalableTarget{ - { - RoleARN: aws.String("test_target"), - }, - } - - client.On("DescribeScalableTargets", - &applicationautoscaling.DescribeScalableTargetsInput{ - ServiceNamespace: aws.String("test"), - }).Return(&applicationautoscaling.DescribeScalableTargetsOutput{ - ScalableTargets: results, - }, nil).Once() - - c.On("Get", "appAutoScalingDescribeScalableTargets_test").Return(nil).Once() - c.On("Put", "appAutoScalingDescribeScalableTargets_test", results).Return(true).Once() - }, - want: []*applicationautoscaling.ScalableTarget{ - { - RoleARN: aws.String("test_target"), - }, - }, - }, - { - name: "should hit cache return scalable targets", - args: args{ - namespace: "test", - }, - mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { - results := []*applicationautoscaling.ScalableTarget{ - { - RoleARN: aws.String("test_target"), - }, - } - - c.On("Get", "appAutoScalingDescribeScalableTargets_test").Return(results).Once() - }, - want: []*applicationautoscaling.ScalableTarget{ - { - RoleARN: aws.String("test_target"), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApplicationAutoScaling{} - tt.mocks(client, store) - - r := &appAutoScalingRepository{ - client: client, - cache: store, - } - got, err := r.DescribeScalableTargets(tt.args.namespace) - if err != nil { - assert.EqualError(t, tt.wantErr, err.Error()) - } else { - assert.Equal(t, tt.wantErr, err) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - - client.AssertExpectations(t) - store.AssertExpectations(t) - }) - } -} - -func Test_appautoscalingRepository_DescribeScalingPolicies(t *testing.T) { - type args struct { - namespace string - } - - tests := []struct { - name string - args args - mocks func(*awstest.MockFakeApplicationAutoScaling, *cache.MockCache) - want []*applicationautoscaling.ScalingPolicy - wantErr error - }{ - { - name: "should return remote error", - args: args{ - namespace: "test", - }, - mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { - client.On("DescribeScalingPolicies", - &applicationautoscaling.DescribeScalingPoliciesInput{ - ServiceNamespace: aws.String("test"), - }).Return(nil, errors.New("remote error")).Once() - - c.On("Get", "appAutoScalingDescribeScalingPolicies_test").Return(nil).Once() - }, - want: nil, - wantErr: errors.New("remote error"), - }, - { - name: "should return scaling policies", - args: args{ - namespace: "test", - }, - mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { - results := []*applicationautoscaling.ScalingPolicy{ - { - PolicyARN: aws.String("test_policy"), - }, - } - - client.On("DescribeScalingPolicies", - &applicationautoscaling.DescribeScalingPoliciesInput{ - ServiceNamespace: aws.String("test"), - }).Return(&applicationautoscaling.DescribeScalingPoliciesOutput{ - ScalingPolicies: results, - }, nil).Once() - - c.On("Get", "appAutoScalingDescribeScalingPolicies_test").Return(nil).Once() - c.On("Put", "appAutoScalingDescribeScalingPolicies_test", results).Return(true).Once() - }, - want: []*applicationautoscaling.ScalingPolicy{ - { - PolicyARN: aws.String("test_policy"), - }, - }, - }, - { - name: "should hit cache return scaling policies", - args: args{ - namespace: "test", - }, - mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { - results := []*applicationautoscaling.ScalingPolicy{ - { - PolicyARN: aws.String("test_policy"), - }, - } - - c.On("Get", "appAutoScalingDescribeScalingPolicies_test").Return(results).Once() - }, - want: []*applicationautoscaling.ScalingPolicy{ - { - PolicyARN: aws.String("test_policy"), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApplicationAutoScaling{} - tt.mocks(client, store) - - r := &appAutoScalingRepository{ - client: client, - cache: store, - } - got, err := r.DescribeScalingPolicies(tt.args.namespace) - if err != nil { - assert.EqualError(t, tt.wantErr, err.Error()) - } else { - assert.Equal(t, tt.wantErr, err) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - - client.AssertExpectations(t) - store.AssertExpectations(t) - }) - } -} - -func Test_appautoscalingRepository_DescribeScheduledActions(t *testing.T) { - type args struct { - namespace string - } - - tests := []struct { - name string - args args - mocks func(*awstest.MockFakeApplicationAutoScaling, *cache.MockCache) - want []*applicationautoscaling.ScheduledAction - wantErr error - }{ - { - name: "should return remote error", - args: args{ - namespace: "test", - }, - mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { - client.On("DescribeScheduledActions", - &applicationautoscaling.DescribeScheduledActionsInput{ - ServiceNamespace: aws.String("test"), - }).Return(nil, errors.New("remote error")).Once() - - c.On("Get", "appAutoScalingDescribeScheduledActions_test").Return(nil).Once() - }, - want: nil, - wantErr: errors.New("remote error"), - }, - { - name: "should return scheduled actions", - args: args{ - namespace: "test", - }, - mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { - results := []*applicationautoscaling.ScheduledAction{ - { - ResourceId: aws.String("test"), - }, - } - - client.On("DescribeScheduledActions", - &applicationautoscaling.DescribeScheduledActionsInput{ - ServiceNamespace: aws.String("test"), - }).Return(&applicationautoscaling.DescribeScheduledActionsOutput{ - ScheduledActions: results, - }, nil).Once() - - c.On("Get", "appAutoScalingDescribeScheduledActions_test").Return(nil).Once() - c.On("Put", "appAutoScalingDescribeScheduledActions_test", results).Return(true).Once() - }, - want: []*applicationautoscaling.ScheduledAction{ - { - ResourceId: aws.String("test"), - }, - }, - }, - { - name: "should hit cache return scheduled actions", - args: args{ - namespace: "test", - }, - mocks: func(client *awstest.MockFakeApplicationAutoScaling, c *cache.MockCache) { - results := []*applicationautoscaling.ScheduledAction{ - { - ResourceId: aws.String("test"), - }, - } - - c.On("Get", "appAutoScalingDescribeScheduledActions_test").Return(results).Once() - }, - want: []*applicationautoscaling.ScheduledAction{ - { - ResourceId: aws.String("test"), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeApplicationAutoScaling{} - tt.mocks(client, store) - - r := &appAutoScalingRepository{ - client: client, - cache: store, - } - got, err := r.DescribeScheduledActions(tt.args.namespace) - if err != nil { - assert.EqualError(t, tt.wantErr, err.Error()) - } else { - assert.Equal(t, tt.wantErr, err) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - - client.AssertExpectations(t) - store.AssertExpectations(t) - }) - } -} diff --git a/pkg/remote/aws/repository/autoscaling_repository.go b/pkg/remote/aws/repository/autoscaling_repository.go deleted file mode 100644 index b28aa740..00000000 --- a/pkg/remote/aws/repository/autoscaling_repository.go +++ /dev/null @@ -1,44 +0,0 @@ -package repository - -import ( - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/autoscaling" - "github.com/aws/aws-sdk-go/service/autoscaling/autoscalingiface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type AutoScalingRepository interface { - DescribeLaunchConfigurations() ([]*autoscaling.LaunchConfiguration, error) -} - -type autoScalingRepository struct { - client autoscalingiface.AutoScalingAPI - cache cache.Cache -} - -func NewAutoScalingRepository(session *session.Session, c cache.Cache) *autoScalingRepository { - return &autoScalingRepository{ - autoscaling.New(session), - c, - } -} - -func (r *autoScalingRepository) DescribeLaunchConfigurations() ([]*autoscaling.LaunchConfiguration, error) { - cacheKey := "DescribeLaunchConfigurations" - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*autoscaling.LaunchConfiguration), nil - } - - var results []*autoscaling.LaunchConfiguration - input := &autoscaling.DescribeLaunchConfigurationsInput{} - err := r.client.DescribeLaunchConfigurationsPages(input, func(resp *autoscaling.DescribeLaunchConfigurationsOutput, lastPage bool) bool { - results = append(results, resp.LaunchConfigurations...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, results) - return results, nil -} diff --git a/pkg/remote/aws/repository/autoscaling_repository_test.go b/pkg/remote/aws/repository/autoscaling_repository_test.go deleted file mode 100644 index 723f5954..00000000 --- a/pkg/remote/aws/repository/autoscaling_repository_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package repository - -import ( - "errors" - "strings" - "testing" - - "github.com/aws/aws-sdk-go/service/autoscaling" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - "github.com/stretchr/testify/mock" - - "github.com/aws/aws-sdk-go/aws" - - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" -) - -func Test_AutoscalingRepository_DescribeLaunchConfigurations(t *testing.T) { - dummryError := errors.New("dummy error") - - expectedLaunchConfigurations := []*autoscaling.LaunchConfiguration{ - {ImageId: aws.String("1")}, - {ImageId: aws.String("2")}, - {ImageId: aws.String("3")}, - {ImageId: aws.String("4")}, - } - - tests := []struct { - name string - mocks func(*awstest.MockFakeAutoscaling, *cache.MockCache) - want []*autoscaling.LaunchConfiguration - wantErr error - }{ - { - name: "List all launch configurations", - mocks: func(client *awstest.MockFakeAutoscaling, store *cache.MockCache) { - store.On("Get", "DescribeLaunchConfigurations").Return(nil).Once() - - client.On("DescribeLaunchConfigurationsPages", - &autoscaling.DescribeLaunchConfigurationsInput{}, - mock.MatchedBy(func(callback func(res *autoscaling.DescribeLaunchConfigurationsOutput, lastPage bool) bool) bool { - callback(&autoscaling.DescribeLaunchConfigurationsOutput{ - LaunchConfigurations: expectedLaunchConfigurations[:2], - }, false) - callback(&autoscaling.DescribeLaunchConfigurationsOutput{ - LaunchConfigurations: expectedLaunchConfigurations[2:], - }, true) - return true - })).Return(nil).Once() - - store.On("Put", "DescribeLaunchConfigurations", expectedLaunchConfigurations).Return(false).Once() - }, - want: expectedLaunchConfigurations, - }, - { - name: "Hit cache and list all launch configurations", - mocks: func(client *awstest.MockFakeAutoscaling, store *cache.MockCache) { - store.On("Get", "DescribeLaunchConfigurations").Return(expectedLaunchConfigurations).Once() - }, - want: expectedLaunchConfigurations, - }, - { - name: "Error listing all launch configurations", - mocks: func(client *awstest.MockFakeAutoscaling, store *cache.MockCache) { - store.On("Get", "DescribeLaunchConfigurations").Return(nil).Once() - - client.On("DescribeLaunchConfigurationsPages", &autoscaling.DescribeLaunchConfigurationsInput{}, mock.MatchedBy(func(callback func(res *autoscaling.DescribeLaunchConfigurationsOutput, lastPage bool) bool) bool { - callback(&autoscaling.DescribeLaunchConfigurationsOutput{ - LaunchConfigurations: []*autoscaling.LaunchConfiguration{}, - }, true) - return true - })).Return(dummryError).Once() - }, - want: nil, - wantErr: dummryError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeAutoscaling{} - tt.mocks(client, store) - r := &autoScalingRepository{ - client: client, - cache: store, - } - got, err := r.DescribeLaunchConfigurations() - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} diff --git a/pkg/remote/aws/repository/cloudformation_repository.go b/pkg/remote/aws/repository/cloudformation_repository.go deleted file mode 100644 index 52be2b63..00000000 --- a/pkg/remote/aws/repository/cloudformation_repository.go +++ /dev/null @@ -1,47 +0,0 @@ -package repository - -import ( - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/cloudformation" - "github.com/aws/aws-sdk-go/service/cloudformation/cloudformationiface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type CloudformationRepository interface { - ListAllStacks() ([]*cloudformation.Stack, error) -} - -type cloudformationRepository struct { - client cloudformationiface.CloudFormationAPI - cache cache.Cache -} - -func NewCloudformationRepository(session *session.Session, c cache.Cache) *cloudformationRepository { - return &cloudformationRepository{ - cloudformation.New(session), - c, - } -} - -func (r *cloudformationRepository) ListAllStacks() ([]*cloudformation.Stack, error) { - if v := r.cache.Get("cloudformationListAllStacks"); v != nil { - return v.([]*cloudformation.Stack), nil - } - - var stacks []*cloudformation.Stack - input := cloudformation.DescribeStacksInput{} - err := r.client.DescribeStacksPages(&input, - func(resp *cloudformation.DescribeStacksOutput, lastPage bool) bool { - if resp.Stacks != nil { - stacks = append(stacks, resp.Stacks...) - } - return !lastPage - }, - ) - if err != nil { - return nil, err - } - - r.cache.Put("cloudformationListAllStacks", stacks) - return stacks, nil -} diff --git a/pkg/remote/aws/repository/cloudformation_repository_test.go b/pkg/remote/aws/repository/cloudformation_repository_test.go deleted file mode 100644 index 1dffdc28..00000000 --- a/pkg/remote/aws/repository/cloudformation_repository_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package repository - -import ( - "strings" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/cloudformation" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - - "github.com/stretchr/testify/mock" - - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" -) - -func Test_cloudformationRepository_ListAllStacks(t *testing.T) { - stacks := []*cloudformation.Stack{ - {StackId: aws.String("stack1")}, - {StackId: aws.String("stack2")}, - {StackId: aws.String("stack3")}, - {StackId: aws.String("stack4")}, - {StackId: aws.String("stack5")}, - {StackId: aws.String("stack6")}, - } - - tests := []struct { - name string - mocks func(client *awstest.MockFakeCloudformation, store *cache.MockCache) - want []*cloudformation.Stack - wantErr error - }{ - { - name: "list multiple stacks", - mocks: func(client *awstest.MockFakeCloudformation, store *cache.MockCache) { - client.On("DescribeStacksPages", - &cloudformation.DescribeStacksInput{}, - mock.MatchedBy(func(callback func(res *cloudformation.DescribeStacksOutput, lastPage bool) bool) bool { - callback(&cloudformation.DescribeStacksOutput{ - Stacks: stacks[:3], - }, false) - callback(&cloudformation.DescribeStacksOutput{ - Stacks: stacks[3:], - }, true) - return true - })).Return(nil).Once() - - store.On("Get", "cloudformationListAllStacks").Return(nil).Times(1) - store.On("Put", "cloudformationListAllStacks", stacks).Return(false).Times(1) - }, - want: stacks, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeCloudformation, store *cache.MockCache) { - store.On("Get", "cloudformationListAllStacks").Return(stacks).Times(1) - }, - want: stacks, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeCloudformation{} - tt.mocks(client, store) - r := &cloudformationRepository{ - client: client, - cache: store, - } - got, err := r.ListAllStacks() - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} diff --git a/pkg/remote/aws/repository/cloudfront_repository.go b/pkg/remote/aws/repository/cloudfront_repository.go deleted file mode 100644 index 847b69a8..00000000 --- a/pkg/remote/aws/repository/cloudfront_repository.go +++ /dev/null @@ -1,47 +0,0 @@ -package repository - -import ( - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/cloudfront" - "github.com/aws/aws-sdk-go/service/cloudfront/cloudfrontiface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type CloudfrontRepository interface { - ListAllDistributions() ([]*cloudfront.DistributionSummary, error) -} - -type cloudfrontRepository struct { - client cloudfrontiface.CloudFrontAPI - cache cache.Cache -} - -func NewCloudfrontRepository(session *session.Session, c cache.Cache) *cloudfrontRepository { - return &cloudfrontRepository{ - cloudfront.New(session), - c, - } -} - -func (r *cloudfrontRepository) ListAllDistributions() ([]*cloudfront.DistributionSummary, error) { - if v := r.cache.Get("cloudfrontListAllDistributions"); v != nil { - return v.([]*cloudfront.DistributionSummary), nil - } - - var distributions []*cloudfront.DistributionSummary - input := cloudfront.ListDistributionsInput{} - err := r.client.ListDistributionsPages(&input, - func(resp *cloudfront.ListDistributionsOutput, lastPage bool) bool { - if resp.DistributionList != nil { - distributions = append(distributions, resp.DistributionList.Items...) - } - return !lastPage - }, - ) - if err != nil { - return nil, err - } - - r.cache.Put("cloudfrontListAllDistributions", distributions) - return distributions, nil -} diff --git a/pkg/remote/aws/repository/cloudfront_repository_test.go b/pkg/remote/aws/repository/cloudfront_repository_test.go deleted file mode 100644 index baa869be..00000000 --- a/pkg/remote/aws/repository/cloudfront_repository_test.go +++ /dev/null @@ -1,92 +0,0 @@ -package repository - -import ( - "strings" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/cloudfront" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - - "github.com/stretchr/testify/mock" - - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" -) - -func Test_cloudfrontRepository_ListAllDistributions(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeCloudFront) - want []*cloudfront.DistributionSummary - wantErr error - }{ - { - name: "list multiple distributions", - mocks: func(client *awstest.MockFakeCloudFront) { - client.On("ListDistributionsPages", - &cloudfront.ListDistributionsInput{}, - mock.MatchedBy(func(callback func(res *cloudfront.ListDistributionsOutput, lastPage bool) bool) bool { - callback(&cloudfront.ListDistributionsOutput{ - DistributionList: &cloudfront.DistributionList{ - Items: []*cloudfront.DistributionSummary{ - {Id: aws.String("distribution1")}, - {Id: aws.String("distribution2")}, - {Id: aws.String("distribution3")}, - }, - }, - }, false) - callback(&cloudfront.ListDistributionsOutput{ - DistributionList: &cloudfront.DistributionList{ - Items: []*cloudfront.DistributionSummary{ - {Id: aws.String("distribution4")}, - {Id: aws.String("distribution5")}, - {Id: aws.String("distribution6")}, - }, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*cloudfront.DistributionSummary{ - {Id: aws.String("distribution1")}, - {Id: aws.String("distribution2")}, - {Id: aws.String("distribution3")}, - {Id: aws.String("distribution4")}, - {Id: aws.String("distribution5")}, - {Id: aws.String("distribution6")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := awstest.MockFakeCloudFront{} - tt.mocks(&client) - r := &cloudfrontRepository{ - client: &client, - cache: store, - } - got, err := r.ListAllDistributions() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllDistributions() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*cloudfront.DistributionSummary{}, store.Get("cloudfrontListAllDistributions")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} diff --git a/pkg/remote/aws/repository/dynamodb_repository.go b/pkg/remote/aws/repository/dynamodb_repository.go deleted file mode 100644 index e902f709..00000000 --- a/pkg/remote/aws/repository/dynamodb_repository.go +++ /dev/null @@ -1,43 +0,0 @@ -package repository - -import ( - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type DynamoDBRepository interface { - ListAllTables() ([]*string, error) -} - -type dynamoDBRepository struct { - client dynamodbiface.DynamoDBAPI - cache cache.Cache -} - -func NewDynamoDBRepository(session *session.Session, c cache.Cache) *dynamoDBRepository { - return &dynamoDBRepository{ - dynamodb.New(session), - c, - } -} - -func (r *dynamoDBRepository) ListAllTables() ([]*string, error) { - if v := r.cache.Get("dynamodbListAllTables"); v != nil { - return v.([]*string), nil - } - - var tables []*string - input := &dynamodb.ListTablesInput{} - err := r.client.ListTablesPages(input, func(res *dynamodb.ListTablesOutput, lastPage bool) bool { - tables = append(tables, res.TableNames...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put("dynamodbListAllTables", tables) - return tables, nil -} diff --git a/pkg/remote/aws/repository/dynamodb_repository_test.go b/pkg/remote/aws/repository/dynamodb_repository_test.go deleted file mode 100644 index 00a03ab7..00000000 --- a/pkg/remote/aws/repository/dynamodb_repository_test.go +++ /dev/null @@ -1,89 +0,0 @@ -package repository - -import ( - "strings" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - - "github.com/stretchr/testify/mock" - - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" -) - -func Test_dynamoDBRepository_ListAllTopics(t *testing.T) { - - tests := []struct { - name string - mocks func(client *awstest.MockFakeDynamoDB) - want []*string - wantErr error - }{ - { - name: "List with 2 pages", - mocks: func(client *awstest.MockFakeDynamoDB) { - client.On("ListTablesPages", - &dynamodb.ListTablesInput{}, - mock.MatchedBy(func(callback func(res *dynamodb.ListTablesOutput, lastPage bool) bool) bool { - callback(&dynamodb.ListTablesOutput{ - TableNames: []*string{ - aws.String("1"), - aws.String("2"), - aws.String("3"), - }, - }, false) - callback(&dynamodb.ListTablesOutput{ - TableNames: []*string{ - aws.String("4"), - aws.String("5"), - aws.String("6"), - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*string{ - aws.String("1"), - aws.String("2"), - aws.String("3"), - aws.String("4"), - aws.String("5"), - aws.String("6"), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := awstest.MockFakeDynamoDB{} - tt.mocks(&client) - r := &dynamoDBRepository{ - client: &client, - cache: store, - } - got, err := r.ListAllTables() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllTables() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*string{}, store.Get("dynamodbListAllTables")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} diff --git a/pkg/remote/aws/repository/ec2_repository.go b/pkg/remote/aws/repository/ec2_repository.go deleted file mode 100644 index 776fd823..00000000 --- a/pkg/remote/aws/repository/ec2_repository.go +++ /dev/null @@ -1,408 +0,0 @@ -package repository - -import ( - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type EC2Repository interface { - ListAllImages() ([]*ec2.Image, error) - ListAllSnapshots() ([]*ec2.Snapshot, error) - ListAllVolumes() ([]*ec2.Volume, error) - ListAllAddresses() ([]*ec2.Address, error) - ListAllAddressesAssociation() ([]*ec2.Address, error) - ListAllInstances() ([]*ec2.Instance, error) - ListAllKeyPairs() ([]*ec2.KeyPairInfo, error) - ListAllInternetGateways() ([]*ec2.InternetGateway, error) - ListAllSubnets() ([]*ec2.Subnet, []*ec2.Subnet, error) - ListAllNatGateways() ([]*ec2.NatGateway, error) - ListAllRouteTables() ([]*ec2.RouteTable, error) - ListAllVPCs() ([]*ec2.Vpc, []*ec2.Vpc, error) - ListAllSecurityGroups() ([]*ec2.SecurityGroup, []*ec2.SecurityGroup, error) - ListAllNetworkACLs() ([]*ec2.NetworkAcl, error) - DescribeLaunchTemplates() ([]*ec2.LaunchTemplate, error) - IsEbsEncryptionEnabledByDefault() (bool, error) -} - -type ec2Repository struct { - client ec2iface.EC2API - cache cache.Cache -} - -func NewEC2Repository(session *session.Session, c cache.Cache) *ec2Repository { - return &ec2Repository{ - ec2.New(session), - c, - } -} - -func (r *ec2Repository) ListAllImages() ([]*ec2.Image, error) { - if v := r.cache.Get("ec2ListAllImages"); v != nil { - return v.([]*ec2.Image), nil - } - - input := &ec2.DescribeImagesInput{ - Owners: []*string{ - aws.String("self"), - }, - } - images, err := r.client.DescribeImages(input) - if err != nil { - return nil, err - } - r.cache.Put("ec2ListAllImages", images.Images) - return images.Images, err -} - -func (r *ec2Repository) ListAllSnapshots() ([]*ec2.Snapshot, error) { - if v := r.cache.Get("ec2ListAllSnapshots"); v != nil { - return v.([]*ec2.Snapshot), nil - } - - var snapshots []*ec2.Snapshot - input := &ec2.DescribeSnapshotsInput{ - OwnerIds: []*string{ - aws.String("self"), - }, - } - err := r.client.DescribeSnapshotsPages(input, func(res *ec2.DescribeSnapshotsOutput, lastPage bool) bool { - snapshots = append(snapshots, res.Snapshots...) - return !lastPage - }) - if err != nil { - return nil, err - } - r.cache.Put("ec2ListAllSnapshots", snapshots) - return snapshots, err -} - -func (r *ec2Repository) ListAllVolumes() ([]*ec2.Volume, error) { - if v := r.cache.Get("ec2ListAllVolumes"); v != nil { - return v.([]*ec2.Volume), nil - } - - var volumes []*ec2.Volume - input := &ec2.DescribeVolumesInput{} - err := r.client.DescribeVolumesPages(input, func(res *ec2.DescribeVolumesOutput, lastPage bool) bool { - volumes = append(volumes, res.Volumes...) - return !lastPage - }) - if err != nil { - return nil, err - } - r.cache.Put("ec2ListAllVolumes", volumes) - return volumes, nil -} - -func (r *ec2Repository) ListAllAddresses() ([]*ec2.Address, error) { - cacheKey := "ec2ListAllAddresses" - v := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - if v != nil { - return v.([]*ec2.Address), nil - } - - input := &ec2.DescribeAddressesInput{} - response, err := r.client.DescribeAddresses(input) - if err != nil { - return nil, err - } - r.cache.Put(cacheKey, response.Addresses) - return response.Addresses, nil -} - -func (r *ec2Repository) ListAllAddressesAssociation() ([]*ec2.Address, error) { - if v := r.cache.Get("ec2ListAllAddressesAssociation"); v != nil { - return v.([]*ec2.Address), nil - } - - addresses, err := r.ListAllAddresses() - if err != nil { - return nil, err - } - results := make([]*ec2.Address, 0, len(addresses)) - - for _, address := range addresses { - if address.AssociationId != nil { - results = append(results, address) - } - } - r.cache.Put("ec2ListAllAddressesAssociation", results) - return results, nil -} - -func (r *ec2Repository) ListAllInstances() ([]*ec2.Instance, error) { - if v := r.cache.Get("ec2ListAllInstances"); v != nil { - return v.([]*ec2.Instance), nil - } - var instances []*ec2.Instance - input := &ec2.DescribeInstancesInput{ - Filters: []*ec2.Filter{ - { - // Ignore terminated state from enumeration since terminated means that instance - // has been removed - Name: aws.String("instance-state-name"), - Values: aws.StringSlice([]string{ - "pending", - "running", - "stopping", - "shutting-down", - "stopped", - }), - }, - }, - } - err := r.client.DescribeInstancesPages(input, func(res *ec2.DescribeInstancesOutput, lastPage bool) bool { - for _, reservation := range res.Reservations { - instances = append(instances, reservation.Instances...) - } - return !lastPage - }) - if err != nil { - return nil, err - } - r.cache.Put("ec2ListAllInstances", instances) - return instances, nil -} - -func (r *ec2Repository) ListAllKeyPairs() ([]*ec2.KeyPairInfo, error) { - if v := r.cache.Get("ec2ListAllKeyPairs"); v != nil { - return v.([]*ec2.KeyPairInfo), nil - } - - input := &ec2.DescribeKeyPairsInput{} - pairs, err := r.client.DescribeKeyPairs(input) - if err != nil { - return nil, err - } - r.cache.Put("ec2ListAllKeyPairs", pairs.KeyPairs) - return pairs.KeyPairs, err -} - -func (r *ec2Repository) ListAllInternetGateways() ([]*ec2.InternetGateway, error) { - if v := r.cache.Get("ec2ListAllInternetGateways"); v != nil { - return v.([]*ec2.InternetGateway), nil - } - - var internetGateways []*ec2.InternetGateway - input := ec2.DescribeInternetGatewaysInput{} - err := r.client.DescribeInternetGatewaysPages(&input, - func(resp *ec2.DescribeInternetGatewaysOutput, lastPage bool) bool { - internetGateways = append(internetGateways, resp.InternetGateways...) - return !lastPage - }, - ) - if err != nil { - return nil, err - } - r.cache.Put("ec2ListAllInternetGateways", internetGateways) - return internetGateways, nil -} - -func (r *ec2Repository) ListAllSubnets() ([]*ec2.Subnet, []*ec2.Subnet, error) { - cacheKey := "ec2ListAllSubnets" - cacheSubnets := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - - defaultCacheKey := "ec2ListAllDefaultSubnets" - cacheDefaultSubnets := r.cache.GetAndLock(defaultCacheKey) - defer r.cache.Unlock(defaultCacheKey) - if cacheSubnets != nil && cacheDefaultSubnets != nil { - return cacheSubnets.([]*ec2.Subnet), cacheDefaultSubnets.([]*ec2.Subnet), nil - } - - input := ec2.DescribeSubnetsInput{} - var subnets []*ec2.Subnet - var defaultSubnets []*ec2.Subnet - err := r.client.DescribeSubnetsPages(&input, - func(resp *ec2.DescribeSubnetsOutput, lastPage bool) bool { - for _, subnet := range resp.Subnets { - if subnet.DefaultForAz != nil && *subnet.DefaultForAz { - defaultSubnets = append(defaultSubnets, subnet) - continue - } - subnets = append(subnets, subnet) - } - return !lastPage - }) - if err != nil { - return nil, nil, err - } - r.cache.Put(cacheKey, subnets) - r.cache.Put(defaultCacheKey, defaultSubnets) - return subnets, defaultSubnets, nil -} - -func (r *ec2Repository) ListAllNatGateways() ([]*ec2.NatGateway, error) { - if v := r.cache.Get("ec2ListAllNatGateways"); v != nil { - return v.([]*ec2.NatGateway), nil - } - - var result []*ec2.NatGateway - input := ec2.DescribeNatGatewaysInput{} - err := r.client.DescribeNatGatewaysPages(&input, - func(resp *ec2.DescribeNatGatewaysOutput, lastPage bool) bool { - result = append(result, resp.NatGateways...) - return !lastPage - }, - ) - - if err != nil { - return nil, err - } - - r.cache.Put("ec2ListAllNatGateways", result) - return result, nil -} - -func (r *ec2Repository) ListAllRouteTables() ([]*ec2.RouteTable, error) { - cacheKey := "ec2ListAllRouteTables" - v := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - if v != nil { - return v.([]*ec2.RouteTable), nil - } - - var routeTables []*ec2.RouteTable - input := ec2.DescribeRouteTablesInput{} - err := r.client.DescribeRouteTablesPages(&input, - func(resp *ec2.DescribeRouteTablesOutput, lastPage bool) bool { - routeTables = append(routeTables, resp.RouteTables...) - return !lastPage - }, - ) - - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, routeTables) - return routeTables, nil -} - -func (r *ec2Repository) ListAllVPCs() ([]*ec2.Vpc, []*ec2.Vpc, error) { - cacheKey := "ec2ListAllVPCs" - cacheVPCs := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - defaultCacheKey := "ec2ListAllDefaultVPCs" - cacheDefaultVPCs := r.cache.GetAndLock(defaultCacheKey) - defer r.cache.Unlock(defaultCacheKey) - if cacheVPCs != nil && cacheDefaultVPCs != nil { - return cacheVPCs.([]*ec2.Vpc), cacheDefaultVPCs.([]*ec2.Vpc), nil - } - - input := ec2.DescribeVpcsInput{} - var VPCs []*ec2.Vpc - var defaultVPCs []*ec2.Vpc - err := r.client.DescribeVpcsPages(&input, - func(resp *ec2.DescribeVpcsOutput, lastPage bool) bool { - for _, vpc := range resp.Vpcs { - if vpc.IsDefault != nil && *vpc.IsDefault { - defaultVPCs = append(defaultVPCs, vpc) - continue - } - VPCs = append(VPCs, vpc) - } - return !lastPage - }, - ) - if err != nil { - return nil, nil, err - } - - r.cache.Put(cacheKey, VPCs) - r.cache.Put(defaultCacheKey, defaultVPCs) - return VPCs, defaultVPCs, nil -} - -func (r *ec2Repository) ListAllSecurityGroups() ([]*ec2.SecurityGroup, []*ec2.SecurityGroup, error) { - cacheKey := "ec2ListAllSecurityGroups" - cacheSecurityGroups := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - defaultCacheKey := "ec2ListAllDefaultSecurityGroups" - cacheDefaultSecurityGroups := r.cache.GetAndLock(defaultCacheKey) - defer r.cache.Unlock(defaultCacheKey) - if cacheSecurityGroups != nil && cacheDefaultSecurityGroups != nil { - return cacheSecurityGroups.([]*ec2.SecurityGroup), cacheDefaultSecurityGroups.([]*ec2.SecurityGroup), nil - } - - var securityGroups []*ec2.SecurityGroup - var defaultSecurityGroups []*ec2.SecurityGroup - input := &ec2.DescribeSecurityGroupsInput{} - err := r.client.DescribeSecurityGroupsPages(input, func(res *ec2.DescribeSecurityGroupsOutput, lastPage bool) bool { - for _, securityGroup := range res.SecurityGroups { - if securityGroup.GroupName != nil && *securityGroup.GroupName == "default" { - defaultSecurityGroups = append(defaultSecurityGroups, securityGroup) - continue - } - securityGroups = append(securityGroups, securityGroup) - } - return !lastPage - }) - if err != nil { - return nil, nil, err - } - - r.cache.Put(cacheKey, securityGroups) - r.cache.Put(defaultCacheKey, defaultSecurityGroups) - return securityGroups, defaultSecurityGroups, nil -} - -func (r *ec2Repository) ListAllNetworkACLs() ([]*ec2.NetworkAcl, error) { - - cacheKey := "ec2ListAllNetworkACLs" - v := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - if v != nil { - return v.([]*ec2.NetworkAcl), nil - } - - var ACLs []*ec2.NetworkAcl - input := ec2.DescribeNetworkAclsInput{} - err := r.client.DescribeNetworkAclsPages(&input, - func(resp *ec2.DescribeNetworkAclsOutput, lastPage bool) bool { - ACLs = append(ACLs, resp.NetworkAcls...) - return !lastPage - }, - ) - - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, ACLs) - return ACLs, nil -} - -func (r *ec2Repository) DescribeLaunchTemplates() ([]*ec2.LaunchTemplate, error) { - cacheKey := "DescribeLaunchTemplates" - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*ec2.LaunchTemplate), nil - } - - input := ec2.DescribeLaunchTemplatesInput{} - resp, err := r.client.DescribeLaunchTemplates(&input) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resp.LaunchTemplates) - return resp.LaunchTemplates, nil -} - -func (r *ec2Repository) IsEbsEncryptionEnabledByDefault() (bool, error) { - if v := r.cache.Get("ec2IsEbsEncryptionEnabledByDefault"); v != nil { - return v.(bool), nil - } - - input := &ec2.GetEbsEncryptionByDefaultInput{} - resp, err := r.client.GetEbsEncryptionByDefault(input) - if err != nil { - return false, err - } - r.cache.Put("ec2IsEbsEncryptionEnabledByDefault", *resp.EbsEncryptionByDefault) - return *resp.EbsEncryptionByDefault, err -} diff --git a/pkg/remote/aws/repository/ec2_repository_test.go b/pkg/remote/aws/repository/ec2_repository_test.go deleted file mode 100644 index 2d0b4aeb..00000000 --- a/pkg/remote/aws/repository/ec2_repository_test.go +++ /dev/null @@ -1,1429 +0,0 @@ -package repository - -import ( - "strings" - "testing" - - "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - "github.com/stretchr/testify/mock" - - "github.com/aws/aws-sdk-go/service/ec2" - - "github.com/aws/aws-sdk-go/aws" - - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" -) - -func Test_ec2Repository_ListAllImages(t *testing.T) { - - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2) - want []*ec2.Image - wantErr error - }{ - { - name: "List all images", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeImages", - &ec2.DescribeImagesInput{ - Owners: []*string{ - aws.String("self"), - }, - }).Return(&ec2.DescribeImagesOutput{ - Images: []*ec2.Image{ - {ImageId: aws.String("1")}, - {ImageId: aws.String("2")}, - {ImageId: aws.String("3")}, - {ImageId: aws.String("4")}, - }, - }, nil).Once() - }, - want: []*ec2.Image{ - {ImageId: aws.String("1")}, - {ImageId: aws.String("2")}, - {ImageId: aws.String("3")}, - {ImageId: aws.String("4")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeEC2{} - tt.mocks(client) - r := &ec2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllImages() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllImages() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*ec2.Image{}, store.Get("ec2ListAllImages")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_ec2Repository_ListAllSnapshots(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2) - want []*ec2.Snapshot - wantErr error - }{ - {name: "List with 2 pages", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeSnapshotsPages", - &ec2.DescribeSnapshotsInput{ - OwnerIds: []*string{ - aws.String("self"), - }, - }, - mock.MatchedBy(func(callback func(res *ec2.DescribeSnapshotsOutput, lastPage bool) bool) bool { - callback(&ec2.DescribeSnapshotsOutput{ - Snapshots: []*ec2.Snapshot{ - {VolumeId: aws.String("1")}, - {VolumeId: aws.String("2")}, - {VolumeId: aws.String("3")}, - {VolumeId: aws.String("4")}, - }, - }, false) - callback(&ec2.DescribeSnapshotsOutput{ - Snapshots: []*ec2.Snapshot{ - {VolumeId: aws.String("5")}, - {VolumeId: aws.String("6")}, - {VolumeId: aws.String("7")}, - {VolumeId: aws.String("8")}, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*ec2.Snapshot{ - {VolumeId: aws.String("1")}, - {VolumeId: aws.String("2")}, - {VolumeId: aws.String("3")}, - {VolumeId: aws.String("4")}, - {VolumeId: aws.String("5")}, - {VolumeId: aws.String("6")}, - {VolumeId: aws.String("7")}, - {VolumeId: aws.String("8")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeEC2{} - tt.mocks(client) - r := &ec2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllSnapshots() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllSnapshots() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*ec2.Snapshot{}, store.Get("ec2ListAllSnapshots")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_ec2Repository_ListAllVolumes(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2) - want []*ec2.Volume - wantErr error - }{ - {name: "List with 2 pages", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeVolumesPages", - &ec2.DescribeVolumesInput{}, - mock.MatchedBy(func(callback func(res *ec2.DescribeVolumesOutput, lastPage bool) bool) bool { - callback(&ec2.DescribeVolumesOutput{ - Volumes: []*ec2.Volume{ - {VolumeId: aws.String("1")}, - {VolumeId: aws.String("2")}, - {VolumeId: aws.String("3")}, - {VolumeId: aws.String("4")}, - }, - }, false) - callback(&ec2.DescribeVolumesOutput{ - Volumes: []*ec2.Volume{ - {VolumeId: aws.String("5")}, - {VolumeId: aws.String("6")}, - {VolumeId: aws.String("7")}, - {VolumeId: aws.String("8")}, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*ec2.Volume{ - {VolumeId: aws.String("1")}, - {VolumeId: aws.String("2")}, - {VolumeId: aws.String("3")}, - {VolumeId: aws.String("4")}, - {VolumeId: aws.String("5")}, - {VolumeId: aws.String("6")}, - {VolumeId: aws.String("7")}, - {VolumeId: aws.String("8")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeEC2{} - tt.mocks(client) - r := &ec2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllVolumes() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllVolumes() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*ec2.Volume{}, store.Get("ec2ListAllVolumes")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_ec2Repository_ListAllAddresses(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2) - want []*ec2.Address - wantErr error - }{ - { - name: "List address", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeAddresses", &ec2.DescribeAddressesInput{}). - Return(&ec2.DescribeAddressesOutput{ - Addresses: []*ec2.Address{ - {AssociationId: aws.String("1")}, - {AssociationId: aws.String("2")}, - {AssociationId: aws.String("3")}, - {AssociationId: aws.String("4")}, - }, - }, nil).Once() - }, - want: []*ec2.Address{ - {AssociationId: aws.String("1")}, - {AssociationId: aws.String("2")}, - {AssociationId: aws.String("3")}, - {AssociationId: aws.String("4")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeEC2{} - tt.mocks(client) - r := &ec2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllAddresses() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllAddresses() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*ec2.Address{}, store.Get("ec2ListAllAddresses")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_ec2Repository_ListAllAddressesAssociation(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2) - want []*ec2.Address - wantErr error - }{ - { - name: "List address", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeAddresses", &ec2.DescribeAddressesInput{}). - Return(&ec2.DescribeAddressesOutput{ - Addresses: []*ec2.Address{ - {AssociationId: aws.String("1")}, - {AssociationId: aws.String("2")}, - {AssociationId: aws.String("3")}, - {AssociationId: aws.String("4")}, - }, - }, nil).Once() - }, - want: []*ec2.Address{ - {AssociationId: aws.String("1")}, - {AssociationId: aws.String("2")}, - {AssociationId: aws.String("3")}, - {AssociationId: aws.String("4")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeEC2{} - tt.mocks(client) - r := &ec2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllAddressesAssociation() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllAddressesAssociation() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*ec2.Address{}, store.Get("ec2ListAllAddressesAssociation")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_ec2Repository_ListAllInstances(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2) - want []*ec2.Instance - wantErr error - }{ - {name: "List with 2 pages", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeInstancesPages", - &ec2.DescribeInstancesInput{ - Filters: []*ec2.Filter{ - { - Name: aws.String("instance-state-name"), - Values: aws.StringSlice([]string{ - "pending", - "running", - "stopping", - "shutting-down", - "stopped", - }), - }, - }, - }, - mock.MatchedBy(func(callback func(res *ec2.DescribeInstancesOutput, lastPage bool) bool) bool { - callback(&ec2.DescribeInstancesOutput{ - Reservations: []*ec2.Reservation{ - { - Instances: []*ec2.Instance{ - {ImageId: aws.String("1")}, - {ImageId: aws.String("2")}, - {ImageId: aws.String("3")}, - }, - }, - { - Instances: []*ec2.Instance{ - {ImageId: aws.String("4")}, - {ImageId: aws.String("5")}, - {ImageId: aws.String("6")}, - }, - }, - }, - }, false) - callback(&ec2.DescribeInstancesOutput{ - Reservations: []*ec2.Reservation{ - { - Instances: []*ec2.Instance{ - {ImageId: aws.String("7")}, - {ImageId: aws.String("8")}, - {ImageId: aws.String("9")}, - }, - }, - { - Instances: []*ec2.Instance{ - {ImageId: aws.String("10")}, - {ImageId: aws.String("11")}, - {ImageId: aws.String("12")}, - }, - }, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*ec2.Instance{ - {ImageId: aws.String("1")}, - {ImageId: aws.String("2")}, - {ImageId: aws.String("3")}, - {ImageId: aws.String("4")}, - {ImageId: aws.String("5")}, - {ImageId: aws.String("6")}, - {ImageId: aws.String("7")}, - {ImageId: aws.String("8")}, - {ImageId: aws.String("9")}, - {ImageId: aws.String("10")}, - {ImageId: aws.String("11")}, - {ImageId: aws.String("12")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeEC2{} - tt.mocks(client) - r := &ec2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllInstances() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllInstances() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*ec2.Instance{}, store.Get("ec2ListAllInstances")) - } - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_ec2Repository_ListAllKeyPairs(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2) - want []*ec2.KeyPairInfo - wantErr error - }{ - { - name: "List address", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeKeyPairs", &ec2.DescribeKeyPairsInput{}). - Return(&ec2.DescribeKeyPairsOutput{ - KeyPairs: []*ec2.KeyPairInfo{ - {KeyPairId: aws.String("1")}, - {KeyPairId: aws.String("2")}, - {KeyPairId: aws.String("3")}, - {KeyPairId: aws.String("4")}, - }, - }, nil).Once() - }, - want: []*ec2.KeyPairInfo{ - {KeyPairId: aws.String("1")}, - {KeyPairId: aws.String("2")}, - {KeyPairId: aws.String("3")}, - {KeyPairId: aws.String("4")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeEC2{} - tt.mocks(client) - r := &ec2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllKeyPairs() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllKeyPairs() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*ec2.KeyPairInfo{}, store.Get("ec2ListAllKeyPairs")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_ec2Repository_ListAllInternetGateways(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2) - want []*ec2.InternetGateway - wantErr error - }{ - { - name: "List only gateways with multiple pages", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeInternetGatewaysPages", - &ec2.DescribeInternetGatewaysInput{}, - mock.MatchedBy(func(callback func(res *ec2.DescribeInternetGatewaysOutput, lastPage bool) bool) bool { - callback(&ec2.DescribeInternetGatewaysOutput{ - InternetGateways: []*ec2.InternetGateway{ - { - InternetGatewayId: aws.String("Internet-0"), - }, - { - InternetGatewayId: aws.String("Internet-1"), - }, - }, - }, false) - callback(&ec2.DescribeInternetGatewaysOutput{ - InternetGateways: []*ec2.InternetGateway{ - { - InternetGatewayId: aws.String("Internet-2"), - }, - { - InternetGatewayId: aws.String("Internet-3"), - }, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*ec2.InternetGateway{ - { - InternetGatewayId: aws.String("Internet-0"), - }, - { - InternetGatewayId: aws.String("Internet-1"), - }, - { - InternetGatewayId: aws.String("Internet-2"), - }, - { - InternetGatewayId: aws.String("Internet-3"), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeEC2{} - tt.mocks(client) - r := &ec2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllInternetGateways() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllInternetGateways() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*ec2.InternetGateway{}, store.Get("ec2ListAllInternetGateways")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_ec2Repository_ListAllSubnets(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2) - wantSubnet []*ec2.Subnet - wantDefaultSubnet []*ec2.Subnet - wantErr error - }{ - { - name: "List with 2 pages", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeSubnetsPages", - &ec2.DescribeSubnetsInput{}, - mock.MatchedBy(func(callback func(res *ec2.DescribeSubnetsOutput, lastPage bool) bool) bool { - callback(&ec2.DescribeSubnetsOutput{ - Subnets: []*ec2.Subnet{ - { - SubnetId: aws.String("subnet-0b13f1e0eacf67424"), // subnet2 - DefaultForAz: aws.Bool(false), - }, - { - SubnetId: aws.String("subnet-0c9b78001fe186e22"), // subnet3 - DefaultForAz: aws.Bool(false), - }, - { - SubnetId: aws.String("subnet-05810d3f933925f6d"), // subnet1 - DefaultForAz: aws.Bool(false), - }, - }, - }, false) - callback(&ec2.DescribeSubnetsOutput{ - Subnets: []*ec2.Subnet{ - { - SubnetId: aws.String("subnet-44fe0c65"), // us-east-1a - DefaultForAz: aws.Bool(true), - }, - { - SubnetId: aws.String("subnet-65e16628"), // us-east-1b - DefaultForAz: aws.Bool(true), - }, - { - SubnetId: aws.String("subnet-afa656f0"), // us-east-1c - DefaultForAz: aws.Bool(true), - }, - }, - }, true) - return true - })).Return(nil).Once() - }, - wantSubnet: []*ec2.Subnet{ - { - SubnetId: aws.String("subnet-0b13f1e0eacf67424"), // subnet2 - DefaultForAz: aws.Bool(false), - }, - { - SubnetId: aws.String("subnet-0c9b78001fe186e22"), // subnet3 - DefaultForAz: aws.Bool(false), - }, - { - SubnetId: aws.String("subnet-05810d3f933925f6d"), // subnet1 - DefaultForAz: aws.Bool(false), - }, - }, - wantDefaultSubnet: []*ec2.Subnet{ - { - SubnetId: aws.String("subnet-44fe0c65"), // us-east-1a - DefaultForAz: aws.Bool(true), - }, - { - SubnetId: aws.String("subnet-65e16628"), // us-east-1b - DefaultForAz: aws.Bool(true), - }, - { - SubnetId: aws.String("subnet-afa656f0"), // us-east-1c - DefaultForAz: aws.Bool(true), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(2) - client := &awstest.MockFakeEC2{} - tt.mocks(client) - r := &ec2Repository{ - client: client, - cache: store, - } - gotSubnet, gotDefaultSubnet, err := r.ListAllSubnets() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, cachedDefaultData, err := r.ListAllSubnets() - assert.NoError(t, err) - assert.Equal(t, gotSubnet, cachedData) - assert.Equal(t, gotDefaultSubnet, cachedDefaultData) - assert.IsType(t, []*ec2.Subnet{}, store.Get("ec2ListAllSubnets")) - assert.IsType(t, []*ec2.Subnet{}, store.Get("ec2ListAllDefaultSubnets")) - } - - changelog, err := diff.Diff(gotSubnet, tt.wantSubnet) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - changelog, err = diff.Diff(gotDefaultSubnet, tt.wantDefaultSubnet) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_ec2Repository_ListAllNatGateways(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2) - want []*ec2.NatGateway - wantErr error - }{ - { - name: "List only gateways with multiple pages", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeNatGatewaysPages", - &ec2.DescribeNatGatewaysInput{}, - mock.MatchedBy(func(callback func(res *ec2.DescribeNatGatewaysOutput, lastPage bool) bool) bool { - callback(&ec2.DescribeNatGatewaysOutput{ - NatGateways: []*ec2.NatGateway{ - { - NatGatewayId: aws.String("nat-0"), - }, - { - NatGatewayId: aws.String("nat-1"), - }, - }, - }, false) - callback(&ec2.DescribeNatGatewaysOutput{ - NatGateways: []*ec2.NatGateway{ - { - NatGatewayId: aws.String("nat-2"), - }, - { - NatGatewayId: aws.String("nat-3"), - }, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*ec2.NatGateway{ - { - NatGatewayId: aws.String("nat-0"), - }, - { - NatGatewayId: aws.String("nat-1"), - }, - { - NatGatewayId: aws.String("nat-2"), - }, - { - NatGatewayId: aws.String("nat-3"), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeEC2{} - tt.mocks(client) - r := &ec2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllNatGateways() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllNatGateways() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*ec2.NatGateway{}, store.Get("ec2ListAllNatGateways")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_ec2Repository_ListAllRouteTables(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2) - want []*ec2.RouteTable - wantErr error - }{ - { - name: "List only route with multiple pages", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeRouteTablesPages", - &ec2.DescribeRouteTablesInput{}, - mock.MatchedBy(func(callback func(res *ec2.DescribeRouteTablesOutput, lastPage bool) bool) bool { - callback(&ec2.DescribeRouteTablesOutput{ - RouteTables: []*ec2.RouteTable{ - { - RouteTableId: aws.String("rtb-096bdfb69309c54c3"), // table1 - Routes: []*ec2.Route{ - { - DestinationCidrBlock: aws.String("10.0.0.0/16"), - Origin: aws.String("CreateRouteTable"), // default route - }, - { - DestinationCidrBlock: aws.String("1.1.1.1/32"), - GatewayId: aws.String("igw-030e74f73bd67f21b"), - }, - { - DestinationIpv6CidrBlock: aws.String("::/0"), - GatewayId: aws.String("igw-030e74f73bd67f21b"), - }, - }, - }, - { - RouteTableId: aws.String("rtb-0169b0937fd963ddc"), // table2 - Routes: []*ec2.Route{ - { - DestinationCidrBlock: aws.String("10.0.0.0/16"), - Origin: aws.String("CreateRouteTable"), // default route - }, - { - DestinationCidrBlock: aws.String("0.0.0.0/0"), - GatewayId: aws.String("igw-030e74f73bd67f21b"), - }, - { - DestinationIpv6CidrBlock: aws.String("::/0"), - GatewayId: aws.String("igw-030e74f73bd67f21b"), - }, - }, - }, - }, - }, false) - callback(&ec2.DescribeRouteTablesOutput{ - RouteTables: []*ec2.RouteTable{ - { - RouteTableId: aws.String("rtb-02780c485f0be93c5"), // default_table - VpcId: aws.String("vpc-09fe5abc2309ba49d"), - Associations: []*ec2.RouteTableAssociation{ - { - Main: aws.Bool(true), - }, - }, - Routes: []*ec2.Route{ - { - DestinationCidrBlock: aws.String("10.0.0.0/16"), - Origin: aws.String("CreateRouteTable"), // default route - }, - { - DestinationCidrBlock: aws.String("10.1.1.0/24"), - GatewayId: aws.String("igw-030e74f73bd67f21b"), - }, - { - DestinationCidrBlock: aws.String("10.1.2.0/24"), - GatewayId: aws.String("igw-030e74f73bd67f21b"), - }, - }, - }, - { - RouteTableId: aws.String(""), // table3 - Routes: []*ec2.Route{ - { - DestinationCidrBlock: aws.String("10.0.0.0/16"), - Origin: aws.String("CreateRouteTable"), // default route - }, - }, - }, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*ec2.RouteTable{ - { - RouteTableId: aws.String("rtb-096bdfb69309c54c3"), // table1 - Routes: []*ec2.Route{ - { - DestinationCidrBlock: aws.String("10.0.0.0/16"), - Origin: aws.String("CreateRouteTable"), // default route - }, - { - DestinationCidrBlock: aws.String("1.1.1.1/32"), - GatewayId: aws.String("igw-030e74f73bd67f21b"), - }, - { - DestinationIpv6CidrBlock: aws.String("::/0"), - GatewayId: aws.String("igw-030e74f73bd67f21b"), - }, - }, - }, - { - RouteTableId: aws.String("rtb-0169b0937fd963ddc"), // table2 - Routes: []*ec2.Route{ - { - DestinationCidrBlock: aws.String("10.0.0.0/16"), - Origin: aws.String("CreateRouteTable"), // default route - }, - { - DestinationCidrBlock: aws.String("0.0.0.0/0"), - GatewayId: aws.String("igw-030e74f73bd67f21b"), - }, - { - DestinationIpv6CidrBlock: aws.String("::/0"), - GatewayId: aws.String("igw-030e74f73bd67f21b"), - }, - }, - }, - { - RouteTableId: aws.String("rtb-02780c485f0be93c5"), // default_table - VpcId: aws.String("vpc-09fe5abc2309ba49d"), - Associations: []*ec2.RouteTableAssociation{ - { - Main: aws.Bool(true), - }, - }, - Routes: []*ec2.Route{ - { - DestinationCidrBlock: aws.String("10.0.0.0/16"), - Origin: aws.String("CreateRouteTable"), // default route - }, - { - DestinationCidrBlock: aws.String("10.1.1.0/24"), - GatewayId: aws.String("igw-030e74f73bd67f21b"), - }, - { - DestinationCidrBlock: aws.String("10.1.2.0/24"), - GatewayId: aws.String("igw-030e74f73bd67f21b"), - }, - }, - }, - { - RouteTableId: aws.String(""), // table3 - Routes: []*ec2.Route{ - { - DestinationCidrBlock: aws.String("10.0.0.0/16"), - Origin: aws.String("CreateRouteTable"), // default route - }, - }, - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeEC2{} - tt.mocks(client) - r := &ec2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllRouteTables() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllRouteTables() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*ec2.RouteTable{}, store.Get("ec2ListAllRouteTables")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_ec2Repository_ListAllVPCs(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2) - wantVPC []*ec2.Vpc - wantDefaultVPC []*ec2.Vpc - wantErr error - }{ - { - name: "mixed default VPC and VPC", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeVpcsPages", - &ec2.DescribeVpcsInput{}, - mock.MatchedBy(func(callback func(res *ec2.DescribeVpcsOutput, lastPage bool) bool) bool { - callback(&ec2.DescribeVpcsOutput{ - Vpcs: []*ec2.Vpc{ - { - VpcId: aws.String("vpc-a8c5d4c1"), - IsDefault: aws.Bool(true), - }, - { - VpcId: aws.String("vpc-0768e1fd0029e3fc3"), - }, - { - VpcId: aws.String("vpc-020b072316a95b97f"), - IsDefault: aws.Bool(false), - }, - }, - }, false) - callback(&ec2.DescribeVpcsOutput{ - Vpcs: []*ec2.Vpc{ - { - VpcId: aws.String("vpc-02c50896b59598761"), - IsDefault: aws.Bool(false), - }, - }, - }, true) - return true - })).Return(nil).Once() - }, - wantVPC: []*ec2.Vpc{ - { - VpcId: aws.String("vpc-0768e1fd0029e3fc3"), - }, - { - VpcId: aws.String("vpc-020b072316a95b97f"), - IsDefault: aws.Bool(false), - }, - { - VpcId: aws.String("vpc-02c50896b59598761"), - IsDefault: aws.Bool(false), - }, - }, - wantDefaultVPC: []*ec2.Vpc{ - { - VpcId: aws.String("vpc-a8c5d4c1"), - IsDefault: aws.Bool(true), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(2) - client := &awstest.MockFakeEC2{} - tt.mocks(client) - r := &ec2Repository{ - client: client, - cache: store, - } - gotVPCs, gotDefaultVPCs, err := r.ListAllVPCs() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, cachedDefaultData, err := r.ListAllVPCs() - assert.NoError(t, err) - assert.Equal(t, gotVPCs, cachedData) - assert.Equal(t, gotDefaultVPCs, cachedDefaultData) - assert.IsType(t, []*ec2.Vpc{}, store.Get("ec2ListAllVPCs")) - assert.IsType(t, []*ec2.Vpc{}, store.Get("ec2ListAllDefaultVPCs")) - } - - changelog, err := diff.Diff(gotVPCs, tt.wantVPC) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - changelog, err = diff.Diff(gotDefaultVPCs, tt.wantDefaultVPC) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_ec2Repository_ListAllSecurityGroups(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2) - wantSecurityGroup []*ec2.SecurityGroup - wantDefaultSecurityGroup []*ec2.SecurityGroup - wantErr error - }{ - { - name: "List with 1 pages", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeSecurityGroupsPages", - &ec2.DescribeSecurityGroupsInput{}, - mock.MatchedBy(func(callback func(res *ec2.DescribeSecurityGroupsOutput, lastPage bool) bool) bool { - callback(&ec2.DescribeSecurityGroupsOutput{ - SecurityGroups: []*ec2.SecurityGroup{ - { - GroupId: aws.String("sg-0254c038e32f25530"), - GroupName: aws.String("foo"), - }, - { - GroupId: aws.String("sg-9e0204ff"), - GroupName: aws.String("default"), - }, - }, - }, true) - return true - })).Return(nil).Once() - }, - wantSecurityGroup: []*ec2.SecurityGroup{ - { - GroupId: aws.String("sg-0254c038e32f25530"), - GroupName: aws.String("foo"), - }, - }, - wantDefaultSecurityGroup: []*ec2.SecurityGroup{ - { - GroupId: aws.String("sg-9e0204ff"), - GroupName: aws.String("default"), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(2) - client := &awstest.MockFakeEC2{} - tt.mocks(client) - r := &ec2Repository{ - client: client, - cache: store, - } - gotSecurityGroups, gotDefaultSecurityGroups, err := r.ListAllSecurityGroups() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, cachedDefaultData, err := r.ListAllSecurityGroups() - assert.NoError(t, err) - assert.Equal(t, gotSecurityGroups, cachedData) - assert.Equal(t, gotDefaultSecurityGroups, cachedDefaultData) - assert.IsType(t, []*ec2.SecurityGroup{}, store.Get("ec2ListAllSecurityGroups")) - assert.IsType(t, []*ec2.SecurityGroup{}, store.Get("ec2ListAllDefaultSecurityGroups")) - } - - changelog, err := diff.Diff(gotSecurityGroups, tt.wantSecurityGroup) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - changelog, err = diff.Diff(gotDefaultSecurityGroups, tt.wantDefaultSecurityGroup) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_ec2Repository_ListAllNetworkACLs(t *testing.T) { - - testErr := errors.New("test") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2) - want []*ec2.NetworkAcl - wantErr error - }{ - { - name: "List with 1 pages", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeNetworkAclsPages", - &ec2.DescribeNetworkAclsInput{}, - mock.MatchedBy(func(callback func(res *ec2.DescribeNetworkAclsOutput, lastPage bool) bool) bool { - callback(&ec2.DescribeNetworkAclsOutput{ - NetworkAcls: []*ec2.NetworkAcl{ - { - NetworkAclId: aws.String("id1"), - }, - { - NetworkAclId: aws.String("id2"), - }, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*ec2.NetworkAcl{ - { - NetworkAclId: aws.String("id1"), - }, - { - NetworkAclId: aws.String("id2"), - }, - }, - }, - { - name: "List return error", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeNetworkAclsPages", - &ec2.DescribeNetworkAclsInput{}, - mock.Anything, - ).Return(testErr) - }, - wantErr: testErr, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(2) - client := &awstest.MockFakeEC2{} - tt.mocks(client) - r := &ec2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllNetworkACLs() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllNetworkACLs() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*ec2.NetworkAcl{}, store.Get("ec2ListAllNetworkACLs")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - client.AssertExpectations(t) - }) - } -} - -func Test_ec2Repository_DescribeLaunchTemplates(t *testing.T) { - - testErr := errors.New("test") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2) - want []*ec2.LaunchTemplate - wantErr error - }{ - { - name: "List with 1 pages", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeLaunchTemplates", - &ec2.DescribeLaunchTemplatesInput{}, - ).Return(&ec2.DescribeLaunchTemplatesOutput{ - LaunchTemplates: []*ec2.LaunchTemplate{ - { - LaunchTemplateId: aws.String("id1"), - }, - { - LaunchTemplateId: aws.String("id2"), - }, - }, - }, nil).Once() - }, - want: []*ec2.LaunchTemplate{ - { - LaunchTemplateId: aws.String("id1"), - }, - { - LaunchTemplateId: aws.String("id2"), - }, - }, - }, - { - name: "List return error", - mocks: func(client *awstest.MockFakeEC2) { - client.On("DescribeLaunchTemplates", - &ec2.DescribeLaunchTemplatesInput{}, - ).Return(nil, testErr) - }, - wantErr: testErr, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeEC2{} - tt.mocks(client) - r := &ec2Repository{ - client: client, - cache: store, - } - got, err := r.DescribeLaunchTemplates() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.DescribeLaunchTemplates() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*ec2.LaunchTemplate{}, store.Get("DescribeLaunchTemplates")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - client.AssertExpectations(t) - }) - } -} - -func Test_ec2Repository_IsEbsEncryptionEnabledByDefault(t *testing.T) { - - testErr := errors.New("test") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeEC2, store *cache.MockCache) - want bool - wantErr error - }{ - { - name: "test that encryption enabled by default", - mocks: func(client *awstest.MockFakeEC2, store *cache.MockCache) { - store.On("Get", "ec2IsEbsEncryptionEnabledByDefault"). - Return(nil). - Once() - - client.On("GetEbsEncryptionByDefault", - &ec2.GetEbsEncryptionByDefaultInput{}, - ).Return(&ec2.GetEbsEncryptionByDefaultOutput{ - EbsEncryptionByDefault: aws.Bool(true), - }, nil).Once() - - store.On("Put", "ec2IsEbsEncryptionEnabledByDefault", true). - Return(false). - Once() - }, - want: true, - }, - { - name: "test that encryption enabled by default (cached)", - mocks: func(client *awstest.MockFakeEC2, store *cache.MockCache) { - store.On("Get", "ec2IsEbsEncryptionEnabledByDefault"). - Return(false). - Once() - }, - want: false, - }, - { - name: "error while getting default encryption value", - mocks: func(client *awstest.MockFakeEC2, store *cache.MockCache) { - store.On("Get", "ec2IsEbsEncryptionEnabledByDefault"). - Return(nil). - Once() - - client.On("GetEbsEncryptionByDefault", - &ec2.GetEbsEncryptionByDefaultInput{}, - ).Return(nil, testErr).Once() - }, - wantErr: testErr, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeEC2{} - tt.mocks(client, store) - r := &ec2Repository{ - client: client, - cache: store, - } - got, err := r.IsEbsEncryptionEnabledByDefault() - - assert.Equal(t, tt.wantErr, err) - assert.Equal(t, tt.want, got) - - client.AssertExpectations(t) - store.AssertExpectations(t) - }) - } -} diff --git a/pkg/remote/aws/repository/ecr_repository.go b/pkg/remote/aws/repository/ecr_repository.go deleted file mode 100644 index 9ca2f50f..00000000 --- a/pkg/remote/aws/repository/ecr_repository.go +++ /dev/null @@ -1,66 +0,0 @@ -package repository - -import ( - "fmt" - - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ecr" - "github.com/aws/aws-sdk-go/service/ecr/ecriface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type ECRRepository interface { - ListAllRepositories() ([]*ecr.Repository, error) - GetRepositoryPolicy(*ecr.Repository) (*ecr.GetRepositoryPolicyOutput, error) -} - -type ecrRepository struct { - client ecriface.ECRAPI - cache cache.Cache -} - -func NewECRRepository(session *session.Session, c cache.Cache) *ecrRepository { - return &ecrRepository{ - ecr.New(session), - c, - } -} - -func (r *ecrRepository) ListAllRepositories() ([]*ecr.Repository, error) { - if v := r.cache.Get("ecrListAllRepositories"); v != nil { - return v.([]*ecr.Repository), nil - } - - var repositories []*ecr.Repository - input := &ecr.DescribeRepositoriesInput{} - err := r.client.DescribeRepositoriesPages(input, func(res *ecr.DescribeRepositoriesOutput, lastPage bool) bool { - repositories = append(repositories, res.Repositories...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put("ecrListAllRepositories", repositories) - return repositories, nil -} - -func (r *ecrRepository) GetRepositoryPolicy(repo *ecr.Repository) (*ecr.GetRepositoryPolicyOutput, error) { - cacheKey := fmt.Sprintf("ecrListAllRepositoriesGetRepositoryPolicy_%s_%s", *repo.RegistryId, *repo.RepositoryName) - if v := r.cache.Get(cacheKey); v != nil { - return v.(*ecr.GetRepositoryPolicyOutput), nil - } - - var repositoryPolicyInput *ecr.GetRepositoryPolicyInput = &ecr.GetRepositoryPolicyInput{ - RegistryId: repo.RegistryId, - RepositoryName: repo.RepositoryName, - } - - repoOutput, err := r.client.GetRepositoryPolicy(repositoryPolicyInput) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, repoOutput) - return repoOutput, nil -} diff --git a/pkg/remote/aws/repository/ecr_repository_test.go b/pkg/remote/aws/repository/ecr_repository_test.go deleted file mode 100644 index 2c8f8190..00000000 --- a/pkg/remote/aws/repository/ecr_repository_test.go +++ /dev/null @@ -1,171 +0,0 @@ -package repository - -import ( - "fmt" - "strings" - "testing" - - "github.com/aws/aws-sdk-go/service/ecr" - "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - - "github.com/aws/aws-sdk-go/aws" - - "github.com/stretchr/testify/mock" - - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" -) - -func Test_ecrRepository_ListAllRepositories(t *testing.T) { - - tests := []struct { - name string - mocks func(client *awstest.MockFakeECR) - want []*ecr.Repository - wantErr error - }{ - { - name: "List with 2 pages", - mocks: func(client *awstest.MockFakeECR) { - client.On("DescribeRepositoriesPages", - &ecr.DescribeRepositoriesInput{}, - mock.MatchedBy(func(callback func(res *ecr.DescribeRepositoriesOutput, lastPage bool) bool) bool { - callback(&ecr.DescribeRepositoriesOutput{ - Repositories: []*ecr.Repository{ - {RepositoryName: aws.String("1")}, - {RepositoryName: aws.String("2")}, - {RepositoryName: aws.String("3")}, - }, - }, false) - callback(&ecr.DescribeRepositoriesOutput{ - Repositories: []*ecr.Repository{ - {RepositoryName: aws.String("4")}, - {RepositoryName: aws.String("5")}, - {RepositoryName: aws.String("6")}, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*ecr.Repository{ - {RepositoryName: aws.String("1")}, - {RepositoryName: aws.String("2")}, - {RepositoryName: aws.String("3")}, - {RepositoryName: aws.String("4")}, - {RepositoryName: aws.String("5")}, - {RepositoryName: aws.String("6")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := awstest.MockFakeECR{} - tt.mocks(&client) - r := &ecrRepository{ - client: &client, - cache: store, - } - got, err := r.ListAllRepositories() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllRepositories() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*ecr.Repository{}, store.Get("ecrListAllRepositories")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_ecrRepository_GetRepositoryPolicy(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeECR) - want *ecr.GetRepositoryPolicyOutput - wantErr error - }{ - { - name: "Get repository policy", - mocks: func(client *awstest.MockFakeECR) { - client.On("GetRepositoryPolicy", - &ecr.GetRepositoryPolicyInput{ - RegistryId: aws.String("1"), - RepositoryName: aws.String("2"), - }, - ).Return(&ecr.GetRepositoryPolicyOutput{ - RegistryId: aws.String("1"), - RepositoryName: aws.String("2"), - }, nil).Once() - }, - want: &ecr.GetRepositoryPolicyOutput{ - RegistryId: aws.String("1"), - RepositoryName: aws.String("2"), - }, - }, - { - name: "Get repository policy error", - mocks: func(client *awstest.MockFakeECR) { - client.On("GetRepositoryPolicy", - &ecr.GetRepositoryPolicyInput{ - RegistryId: aws.String("1"), - RepositoryName: aws.String("2"), - }, - ).Return(nil, dummyError).Once() - }, - wantErr: dummyError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := awstest.MockFakeECR{} - tt.mocks(&client) - r := &ecrRepository{ - client: &client, - cache: store, - } - - repo := &ecr.Repository{ - RegistryId: aws.String("1"), - RepositoryName: aws.String("2"), - } - - got, err := r.GetRepositoryPolicy(repo) - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.GetRepositoryPolicy(repo) - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - cacheKey := fmt.Sprintf("ecrListAllRepositoriesGetRepositoryPolicy_%s_%s", *repo.RegistryId, *repo.RepositoryName) - assert.IsType(t, &ecr.GetRepositoryPolicyOutput{}, store.Get(cacheKey)) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} diff --git a/pkg/remote/aws/repository/elasticache_repository.go b/pkg/remote/aws/repository/elasticache_repository.go deleted file mode 100644 index 8d05cf0a..00000000 --- a/pkg/remote/aws/repository/elasticache_repository.go +++ /dev/null @@ -1,45 +0,0 @@ -package repository - -import ( - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/elasticache" - "github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type ElastiCacheRepository interface { - ListAllCacheClusters() ([]*elasticache.CacheCluster, error) -} - -type elasticacheRepository struct { - client elasticacheiface.ElastiCacheAPI - cache cache.Cache -} - -func NewElastiCacheRepository(session *session.Session, c cache.Cache) *elasticacheRepository { - return &elasticacheRepository{ - elasticache.New(session), - c, - } -} - -func (r *elasticacheRepository) ListAllCacheClusters() ([]*elasticache.CacheCluster, error) { - if v := r.cache.Get("elasticacheListAllCacheClusters"); v != nil { - return v.([]*elasticache.CacheCluster), nil - } - - var clusters []*elasticache.CacheCluster - input := elasticache.DescribeCacheClustersInput{} - err := r.client.DescribeCacheClustersPages(&input, - func(resp *elasticache.DescribeCacheClustersOutput, lastPage bool) bool { - clusters = append(clusters, resp.CacheClusters...) - return !lastPage - }, - ) - if err != nil { - return nil, err - } - - r.cache.Put("elasticacheListAllCacheClusters", clusters) - return clusters, nil -} diff --git a/pkg/remote/aws/repository/elasticache_repository_test.go b/pkg/remote/aws/repository/elasticache_repository_test.go deleted file mode 100644 index fa01590b..00000000 --- a/pkg/remote/aws/repository/elasticache_repository_test.go +++ /dev/null @@ -1,96 +0,0 @@ -package repository - -import ( - "strings" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/elasticache" - "github.com/pkg/errors" - "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func Test_elasticacheRepository_ListAllCacheClusters(t *testing.T) { - clusters := []*elasticache.CacheCluster{ - {CacheClusterId: aws.String("cluster1")}, - {CacheClusterId: aws.String("cluster2")}, - {CacheClusterId: aws.String("cluster3")}, - {CacheClusterId: aws.String("cluster4")}, - {CacheClusterId: aws.String("cluster5")}, - {CacheClusterId: aws.String("cluster6")}, - } - - remoteError := errors.New("remote error") - - tests := []struct { - name string - mocks func(client *awstest.MockFakeElastiCache, store *cache.MockCache) - want []*elasticache.CacheCluster - wantErr error - }{ - { - name: "List cache clusters", - mocks: func(client *awstest.MockFakeElastiCache, store *cache.MockCache) { - client.On("DescribeCacheClustersPages", - &elasticache.DescribeCacheClustersInput{}, - mock.MatchedBy(func(callback func(res *elasticache.DescribeCacheClustersOutput, lastPage bool) bool) bool { - callback(&elasticache.DescribeCacheClustersOutput{ - CacheClusters: clusters[:3], - }, false) - callback(&elasticache.DescribeCacheClustersOutput{ - CacheClusters: clusters[3:], - }, true) - return true - })).Return(nil).Once() - store.On("Get", "elasticacheListAllCacheClusters").Return(nil).Times(1) - store.On("Put", "elasticacheListAllCacheClusters", clusters).Return(false).Times(1) - }, - want: clusters, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeElastiCache, store *cache.MockCache) { - store.On("Get", "elasticacheListAllCacheClusters").Return(clusters).Times(1) - }, - want: clusters, - }, - { - name: "should return remote error", - mocks: func(client *awstest.MockFakeElastiCache, store *cache.MockCache) { - client.On("DescribeCacheClustersPages", - &elasticache.DescribeCacheClustersInput{}, - mock.AnythingOfType("func(*elasticache.DescribeCacheClustersOutput, bool) bool")).Return(remoteError).Once() - store.On("Get", "elasticacheListAllCacheClusters").Return(nil).Times(1) - }, - wantErr: remoteError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeElastiCache{} - tt.mocks(client, store) - r := &elasticacheRepository{ - client: client, - cache: store, - } - got, err := r.ListAllCacheClusters() - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - store.AssertExpectations(t) - client.AssertExpectations(t) - }) - } -} diff --git a/pkg/remote/aws/repository/elb_repository.go b/pkg/remote/aws/repository/elb_repository.go deleted file mode 100644 index 3fe18147..00000000 --- a/pkg/remote/aws/repository/elb_repository.go +++ /dev/null @@ -1,43 +0,0 @@ -package repository - -import ( - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/elb" - "github.com/aws/aws-sdk-go/service/elb/elbiface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type ELBRepository interface { - ListAllLoadBalancers() ([]*elb.LoadBalancerDescription, error) -} - -type elbRepository struct { - client elbiface.ELBAPI - cache cache.Cache -} - -func NewELBRepository(session *session.Session, c cache.Cache) *elbRepository { - return &elbRepository{ - elb.New(session), - c, - } -} - -func (r *elbRepository) ListAllLoadBalancers() ([]*elb.LoadBalancerDescription, error) { - if v := r.cache.Get("elbListAllLoadBalancers"); v != nil { - return v.([]*elb.LoadBalancerDescription), nil - } - - results := make([]*elb.LoadBalancerDescription, 0) - input := elb.DescribeLoadBalancersInput{} - err := r.client.DescribeLoadBalancersPages(&input, func(res *elb.DescribeLoadBalancersOutput, lastPage bool) bool { - results = append(results, res.LoadBalancerDescriptions...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put("elbListAllLoadBalancers", results) - return results, nil -} diff --git a/pkg/remote/aws/repository/elb_repository_test.go b/pkg/remote/aws/repository/elb_repository_test.go deleted file mode 100644 index 48be952f..00000000 --- a/pkg/remote/aws/repository/elb_repository_test.go +++ /dev/null @@ -1,119 +0,0 @@ -package repository - -import ( - "errors" - "strings" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/elb" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - - "github.com/stretchr/testify/mock" - - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" -) - -func Test_ELBRepository_ListAllLoadBalancers(t *testing.T) { - dummyErr := errors.New("dummy error") - - results := []*elb.LoadBalancerDescription{ - { - LoadBalancerName: aws.String("test-lb-1"), - }, - { - LoadBalancerName: aws.String("test-lb-2"), - }, - } - - tests := []struct { - name string - mocks func(*awstest.MockFakeELB, *cache.MockCache) - want []*elb.LoadBalancerDescription - wantErr error - }{ - { - name: "List load balancers with multiple pages", - mocks: func(client *awstest.MockFakeELB, store *cache.MockCache) { - store.On("Get", "elbListAllLoadBalancers").Return(nil).Once() - - client.On("DescribeLoadBalancersPages", - &elb.DescribeLoadBalancersInput{}, - mock.MatchedBy(func(callback func(res *elb.DescribeLoadBalancersOutput, lastPage bool) bool) bool { - callback(&elb.DescribeLoadBalancersOutput{LoadBalancerDescriptions: []*elb.LoadBalancerDescription{ - results[0], - }}, false) - callback(&elb.DescribeLoadBalancersOutput{LoadBalancerDescriptions: []*elb.LoadBalancerDescription{ - results[1], - }}, true) - return true - })).Return(nil).Once() - - store.On("Put", "elbListAllLoadBalancers", results).Return(false).Once() - }, - want: []*elb.LoadBalancerDescription{ - { - LoadBalancerName: aws.String("test-lb-1"), - }, - { - LoadBalancerName: aws.String("test-lb-2"), - }, - }, - }, - { - name: "List load balancers with multiple pages (cache hit)", - mocks: func(client *awstest.MockFakeELB, store *cache.MockCache) { - store.On("Get", "elbListAllLoadBalancers").Return(results).Once() - }, - want: []*elb.LoadBalancerDescription{ - { - LoadBalancerName: aws.String("test-lb-1"), - }, - { - LoadBalancerName: aws.String("test-lb-2"), - }, - }, - }, - { - name: "Error listing load balancers", - mocks: func(client *awstest.MockFakeELB, store *cache.MockCache) { - store.On("Get", "elbListAllLoadBalancers").Return(nil).Once() - - client.On("DescribeLoadBalancersPages", - &elb.DescribeLoadBalancersInput{}, - mock.MatchedBy(func(callback func(res *elb.DescribeLoadBalancersOutput, lastPage bool) bool) bool { - callback(&elb.DescribeLoadBalancersOutput{LoadBalancerDescriptions: []*elb.LoadBalancerDescription{}}, true) - return true - })).Return(dummyErr).Once() - }, - wantErr: dummyErr, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeELB{} - tt.mocks(client, store) - r := &elbRepository{ - client: client, - cache: store, - } - got, err := r.ListAllLoadBalancers() - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - - client.AssertExpectations(t) - store.AssertExpectations(t) - }) - } -} diff --git a/pkg/remote/aws/repository/elbv2_repository.go b/pkg/remote/aws/repository/elbv2_repository.go deleted file mode 100644 index aab45686..00000000 --- a/pkg/remote/aws/repository/elbv2_repository.go +++ /dev/null @@ -1,65 +0,0 @@ -package repository - -import ( - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/elbv2" - "github.com/aws/aws-sdk-go/service/elbv2/elbv2iface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type ELBV2Repository interface { - ListAllLoadBalancers() ([]*elbv2.LoadBalancer, error) - ListAllLoadBalancerListeners(string) ([]*elbv2.Listener, error) -} - -type elbv2Repository struct { - client elbv2iface.ELBV2API - cache cache.Cache -} - -func NewELBV2Repository(session *session.Session, c cache.Cache) *elbv2Repository { - return &elbv2Repository{ - elbv2.New(session), - c, - } -} - -func (r *elbv2Repository) ListAllLoadBalancers() ([]*elbv2.LoadBalancer, error) { - cacheKey := "elbv2ListAllLoadBalancers" - defer r.cache.Unlock(cacheKey) - if v := r.cache.GetAndLock(cacheKey); v != nil { - return v.([]*elbv2.LoadBalancer), nil - } - - results := make([]*elbv2.LoadBalancer, 0) - input := &elbv2.DescribeLoadBalancersInput{} - err := r.client.DescribeLoadBalancersPages(input, func(res *elbv2.DescribeLoadBalancersOutput, lastPage bool) bool { - results = append(results, res.LoadBalancers...) - return !lastPage - }) - if err != nil { - return nil, err - } - r.cache.Put(cacheKey, results) - return results, err -} - -func (r *elbv2Repository) ListAllLoadBalancerListeners(loadBalancerArn string) ([]*elbv2.Listener, error) { - if v := r.cache.Get("elbv2ListAllLoadBalancerListeners"); v != nil { - return v.([]*elbv2.Listener), nil - } - - results := make([]*elbv2.Listener, 0) - input := &elbv2.DescribeListenersInput{ - LoadBalancerArn: &loadBalancerArn, - } - err := r.client.DescribeListenersPages(input, func(res *elbv2.DescribeListenersOutput, lastPage bool) bool { - results = append(results, res.Listeners...) - return !lastPage - }) - if err != nil { - return nil, err - } - r.cache.Put("elbv2ListAllLoadBalancerListeners", results) - return results, err -} diff --git a/pkg/remote/aws/repository/elbv2_repository_test.go b/pkg/remote/aws/repository/elbv2_repository_test.go deleted file mode 100644 index a96b9690..00000000 --- a/pkg/remote/aws/repository/elbv2_repository_test.go +++ /dev/null @@ -1,243 +0,0 @@ -package repository - -import ( - "strings" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/elbv2" - "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - "github.com/stretchr/testify/mock" - - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" -) - -func Test_ELBV2Repository_ListAllLoadBalancers(t *testing.T) { - dummyError := errors.New("dummy error") - - tests := []struct { - name string - mocks func(*awstest.MockFakeELBV2, *cache.MockCache) - want []*elbv2.LoadBalancer - wantErr error - }{ - { - name: "list load balancers", - mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { - results := &elbv2.DescribeLoadBalancersOutput{ - LoadBalancers: []*elbv2.LoadBalancer{ - { - LoadBalancerArn: aws.String("test-1"), - LoadBalancerName: aws.String("test-1"), - }, - { - LoadBalancerArn: aws.String("test-2"), - LoadBalancerName: aws.String("test-2"), - }, - }, - } - - store.On("GetAndLock", "elbv2ListAllLoadBalancers").Return(nil).Once() - store.On("Unlock", "elbv2ListAllLoadBalancers").Return().Once() - - client.On("DescribeLoadBalancersPages", - &elbv2.DescribeLoadBalancersInput{}, - mock.MatchedBy(func(callback func(res *elbv2.DescribeLoadBalancersOutput, lastPage bool) bool) bool { - callback(&elbv2.DescribeLoadBalancersOutput{LoadBalancers: []*elbv2.LoadBalancer{ - results.LoadBalancers[0], - }}, false) - callback(&elbv2.DescribeLoadBalancersOutput{LoadBalancers: []*elbv2.LoadBalancer{ - results.LoadBalancers[1], - }}, true) - return true - })).Return(nil).Once() - - store.On("Put", "elbv2ListAllLoadBalancers", results.LoadBalancers).Return(false).Once() - }, - want: []*elbv2.LoadBalancer{ - { - LoadBalancerArn: aws.String("test-1"), - LoadBalancerName: aws.String("test-1"), - }, - { - LoadBalancerArn: aws.String("test-2"), - LoadBalancerName: aws.String("test-2"), - }, - }, - }, - { - name: "list load balancers from cache", - mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { - output := &elbv2.DescribeLoadBalancersOutput{ - LoadBalancers: []*elbv2.LoadBalancer{ - { - LoadBalancerArn: aws.String("test-1"), - LoadBalancerName: aws.String("test-1"), - }, - }, - } - - store.On("GetAndLock", "elbv2ListAllLoadBalancers").Return(output.LoadBalancers).Once() - store.On("Unlock", "elbv2ListAllLoadBalancers").Return().Once() - }, - want: []*elbv2.LoadBalancer{ - { - LoadBalancerArn: aws.String("test-1"), - LoadBalancerName: aws.String("test-1"), - }, - }, - }, - { - name: "error listing load balancers", - mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { - store.On("GetAndLock", "elbv2ListAllLoadBalancers").Return(nil).Once() - store.On("Unlock", "elbv2ListAllLoadBalancers").Return().Once() - - client.On("DescribeLoadBalancersPages", - &elbv2.DescribeLoadBalancersInput{}, - mock.MatchedBy(func(callback func(res *elbv2.DescribeLoadBalancersOutput, lastPage bool) bool) bool { - callback(&elbv2.DescribeLoadBalancersOutput{LoadBalancers: []*elbv2.LoadBalancer{}}, true) - return true - })).Return(dummyError).Once() - }, - wantErr: dummyError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeELBV2{} - tt.mocks(client, store) - r := &elbv2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllLoadBalancers() - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_ELBV2Repository_ListAllLoadBalancerListeners(t *testing.T) { - dummyError := errors.New("dummy error") - - tests := []struct { - name string - mocks func(*awstest.MockFakeELBV2, *cache.MockCache) - want []*elbv2.Listener - wantErr error - }{ - { - name: "list load balancer listeners", - mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { - results := &elbv2.DescribeListenersOutput{ - Listeners: []*elbv2.Listener{ - { - LoadBalancerArn: aws.String("test-lb"), - ListenerArn: aws.String("test-lb-listener-1"), - }, - { - LoadBalancerArn: aws.String("test-lb"), - ListenerArn: aws.String("test-lb-listener-2"), - }, - }, - } - - store.On("Get", "elbv2ListAllLoadBalancerListeners").Return(nil).Once() - - client.On("DescribeListenersPages", - &elbv2.DescribeListenersInput{LoadBalancerArn: aws.String("test-lb")}, - mock.MatchedBy(func(callback func(res *elbv2.DescribeListenersOutput, lastPage bool) bool) bool { - callback(&elbv2.DescribeListenersOutput{Listeners: []*elbv2.Listener{ - results.Listeners[0], - }}, false) - callback(&elbv2.DescribeListenersOutput{Listeners: []*elbv2.Listener{ - results.Listeners[1], - }}, true) - return true - })).Return(nil).Once() - - store.On("Put", "elbv2ListAllLoadBalancerListeners", results.Listeners).Return(false).Once() - }, - want: []*elbv2.Listener{ - { - LoadBalancerArn: aws.String("test-lb"), - ListenerArn: aws.String("test-lb-listener-1"), - }, - { - LoadBalancerArn: aws.String("test-lb"), - ListenerArn: aws.String("test-lb-listener-2"), - }, - }, - }, - { - name: "list load balancer listeners from cache", - mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { - output := &elbv2.DescribeListenersOutput{ - Listeners: []*elbv2.Listener{ - { - LoadBalancerArn: aws.String("test-lb"), - ListenerArn: aws.String("test-lb-listener"), - }, - }, - } - - store.On("Get", "elbv2ListAllLoadBalancerListeners").Return(output.Listeners).Once() - }, - want: []*elbv2.Listener{ - { - LoadBalancerArn: aws.String("test-lb"), - ListenerArn: aws.String("test-lb-listener"), - }, - }, - }, - { - name: "error listing load balancer listeners", - mocks: func(client *awstest.MockFakeELBV2, store *cache.MockCache) { - store.On("Get", "elbv2ListAllLoadBalancerListeners").Return(nil).Once() - - client.On("DescribeListenersPages", - &elbv2.DescribeListenersInput{LoadBalancerArn: aws.String("test-lb")}, - mock.MatchedBy(func(callback func(res *elbv2.DescribeListenersOutput, lastPage bool) bool) bool { - callback(&elbv2.DescribeListenersOutput{Listeners: []*elbv2.Listener{}}, true) - return true - })).Return(dummyError).Once() - }, - wantErr: dummyError, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeELBV2{} - tt.mocks(client, store) - r := &elbv2Repository{ - client: client, - cache: store, - } - got, err := r.ListAllLoadBalancerListeners("test-lb") - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} diff --git a/pkg/remote/aws/repository/iam_repository.go b/pkg/remote/aws/repository/iam_repository.go deleted file mode 100644 index c828c1db..00000000 --- a/pkg/remote/aws/repository/iam_repository.go +++ /dev/null @@ -1,367 +0,0 @@ -package repository - -import ( - "fmt" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/iam" - "github.com/aws/aws-sdk-go/service/iam/iamiface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type IAMRepository interface { - ListAllAccessKeys([]*iam.User) ([]*iam.AccessKeyMetadata, error) - ListAllUsers() ([]*iam.User, error) - ListAllPolicies() ([]*iam.Policy, error) - ListAllRoles() ([]*iam.Role, error) - ListAllRolePolicyAttachments([]*iam.Role) ([]*AttachedRolePolicy, error) - ListAllRolePolicies([]*iam.Role) ([]RolePolicy, error) - ListAllUserPolicyAttachments([]*iam.User) ([]*AttachedUserPolicy, error) - ListAllUserPolicies([]*iam.User) ([]string, error) - ListAllGroups() ([]*iam.Group, error) - ListAllGroupPolicies([]*iam.Group) ([]string, error) - ListAllGroupPolicyAttachments([]*iam.Group) ([]*AttachedGroupPolicy, error) -} - -type iamRepository struct { - client iamiface.IAMAPI - cache cache.Cache -} - -func NewIAMRepository(session *session.Session, c cache.Cache) *iamRepository { - return &iamRepository{ - iam.New(session), - c, - } -} - -func (r *iamRepository) ListAllAccessKeys(users []*iam.User) ([]*iam.AccessKeyMetadata, error) { - var resources []*iam.AccessKeyMetadata - for _, user := range users { - cacheKey := fmt.Sprintf("iamListAllAccessKeys_user_%s", *user.UserName) - if v := r.cache.Get(cacheKey); v != nil { - resources = append(resources, v.([]*iam.AccessKeyMetadata)...) - continue - } - - userResources := make([]*iam.AccessKeyMetadata, 0) - input := &iam.ListAccessKeysInput{ - UserName: user.UserName, - } - err := r.client.ListAccessKeysPages(input, func(res *iam.ListAccessKeysOutput, lastPage bool) bool { - userResources = append(userResources, res.AccessKeyMetadata...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, userResources) - resources = append(resources, userResources...) - } - - return resources, nil -} - -func (r *iamRepository) ListAllUsers() ([]*iam.User, error) { - - cacheKey := "iamListAllUsers" - v := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - if v != nil { - return v.([]*iam.User), nil - } - - var resources []*iam.User - input := &iam.ListUsersInput{} - err := r.client.ListUsersPages(input, func(res *iam.ListUsersOutput, lastPage bool) bool { - resources = append(resources, res.Users...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resources) - return resources, nil -} - -func (r *iamRepository) ListAllPolicies() ([]*iam.Policy, error) { - if v := r.cache.Get("iamListAllPolicies"); v != nil { - return v.([]*iam.Policy), nil - } - - var resources []*iam.Policy - input := &iam.ListPoliciesInput{ - Scope: aws.String(iam.PolicyScopeTypeLocal), - } - err := r.client.ListPoliciesPages(input, func(res *iam.ListPoliciesOutput, lastPage bool) bool { - resources = append(resources, res.Policies...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put("iamListAllPolicies", resources) - return resources, nil -} - -func (r *iamRepository) ListAllRoles() ([]*iam.Role, error) { - cacheKey := "iamListAllRoles" - v := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - if v != nil { - return v.([]*iam.Role), nil - } - - var resources []*iam.Role - input := &iam.ListRolesInput{} - err := r.client.ListRolesPages(input, func(res *iam.ListRolesOutput, lastPage bool) bool { - resources = append(resources, res.Roles...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resources) - return resources, nil -} - -func (r *iamRepository) ListAllRolePolicyAttachments(roles []*iam.Role) ([]*AttachedRolePolicy, error) { - var resources []*AttachedRolePolicy - for _, role := range roles { - cacheKey := fmt.Sprintf("iamListAllRolePolicyAttachments_role_%s", *role.RoleName) - if v := r.cache.Get(cacheKey); v != nil { - resources = append(resources, v.([]*AttachedRolePolicy)...) - continue - } - - roleResources := make([]*AttachedRolePolicy, 0) - input := &iam.ListAttachedRolePoliciesInput{ - RoleName: role.RoleName, - } - err := r.client.ListAttachedRolePoliciesPages(input, func(res *iam.ListAttachedRolePoliciesOutput, lastPage bool) bool { - for _, policy := range res.AttachedPolicies { - p := *policy - roleResources = append(roleResources, &AttachedRolePolicy{ - AttachedPolicy: p, - RoleName: *input.RoleName, - }) - } - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, roleResources) - resources = append(resources, roleResources...) - } - - return resources, nil -} - -func (r *iamRepository) ListAllRolePolicies(roles []*iam.Role) ([]RolePolicy, error) { - var resources []RolePolicy - for _, role := range roles { - cacheKey := fmt.Sprintf("iamListAllRolePolicies_role_%s", *role.RoleName) - if v := r.cache.Get(cacheKey); v != nil { - resources = append(resources, v.([]RolePolicy)...) - continue - } - - roleResources := make([]RolePolicy, 0) - input := &iam.ListRolePoliciesInput{ - RoleName: role.RoleName, - } - err := r.client.ListRolePoliciesPages(input, func(res *iam.ListRolePoliciesOutput, lastPage bool) bool { - for _, policy := range res.PolicyNames { - roleResources = append(roleResources, RolePolicy{*policy, *input.RoleName}) - } - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, roleResources) - resources = append(resources, roleResources...) - } - - return resources, nil -} - -func (r *iamRepository) ListAllUserPolicyAttachments(users []*iam.User) ([]*AttachedUserPolicy, error) { - var resources []*AttachedUserPolicy - for _, user := range users { - cacheKey := fmt.Sprintf("iamListAllUserPolicyAttachments_user_%s", *user.UserName) - if v := r.cache.Get(cacheKey); v != nil { - resources = append(resources, v.([]*AttachedUserPolicy)...) - continue - } - - userResources := make([]*AttachedUserPolicy, 0) - input := &iam.ListAttachedUserPoliciesInput{ - UserName: user.UserName, - } - err := r.client.ListAttachedUserPoliciesPages(input, func(res *iam.ListAttachedUserPoliciesOutput, lastPage bool) bool { - for _, policy := range res.AttachedPolicies { - p := *policy - userResources = append(userResources, &AttachedUserPolicy{ - AttachedPolicy: p, - UserName: *input.UserName, - }) - } - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, userResources) - resources = append(resources, userResources...) - } - - return resources, nil -} - -func (r *iamRepository) ListAllUserPolicies(users []*iam.User) ([]string, error) { - var resources []string - for _, user := range users { - cacheKey := fmt.Sprintf("iamListAllUserPolicies_user_%s", *user.UserName) - if v := r.cache.Get(cacheKey); v != nil { - resources = append(resources, v.([]string)...) - continue - } - - userResources := make([]string, 0) - input := &iam.ListUserPoliciesInput{ - UserName: user.UserName, - } - err := r.client.ListUserPoliciesPages(input, func(res *iam.ListUserPoliciesOutput, lastPage bool) bool { - for _, polName := range res.PolicyNames { - userResources = append(userResources, fmt.Sprintf("%s:%s", *input.UserName, *polName)) - } - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, userResources) - resources = append(resources, userResources...) - } - - return resources, nil -} - -func (r *iamRepository) ListAllGroups() ([]*iam.Group, error) { - - cacheKey := "iamListAllGroups" - v := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - - if v != nil { - return v.([]*iam.Group), nil - } - - var resources []*iam.Group - input := &iam.ListGroupsInput{} - err := r.client.ListGroupsPages(input, func(res *iam.ListGroupsOutput, lastPage bool) bool { - resources = append(resources, res.Groups...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, resources) - return resources, nil -} - -func (r *iamRepository) ListAllGroupPolicies(groups []*iam.Group) ([]string, error) { - var resources []string - for _, group := range groups { - cacheKey := fmt.Sprintf("iamListAllGroupPolicies_group_%s", *group.GroupName) - if v := r.cache.Get(cacheKey); v != nil { - resources = append(resources, v.([]string)...) - continue - } - - groupResources := make([]string, 0) - input := &iam.ListGroupPoliciesInput{ - GroupName: group.GroupName, - } - err := r.client.ListGroupPoliciesPages(input, func(res *iam.ListGroupPoliciesOutput, lastPage bool) bool { - for _, polName := range res.PolicyNames { - groupResources = append(groupResources, fmt.Sprintf("%s:%s", *input.GroupName, *polName)) - } - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, groupResources) - resources = append(resources, groupResources...) - } - - return resources, nil -} - -func (r *iamRepository) ListAllGroupPolicyAttachments(groups []*iam.Group) ([]*AttachedGroupPolicy, error) { - var resources []*AttachedGroupPolicy - for _, group := range groups { - cacheKey := fmt.Sprintf("iamListAllGroupPolicyAttachments_%s", *group.GroupId) - if v := r.cache.Get(cacheKey); v != nil { - resources = append(resources, v.([]*AttachedGroupPolicy)...) - continue - } - - attachedGroupPolicies := make([]*AttachedGroupPolicy, 0) - input := &iam.ListAttachedGroupPoliciesInput{ - GroupName: group.GroupName, - } - err := r.client.ListAttachedGroupPoliciesPages(input, func(res *iam.ListAttachedGroupPoliciesOutput, lastPage bool) bool { - for _, policy := range res.AttachedPolicies { - p := *policy - attachedGroupPolicies = append(attachedGroupPolicies, &AttachedGroupPolicy{ - AttachedPolicy: p, - GroupName: *input.GroupName, - }) - } - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, attachedGroupPolicies) - resources = append(resources, attachedGroupPolicies...) - } - - return resources, nil -} - -type AttachedUserPolicy struct { - iam.AttachedPolicy - UserName string -} - -type AttachedRolePolicy struct { - iam.AttachedPolicy - RoleName string -} - -type AttachedGroupPolicy struct { - iam.AttachedPolicy - GroupName string -} - -type RolePolicy struct { - Policy string - RoleName string -} diff --git a/pkg/remote/aws/repository/iam_repository_test.go b/pkg/remote/aws/repository/iam_repository_test.go deleted file mode 100644 index 2aff283e..00000000 --- a/pkg/remote/aws/repository/iam_repository_test.go +++ /dev/null @@ -1,1100 +0,0 @@ -package repository - -import ( - "fmt" - "strings" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/iam" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - - "github.com/stretchr/testify/mock" - - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" -) - -func Test_IAMRepository_ListAllAccessKeys(t *testing.T) { - tests := []struct { - name string - users []*iam.User - mocks func(client *awstest.MockFakeIAM) - want []*iam.AccessKeyMetadata - wantErr error - }{ - { - name: "List only access keys with multiple pages", - users: []*iam.User{ - { - UserName: aws.String("test-driftctl"), - }, - { - UserName: aws.String("test-driftctl2"), - }, - }, - mocks: func(client *awstest.MockFakeIAM) { - - client.On("ListAccessKeysPages", - &iam.ListAccessKeysInput{ - UserName: aws.String("test-driftctl"), - }, - mock.MatchedBy(func(callback func(res *iam.ListAccessKeysOutput, lastPage bool) bool) bool { - callback(&iam.ListAccessKeysOutput{AccessKeyMetadata: []*iam.AccessKeyMetadata{ - { - AccessKeyId: aws.String("AKIA5QYBVVD223VWU32A"), - UserName: aws.String("test-driftctl"), - }, - }}, false) - callback(&iam.ListAccessKeysOutput{AccessKeyMetadata: []*iam.AccessKeyMetadata{ - { - AccessKeyId: aws.String("AKIA5QYBVVD2QYI36UZP"), - UserName: aws.String("test-driftctl"), - }, - }}, true) - return true - })).Return(nil).Once() - client.On("ListAccessKeysPages", - &iam.ListAccessKeysInput{ - UserName: aws.String("test-driftctl2"), - }, - mock.MatchedBy(func(callback func(res *iam.ListAccessKeysOutput, lastPage bool) bool) bool { - callback(&iam.ListAccessKeysOutput{AccessKeyMetadata: []*iam.AccessKeyMetadata{ - { - AccessKeyId: aws.String("AKIA5QYBVVD26EJME25D"), - UserName: aws.String("test-driftctl2"), - }, - }}, false) - callback(&iam.ListAccessKeysOutput{AccessKeyMetadata: []*iam.AccessKeyMetadata{ - { - AccessKeyId: aws.String("AKIA5QYBVVD2SWDFVVMG"), - UserName: aws.String("test-driftctl2"), - }, - }}, true) - return true - })).Return(nil).Once() - }, - want: []*iam.AccessKeyMetadata{ - { - AccessKeyId: aws.String("AKIA5QYBVVD223VWU32A"), - UserName: aws.String("test-driftctl"), - }, - { - AccessKeyId: aws.String("AKIA5QYBVVD2QYI36UZP"), - UserName: aws.String("test-driftctl"), - }, - { - AccessKeyId: aws.String("AKIA5QYBVVD223VWU32A"), - UserName: aws.String("test-driftctl"), - }, - { - AccessKeyId: aws.String("AKIA5QYBVVD2QYI36UZP"), - UserName: aws.String("test-driftctl"), - }, - { - AccessKeyId: aws.String("AKIA5QYBVVD26EJME25D"), - UserName: aws.String("test-driftctl2"), - }, - { - AccessKeyId: aws.String("AKIA5QYBVVD2SWDFVVMG"), - UserName: aws.String("test-driftctl2"), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(2) - client := &awstest.MockFakeIAM{} - tt.mocks(client) - r := &iamRepository{ - client: client, - cache: store, - } - got, err := r.ListAllAccessKeys(tt.users) - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllAccessKeys(tt.users) - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - for _, user := range tt.users { - assert.IsType(t, []*iam.AccessKeyMetadata{}, store.Get(fmt.Sprintf("iamListAllAccessKeys_user_%s", *user.UserName))) - } - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_IAMRepository_ListAllUsers(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeIAM) - want []*iam.User - wantErr error - }{ - { - name: "List only users with multiple pages", - mocks: func(client *awstest.MockFakeIAM) { - - client.On("ListUsersPages", - &iam.ListUsersInput{}, - mock.MatchedBy(func(callback func(res *iam.ListUsersOutput, lastPage bool) bool) bool { - callback(&iam.ListUsersOutput{Users: []*iam.User{ - { - UserName: aws.String("test-driftctl"), - }, - { - UserName: aws.String("test-driftctl2"), - }, - }}, false) - callback(&iam.ListUsersOutput{Users: []*iam.User{ - { - UserName: aws.String("test-driftctl3"), - }, - { - UserName: aws.String("test-driftctl4"), - }, - }}, true) - return true - })).Return(nil).Once() - }, - want: []*iam.User{ - { - UserName: aws.String("test-driftctl"), - }, - { - UserName: aws.String("test-driftctl2"), - }, - { - UserName: aws.String("test-driftctl3"), - }, - { - UserName: aws.String("test-driftctl4"), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeIAM{} - tt.mocks(client) - r := &iamRepository{ - client: client, - cache: store, - } - got, err := r.ListAllUsers() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllUsers() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*iam.User{}, store.Get("iamListAllUsers")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_IAMRepository_ListAllPolicies(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeIAM) - want []*iam.Policy - wantErr error - }{ - { - name: "List only policies with multiple pages", - mocks: func(client *awstest.MockFakeIAM) { - - client.On("ListPoliciesPages", - &iam.ListPoliciesInput{Scope: aws.String(iam.PolicyScopeTypeLocal)}, - mock.MatchedBy(func(callback func(res *iam.ListPoliciesOutput, lastPage bool) bool) bool { - callback(&iam.ListPoliciesOutput{Policies: []*iam.Policy{ - { - PolicyName: aws.String("test-driftctl"), - }, - { - PolicyName: aws.String("test-driftctl2"), - }, - }}, false) - callback(&iam.ListPoliciesOutput{Policies: []*iam.Policy{ - { - PolicyName: aws.String("test-driftctl3"), - }, - { - PolicyName: aws.String("test-driftctl4"), - }, - }}, true) - return true - })).Return(nil).Once() - }, - want: []*iam.Policy{ - { - PolicyName: aws.String("test-driftctl"), - }, - { - PolicyName: aws.String("test-driftctl2"), - }, - { - PolicyName: aws.String("test-driftctl3"), - }, - { - PolicyName: aws.String("test-driftctl4"), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeIAM{} - tt.mocks(client) - r := &iamRepository{ - client: client, - cache: store, - } - got, err := r.ListAllPolicies() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllPolicies() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*iam.Policy{}, store.Get("iamListAllPolicies")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_IAMRepository_ListAllRoles(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeIAM) - want []*iam.Role - wantErr error - }{ - { - name: "List only roles with multiple pages", - mocks: func(client *awstest.MockFakeIAM) { - - client.On("ListRolesPages", - &iam.ListRolesInput{}, - mock.MatchedBy(func(callback func(res *iam.ListRolesOutput, lastPage bool) bool) bool { - callback(&iam.ListRolesOutput{Roles: []*iam.Role{ - { - RoleName: aws.String("test-driftctl"), - }, - { - RoleName: aws.String("test-driftctl2"), - }, - }}, false) - callback(&iam.ListRolesOutput{Roles: []*iam.Role{ - { - RoleName: aws.String("test-driftctl3"), - }, - { - RoleName: aws.String("test-driftctl4"), - }, - }}, true) - return true - })).Return(nil).Once() - }, - want: []*iam.Role{ - { - RoleName: aws.String("test-driftctl"), - }, - { - RoleName: aws.String("test-driftctl2"), - }, - { - RoleName: aws.String("test-driftctl3"), - }, - { - RoleName: aws.String("test-driftctl4"), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeIAM{} - tt.mocks(client) - r := &iamRepository{ - client: client, - cache: store, - } - got, err := r.ListAllRoles() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllRoles() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*iam.Role{}, store.Get("iamListAllRoles")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_IAMRepository_ListAllRolePolicyAttachments(t *testing.T) { - tests := []struct { - name string - roles []*iam.Role - mocks func(client *awstest.MockFakeIAM) - want []*AttachedRolePolicy - wantErr error - }{ - { - name: "List only role policy attachments with multiple pages", - roles: []*iam.Role{ - { - RoleName: aws.String("test-role"), - }, - { - RoleName: aws.String("test-role2"), - }, - }, - mocks: func(client *awstest.MockFakeIAM) { - - shouldSkipfirst := false - shouldSkipSecond := false - - client.On("ListAttachedRolePoliciesPages", - &iam.ListAttachedRolePoliciesInput{ - RoleName: aws.String("test-role"), - }, - mock.MatchedBy(func(callback func(res *iam.ListAttachedRolePoliciesOutput, lastPage bool) bool) bool { - if shouldSkipfirst { - return false - } - callback(&iam.ListAttachedRolePoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ - { - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy"), - PolicyName: aws.String("policy"), - }, - { - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy2"), - PolicyName: aws.String("policy2"), - }, - }}, false) - callback(&iam.ListAttachedRolePoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ - { - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy3"), - PolicyName: aws.String("policy3"), - }, - }}, true) - shouldSkipfirst = true - return true - })).Return(nil).Once() - - client.On("ListAttachedRolePoliciesPages", - &iam.ListAttachedRolePoliciesInput{ - RoleName: aws.String("test-role2"), - }, - mock.MatchedBy(func(callback func(res *iam.ListAttachedRolePoliciesOutput, lastPage bool) bool) bool { - if shouldSkipSecond { - return false - } - callback(&iam.ListAttachedRolePoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ - { - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy"), - PolicyName: aws.String("policy"), - }, - { - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy2"), - PolicyName: aws.String("policy2"), - }, - }}, false) - callback(&iam.ListAttachedRolePoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ - { - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy3"), - PolicyName: aws.String("policy3"), - }, - }}, true) - shouldSkipSecond = true - return true - })).Return(nil).Once() - }, - want: []*AttachedRolePolicy{ - { - iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy"), - PolicyName: aws.String("policy"), - }, - *aws.String("test-role"), - }, - { - iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy2"), - PolicyName: aws.String("policy2"), - }, - *aws.String("test-role"), - }, - { - iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy3"), - PolicyName: aws.String("policy3"), - }, - *aws.String("test-role"), - }, - { - iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy"), - PolicyName: aws.String("policy"), - }, - *aws.String("test-role2"), - }, - { - iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy2"), - PolicyName: aws.String("policy2"), - }, - *aws.String("test-role2"), - }, - { - iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test-policy3"), - PolicyName: aws.String("policy3"), - }, - *aws.String("test-role2"), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(2) - client := &awstest.MockFakeIAM{} - tt.mocks(client) - r := &iamRepository{ - client: client, - cache: store, - } - got, err := r.ListAllRolePolicyAttachments(tt.roles) - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllRolePolicyAttachments(tt.roles) - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - for _, role := range tt.roles { - assert.IsType(t, []*AttachedRolePolicy{}, store.Get(fmt.Sprintf("iamListAllRolePolicyAttachments_role_%s", *role.RoleName))) - } - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_IAMRepository_ListAllRolePolicies(t *testing.T) { - tests := []struct { - name string - roles []*iam.Role - mocks func(client *awstest.MockFakeIAM) - want []RolePolicy - wantErr error - }{ - { - name: "List only role policies with multiple pages", - roles: []*iam.Role{ - { - RoleName: aws.String("test_role_0"), - }, - { - RoleName: aws.String("test_role_1"), - }, - }, - mocks: func(client *awstest.MockFakeIAM) { - firstMockCalled := false - client.On("ListRolePoliciesPages", - &iam.ListRolePoliciesInput{ - RoleName: aws.String("test_role_0"), - }, - mock.MatchedBy(func(callback func(res *iam.ListRolePoliciesOutput, lastPage bool) bool) bool { - if firstMockCalled { - return false - } - callback(&iam.ListRolePoliciesOutput{ - PolicyNames: []*string{ - aws.String("policy-role0-0"), - aws.String("policy-role0-1"), - }, - }, false) - callback(&iam.ListRolePoliciesOutput{ - PolicyNames: []*string{ - aws.String("policy-role0-2"), - }, - }, true) - firstMockCalled = true - return true - })).Once().Return(nil) - client.On("ListRolePoliciesPages", - &iam.ListRolePoliciesInput{ - RoleName: aws.String("test_role_1"), - }, - mock.MatchedBy(func(callback func(res *iam.ListRolePoliciesOutput, lastPage bool) bool) bool { - callback(&iam.ListRolePoliciesOutput{ - PolicyNames: []*string{ - aws.String("policy-role1-0"), - aws.String("policy-role1-1"), - }, - }, false) - callback(&iam.ListRolePoliciesOutput{ - PolicyNames: []*string{ - aws.String("policy-role1-2"), - }, - }, true) - return true - })).Once().Return(nil) - }, - want: []RolePolicy{ - {Policy: "policy-role0-0", RoleName: "test_role_0"}, - {Policy: "policy-role0-1", RoleName: "test_role_0"}, - {Policy: "policy-role0-2", RoleName: "test_role_0"}, - {Policy: "policy-role1-0", RoleName: "test_role_1"}, - {Policy: "policy-role1-1", RoleName: "test_role_1"}, - {Policy: "policy-role1-2", RoleName: "test_role_1"}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(2) - client := &awstest.MockFakeIAM{} - tt.mocks(client) - r := &iamRepository{ - client: client, - cache: store, - } - got, err := r.ListAllRolePolicies(tt.roles) - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllRolePolicies(tt.roles) - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - for _, role := range tt.roles { - assert.IsType(t, []RolePolicy{}, store.Get(fmt.Sprintf("iamListAllRolePolicies_role_%s", *role.RoleName))) - } - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_IAMRepository_ListAllUserPolicyAttachments(t *testing.T) { - tests := []struct { - name string - users []*iam.User - mocks func(client *awstest.MockFakeIAM) - want []*AttachedUserPolicy - wantErr error - }{ - { - name: "List only user policy attachments with multiple pages", - users: []*iam.User{ - { - UserName: aws.String("loadbalancer"), - }, - { - UserName: aws.String("loadbalancer2"), - }, - }, - mocks: func(client *awstest.MockFakeIAM) { - - client.On("ListAttachedUserPoliciesPages", - &iam.ListAttachedUserPoliciesInput{ - UserName: aws.String("loadbalancer"), - }, - mock.MatchedBy(func(callback func(res *iam.ListAttachedUserPoliciesOutput, lastPage bool) bool) bool { - callback(&iam.ListAttachedUserPoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ - { - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test"), - PolicyName: aws.String("test-attach"), - }, - { - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test2"), - PolicyName: aws.String("test-attach2"), - }, - }}, false) - callback(&iam.ListAttachedUserPoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ - { - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test3"), - PolicyName: aws.String("test-attach3"), - }, - }}, true) - return true - })).Return(nil).Once() - - client.On("ListAttachedUserPoliciesPages", - &iam.ListAttachedUserPoliciesInput{ - UserName: aws.String("loadbalancer2"), - }, - mock.MatchedBy(func(callback func(res *iam.ListAttachedUserPoliciesOutput, lastPage bool) bool) bool { - callback(&iam.ListAttachedUserPoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ - { - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test"), - PolicyName: aws.String("test-attach"), - }, - { - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test2"), - PolicyName: aws.String("test-attach2"), - }, - }}, false) - callback(&iam.ListAttachedUserPoliciesOutput{AttachedPolicies: []*iam.AttachedPolicy{ - { - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test3"), - PolicyName: aws.String("test-attach3"), - }, - }}, true) - return true - })).Return(nil).Once() - }, - - want: []*AttachedUserPolicy{ - { - iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test"), - PolicyName: aws.String("test-attach"), - }, - *aws.String("loadbalancer"), - }, - { - iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test2"), - PolicyName: aws.String("test-attach2"), - }, - *aws.String("loadbalancer"), - }, - { - iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test3"), - PolicyName: aws.String("test-attach3"), - }, - *aws.String("loadbalancer"), - }, - { - iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test"), - PolicyName: aws.String("test-attach"), - }, - *aws.String("loadbalancer2"), - }, - { - iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test2"), - PolicyName: aws.String("test-attach2"), - }, - *aws.String("loadbalancer2"), - }, - { - iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test3"), - PolicyName: aws.String("test-attach3"), - }, - *aws.String("loadbalancer2"), - }, - { - iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test"), - PolicyName: aws.String("test-attach"), - }, - *aws.String("loadbalancer2"), - }, - { - iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test2"), - PolicyName: aws.String("test-attach2"), - }, - *aws.String("loadbalancer2"), - }, - { - iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::526954929923:policy/test3"), - PolicyName: aws.String("test-attach3"), - }, - *aws.String("loadbalancer2"), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(2) - client := &awstest.MockFakeIAM{} - tt.mocks(client) - r := &iamRepository{ - client: client, - cache: store, - } - got, err := r.ListAllUserPolicyAttachments(tt.users) - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllUserPolicyAttachments(tt.users) - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - for _, user := range tt.users { - assert.IsType(t, []*AttachedUserPolicy{}, store.Get(fmt.Sprintf("iamListAllUserPolicyAttachments_user_%s", *user.UserName))) - } - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_IAMRepository_ListAllUserPolicies(t *testing.T) { - tests := []struct { - name string - users []*iam.User - mocks func(client *awstest.MockFakeIAM) - want []string - wantErr error - }{ - { - name: "List only user policies with multiple pages", - users: []*iam.User{ - { - UserName: aws.String("loadbalancer"), - }, - { - UserName: aws.String("loadbalancer2"), - }, - }, - mocks: func(client *awstest.MockFakeIAM) { - - client.On("ListUserPoliciesPages", - &iam.ListUserPoliciesInput{ - UserName: aws.String("loadbalancer"), - }, - mock.MatchedBy(func(callback func(res *iam.ListUserPoliciesOutput, lastPage bool) bool) bool { - callback(&iam.ListUserPoliciesOutput{PolicyNames: []*string{ - aws.String("test"), - aws.String("test2"), - aws.String("test3"), - }}, false) - callback(&iam.ListUserPoliciesOutput{PolicyNames: []*string{ - aws.String("test4"), - }}, true) - return true - })).Return(nil).Once() - - client.On("ListUserPoliciesPages", - &iam.ListUserPoliciesInput{ - UserName: aws.String("loadbalancer2"), - }, - mock.MatchedBy(func(callback func(res *iam.ListUserPoliciesOutput, lastPage bool) bool) bool { - callback(&iam.ListUserPoliciesOutput{PolicyNames: []*string{ - aws.String("test2"), - aws.String("test22"), - aws.String("test23"), - }}, false) - callback(&iam.ListUserPoliciesOutput{PolicyNames: []*string{ - aws.String("test24"), - }}, true) - return true - })).Return(nil).Once() - }, - want: []string{ - *aws.String("loadbalancer:test"), - *aws.String("loadbalancer:test2"), - *aws.String("loadbalancer:test3"), - *aws.String("loadbalancer:test4"), - *aws.String("loadbalancer2:test"), - *aws.String("loadbalancer2:test2"), - *aws.String("loadbalancer2:test3"), - *aws.String("loadbalancer2:test4"), - *aws.String("loadbalancer2:test2"), - *aws.String("loadbalancer2:test22"), - *aws.String("loadbalancer2:test23"), - *aws.String("loadbalancer2:test24"), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(2) - client := &awstest.MockFakeIAM{} - tt.mocks(client) - r := &iamRepository{ - client: client, - cache: store, - } - got, err := r.ListAllUserPolicies(tt.users) - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllUserPolicies(tt.users) - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - for _, user := range tt.users { - assert.IsType(t, []string{}, store.Get(fmt.Sprintf("iamListAllUserPolicies_user_%s", *user.UserName))) - } - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_IAMRepository_ListAllGroups(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeIAM) - want []*iam.Group - wantErr error - }{ - { - name: "List groups with multiple pages", - mocks: func(client *awstest.MockFakeIAM) { - - client.On("ListGroupsPages", - &iam.ListGroupsInput{}, - mock.MatchedBy(func(callback func(res *iam.ListGroupsOutput, lastPage bool) bool) bool { - callback(&iam.ListGroupsOutput{Groups: []*iam.Group{ - { - GroupName: aws.String("group1"), - }, - { - GroupName: aws.String("group2"), - }, - }}, false) - callback(&iam.ListGroupsOutput{Groups: []*iam.Group{ - { - GroupName: aws.String("group3"), - }, - { - GroupName: aws.String("group4"), - }, - }}, true) - return true - })).Return(nil).Once() - }, - want: []*iam.Group{ - { - GroupName: aws.String("group1"), - }, - { - GroupName: aws.String("group2"), - }, - { - GroupName: aws.String("group3"), - }, - { - GroupName: aws.String("group4"), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeIAM{} - tt.mocks(client) - r := &iamRepository{ - client: client, - cache: store, - } - got, err := r.ListAllGroups() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllGroups() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*iam.Group{}, store.Get("iamListAllGroups")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_IAMRepository_ListAllGroupPolicies(t *testing.T) { - tests := []struct { - name string - groups []*iam.Group - mocks func(client *awstest.MockFakeIAM) - want []string - wantErr error - }{ - { - name: "List only group policies with multiple pages", - groups: []*iam.Group{ - { - GroupName: aws.String("group1"), - }, - { - GroupName: aws.String("group2"), - }, - }, - mocks: func(client *awstest.MockFakeIAM) { - firstMockCalled := false - client.On("ListGroupPoliciesPages", - &iam.ListGroupPoliciesInput{ - GroupName: aws.String("group1"), - }, - mock.MatchedBy(func(callback func(res *iam.ListGroupPoliciesOutput, lastPage bool) bool) bool { - if firstMockCalled { - return false - } - callback(&iam.ListGroupPoliciesOutput{PolicyNames: []*string{ - aws.String("policy1"), - aws.String("policy2"), - aws.String("policy3"), - }}, false) - callback(&iam.ListGroupPoliciesOutput{PolicyNames: []*string{ - aws.String("policy4"), - }}, true) - firstMockCalled = true - return true - })).Return(nil).Once() - - client.On("ListGroupPoliciesPages", - &iam.ListGroupPoliciesInput{ - GroupName: aws.String("group2"), - }, - mock.MatchedBy(func(callback func(res *iam.ListGroupPoliciesOutput, lastPage bool) bool) bool { - callback(&iam.ListGroupPoliciesOutput{PolicyNames: []*string{ - aws.String("policy2"), - aws.String("policy22"), - aws.String("policy23"), - }}, false) - callback(&iam.ListGroupPoliciesOutput{PolicyNames: []*string{ - aws.String("policy24"), - }}, true) - return true - })).Return(nil).Once() - }, - want: []string{ - *aws.String("group1:policy1"), - *aws.String("group1:policy2"), - *aws.String("group1:policy3"), - *aws.String("group1:policy4"), - *aws.String("group2:policy2"), - *aws.String("group2:policy22"), - *aws.String("group2:policy23"), - *aws.String("group2:policy24"), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(2) - client := &awstest.MockFakeIAM{} - tt.mocks(client) - r := &iamRepository{ - client: client, - cache: store, - } - got, err := r.ListAllGroupPolicies(tt.groups) - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllGroupPolicies(tt.groups) - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - for _, group := range tt.groups { - assert.IsType(t, []string{}, store.Get(fmt.Sprintf("iamListAllGroupPolicies_group_%s", *group.GroupName))) - } - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} diff --git a/pkg/remote/aws/repository/kms_repository.go b/pkg/remote/aws/repository/kms_repository.go deleted file mode 100644 index 2b21baff..00000000 --- a/pkg/remote/aws/repository/kms_repository.go +++ /dev/null @@ -1,146 +0,0 @@ -package repository - -import ( - "fmt" - "strings" - "sync" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kms" - "github.com/aws/aws-sdk-go/service/kms/kmsiface" - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type KMSRepository interface { - ListAllKeys() ([]*kms.KeyListEntry, error) - ListAllAliases() ([]*kms.AliasListEntry, error) -} - -type kmsRepository struct { - client kmsiface.KMSAPI - cache cache.Cache - describeKeyLock *sync.Mutex -} - -func NewKMSRepository(session *session.Session, c cache.Cache) *kmsRepository { - return &kmsRepository{ - kms.New(session), - c, - &sync.Mutex{}, - } -} - -func (r *kmsRepository) ListAllKeys() ([]*kms.KeyListEntry, error) { - if v := r.cache.Get("kmsListAllKeys"); v != nil { - return v.([]*kms.KeyListEntry), nil - } - - var keys []*kms.KeyListEntry - input := kms.ListKeysInput{} - err := r.client.ListKeysPages(&input, - func(resp *kms.ListKeysOutput, lastPage bool) bool { - keys = append(keys, resp.Keys...) - return !lastPage - }, - ) - if err != nil { - return nil, err - } - customerKeys, err := r.filterKeys(keys) - if err != nil { - return nil, err - } - - r.cache.Put("kmsListAllKeys", customerKeys) - return customerKeys, nil -} - -func (r *kmsRepository) ListAllAliases() ([]*kms.AliasListEntry, error) { - if v := r.cache.Get("kmsListAllAliases"); v != nil { - return v.([]*kms.AliasListEntry), nil - } - - var aliases []*kms.AliasListEntry - input := kms.ListAliasesInput{} - err := r.client.ListAliasesPages(&input, - func(resp *kms.ListAliasesOutput, lastPage bool) bool { - aliases = append(aliases, resp.Aliases...) - return !lastPage - }, - ) - if err != nil { - return nil, err - } - - result, err := r.filterAliases(aliases) - if err != nil { - return nil, err - } - r.cache.Put("kmsListAllAliases", result) - return result, nil -} - -func (r *kmsRepository) describeKey(keyId *string) (*kms.DescribeKeyOutput, error) { - var results interface{} - // Since this method can be call in parallel, we should lock and unlock if we want to be sure to hit the cache - r.describeKeyLock.Lock() - defer r.describeKeyLock.Unlock() - cacheKey := fmt.Sprintf("kmsDescribeKey-%s", *keyId) - results = r.cache.Get(cacheKey) - if results == nil { - var err error - results, err = r.client.DescribeKey(&kms.DescribeKeyInput{KeyId: keyId}) - if err != nil { - return nil, err - } - r.cache.Put(cacheKey, results) - } - describeKey := results.(*kms.DescribeKeyOutput) - if aws.StringValue(describeKey.KeyMetadata.KeyState) == kms.KeyStatePendingDeletion { - return nil, nil - } - return describeKey, nil -} - -func (r *kmsRepository) filterKeys(keys []*kms.KeyListEntry) ([]*kms.KeyListEntry, error) { - var customerKeys []*kms.KeyListEntry - for _, key := range keys { - k, err := r.describeKey(key.KeyId) - if err != nil { - return nil, err - } - if k == nil { - logrus.WithFields(logrus.Fields{ - "id": *key.KeyId, - }).Debug("Ignored kms key from listing since it is pending from deletion") - continue - } - if k.KeyMetadata.KeyManager != nil && *k.KeyMetadata.KeyManager != "AWS" { - customerKeys = append(customerKeys, key) - } - } - return customerKeys, nil -} - -func (r *kmsRepository) filterAliases(aliases []*kms.AliasListEntry) ([]*kms.AliasListEntry, error) { - var customerAliases []*kms.AliasListEntry - for _, alias := range aliases { - if alias.AliasName != nil && !strings.HasPrefix(*alias.AliasName, "alias/aws/") { - k, err := r.describeKey(alias.TargetKeyId) - if err != nil { - return nil, err - } - if k == nil { - logrus.WithFields(logrus.Fields{ - "id": *alias.TargetKeyId, - "alias": *alias.AliasName, - }).Debug("Ignored kms key alias from listing since it is linked to a pending from deletion key") - continue - } - customerAliases = append(customerAliases, alias) - } - } - return customerAliases, nil -} diff --git a/pkg/remote/aws/repository/kms_repository_test.go b/pkg/remote/aws/repository/kms_repository_test.go deleted file mode 100644 index 51ee23f3..00000000 --- a/pkg/remote/aws/repository/kms_repository_test.go +++ /dev/null @@ -1,249 +0,0 @@ -package repository - -import ( - "strings" - "sync" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/kms" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - "github.com/stretchr/testify/mock" - - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" -) - -func Test_KMSRepository_ListAllKeys(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeKMS) - want []*kms.KeyListEntry - wantErr error - }{ - { - name: "List only enabled keys", - mocks: func(client *awstest.MockFakeKMS) { - client.On("ListKeysPages", - &kms.ListKeysInput{}, - mock.MatchedBy(func(callback func(res *kms.ListKeysOutput, lastPage bool) bool) bool { - callback(&kms.ListKeysOutput{ - Keys: []*kms.KeyListEntry{ - {KeyId: aws.String("1")}, - {KeyId: aws.String("2")}, - }, - }, true) - return true - })).Return(nil).Once() - client.On("DescribeKey", - &kms.DescribeKeyInput{ - KeyId: aws.String("1"), - }).Return(&kms.DescribeKeyOutput{ - KeyMetadata: &kms.KeyMetadata{ - KeyId: aws.String("1"), - KeyManager: aws.String("CUSTOMER"), - KeyState: aws.String(kms.KeyStateEnabled), - }, - }, nil).Once() - client.On("DescribeKey", - &kms.DescribeKeyInput{ - KeyId: aws.String("2"), - }).Return(&kms.DescribeKeyOutput{ - KeyMetadata: &kms.KeyMetadata{ - KeyId: aws.String("2"), - KeyManager: aws.String("CUSTOMER"), - KeyState: aws.String(kms.KeyStatePendingDeletion), - }, - }, nil).Once() - }, - want: []*kms.KeyListEntry{ - {KeyId: aws.String("1")}, - }, - }, - { - name: "List only customer keys", - mocks: func(client *awstest.MockFakeKMS) { - client.On("ListKeysPages", - &kms.ListKeysInput{}, - mock.MatchedBy(func(callback func(res *kms.ListKeysOutput, lastPage bool) bool) bool { - callback(&kms.ListKeysOutput{ - Keys: []*kms.KeyListEntry{ - {KeyId: aws.String("1")}, - {KeyId: aws.String("2")}, - {KeyId: aws.String("3")}, - }, - }, true) - return true - })).Return(nil).Once() - client.On("DescribeKey", - &kms.DescribeKeyInput{ - KeyId: aws.String("1"), - }).Return(&kms.DescribeKeyOutput{ - KeyMetadata: &kms.KeyMetadata{ - KeyId: aws.String("1"), - KeyManager: aws.String("CUSTOMER"), - KeyState: aws.String(kms.KeyStateEnabled), - }, - }, nil).Once() - client.On("DescribeKey", - &kms.DescribeKeyInput{ - KeyId: aws.String("2"), - }).Return(&kms.DescribeKeyOutput{ - KeyMetadata: &kms.KeyMetadata{ - KeyId: aws.String("2"), - KeyManager: aws.String("AWS"), - KeyState: aws.String(kms.KeyStateEnabled), - }, - }, nil).Once() - client.On("DescribeKey", - &kms.DescribeKeyInput{ - KeyId: aws.String("3"), - }).Return(&kms.DescribeKeyOutput{ - KeyMetadata: &kms.KeyMetadata{ - KeyId: aws.String("3"), - KeyManager: aws.String("AWS"), - KeyState: aws.String(kms.KeyStateEnabled), - }, - }, nil).Once() - }, - want: []*kms.KeyListEntry{ - {KeyId: aws.String("1")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := awstest.MockFakeKMS{} - tt.mocks(&client) - r := &kmsRepository{ - client: &client, - cache: store, - describeKeyLock: &sync.Mutex{}, - } - got, err := r.ListAllKeys() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllKeys() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*kms.KeyListEntry{}, store.Get("kmsListAllKeys")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_KMSRepository_ListAllAliases(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeKMS) - want []*kms.AliasListEntry - wantErr error - }{ - { - name: "List only aliases for enabled keys", - mocks: func(client *awstest.MockFakeKMS) { - client.On("ListAliasesPages", - &kms.ListAliasesInput{}, - mock.MatchedBy(func(callback func(res *kms.ListAliasesOutput, lastPage bool) bool) bool { - callback(&kms.ListAliasesOutput{ - Aliases: []*kms.AliasListEntry{ - {AliasName: aws.String("alias/1"), TargetKeyId: aws.String("key-id-1")}, - {AliasName: aws.String("alias/2"), TargetKeyId: aws.String("key-id-2")}, - }, - }, true) - return true - })).Return(nil).Once() - client.On("DescribeKey", &kms.DescribeKeyInput{KeyId: aws.String("key-id-1")}).Return(&kms.DescribeKeyOutput{ - KeyMetadata: &kms.KeyMetadata{ - KeyState: aws.String(kms.KeyStatePendingDeletion), - }, - }, nil) - client.On("DescribeKey", &kms.DescribeKeyInput{KeyId: aws.String("key-id-2")}).Return(&kms.DescribeKeyOutput{ - KeyMetadata: &kms.KeyMetadata{ - KeyState: aws.String(kms.KeyStateEnabled), - }, - }, nil) - }, - want: []*kms.AliasListEntry{ - {AliasName: aws.String("alias/2"), TargetKeyId: aws.String("key-id-2")}, - }, - }, - { - name: "List only customer aliases", - mocks: func(client *awstest.MockFakeKMS) { - client.On("ListAliasesPages", - &kms.ListAliasesInput{}, - mock.MatchedBy(func(callback func(res *kms.ListAliasesOutput, lastPage bool) bool) bool { - callback(&kms.ListAliasesOutput{ - Aliases: []*kms.AliasListEntry{ - {AliasName: aws.String("alias/1"), TargetKeyId: aws.String("key-id-1")}, - {AliasName: aws.String("alias/foo/2"), TargetKeyId: aws.String("key-id-2")}, - {AliasName: aws.String("alias/aw/3"), TargetKeyId: aws.String("key-id-3")}, - {AliasName: aws.String("alias/aws/4"), TargetKeyId: aws.String("key-id-4")}, - {AliasName: aws.String("alias/aws/5"), TargetKeyId: aws.String("key-id-5")}, - {AliasName: aws.String("alias/awss/6"), TargetKeyId: aws.String("key-id-6")}, - {AliasName: aws.String("alias/aws7"), TargetKeyId: aws.String("key-id-7")}, - }, - }, true) - return true - })).Return(nil).Once() - client.On("DescribeKey", mock.Anything).Return(&kms.DescribeKeyOutput{ - KeyMetadata: &kms.KeyMetadata{ - KeyState: aws.String(kms.KeyStateEnabled), - }, - }, nil) - }, - want: []*kms.AliasListEntry{ - {AliasName: aws.String("alias/1"), TargetKeyId: aws.String("key-id-1")}, - {AliasName: aws.String("alias/foo/2"), TargetKeyId: aws.String("key-id-2")}, - {AliasName: aws.String("alias/aw/3"), TargetKeyId: aws.String("key-id-3")}, - {AliasName: aws.String("alias/awss/6"), TargetKeyId: aws.String("key-id-6")}, - {AliasName: aws.String("alias/aws7"), TargetKeyId: aws.String("key-id-7")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := awstest.MockFakeKMS{} - tt.mocks(&client) - r := &kmsRepository{ - client: &client, - cache: store, - describeKeyLock: &sync.Mutex{}, - } - got, err := r.ListAllAliases() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllAliases() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*kms.AliasListEntry{}, store.Get("kmsListAllAliases")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %v -> %v", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} diff --git a/pkg/remote/aws/repository/lambda_repository.go b/pkg/remote/aws/repository/lambda_repository.go deleted file mode 100644 index 6b9b6fdf..00000000 --- a/pkg/remote/aws/repository/lambda_repository.go +++ /dev/null @@ -1,63 +0,0 @@ -package repository - -import ( - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/lambda" - "github.com/aws/aws-sdk-go/service/lambda/lambdaiface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type LambdaRepository interface { - ListAllLambdaFunctions() ([]*lambda.FunctionConfiguration, error) - ListAllLambdaEventSourceMappings() ([]*lambda.EventSourceMappingConfiguration, error) -} - -type lambdaRepository struct { - client lambdaiface.LambdaAPI - cache cache.Cache -} - -func NewLambdaRepository(session *session.Session, c cache.Cache) *lambdaRepository { - return &lambdaRepository{ - lambda.New(session), - c, - } -} - -func (r *lambdaRepository) ListAllLambdaFunctions() ([]*lambda.FunctionConfiguration, error) { - if v := r.cache.Get("lambdaListAllLambdaFunctions"); v != nil { - return v.([]*lambda.FunctionConfiguration), nil - } - - var functions []*lambda.FunctionConfiguration - input := &lambda.ListFunctionsInput{} - err := r.client.ListFunctionsPages(input, func(res *lambda.ListFunctionsOutput, lastPage bool) bool { - functions = append(functions, res.Functions...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put("lambdaListAllLambdaFunctions", functions) - return functions, nil -} - -func (r *lambdaRepository) ListAllLambdaEventSourceMappings() ([]*lambda.EventSourceMappingConfiguration, error) { - if v := r.cache.Get("lambdaListAllLambdaEventSourceMappings"); v != nil { - return v.([]*lambda.EventSourceMappingConfiguration), nil - } - - var eventSourceMappingConfigurations []*lambda.EventSourceMappingConfiguration - input := &lambda.ListEventSourceMappingsInput{} - err := r.client.ListEventSourceMappingsPages(input, func(res *lambda.ListEventSourceMappingsOutput, lastPage bool) bool { - eventSourceMappingConfigurations = append(eventSourceMappingConfigurations, res.EventSourceMappings...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put("lambdaListAllLambdaEventSourceMappings", eventSourceMappingConfigurations) - return eventSourceMappingConfigurations, nil -} diff --git a/pkg/remote/aws/repository/lambda_repository_test.go b/pkg/remote/aws/repository/lambda_repository_test.go deleted file mode 100644 index 62762983..00000000 --- a/pkg/remote/aws/repository/lambda_repository_test.go +++ /dev/null @@ -1,169 +0,0 @@ -package repository - -import ( - "strings" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - "github.com/stretchr/testify/mock" - - "github.com/aws/aws-sdk-go/service/lambda" - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" -) - -func Test_lambdaRepository_ListAllLambdaFunctions(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeLambda) - want []*lambda.FunctionConfiguration - wantErr error - }{ - { - name: "List with 2 pages", - mocks: func(client *awstest.MockFakeLambda) { - client.On("ListFunctionsPages", - &lambda.ListFunctionsInput{}, - mock.MatchedBy(func(callback func(res *lambda.ListFunctionsOutput, lastPage bool) bool) bool { - callback(&lambda.ListFunctionsOutput{ - Functions: []*lambda.FunctionConfiguration{ - {FunctionName: aws.String("1")}, - {FunctionName: aws.String("2")}, - {FunctionName: aws.String("3")}, - {FunctionName: aws.String("4")}, - }, - }, false) - callback(&lambda.ListFunctionsOutput{ - Functions: []*lambda.FunctionConfiguration{ - {FunctionName: aws.String("5")}, - {FunctionName: aws.String("6")}, - {FunctionName: aws.String("7")}, - {FunctionName: aws.String("8")}, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*lambda.FunctionConfiguration{ - {FunctionName: aws.String("1")}, - {FunctionName: aws.String("2")}, - {FunctionName: aws.String("3")}, - {FunctionName: aws.String("4")}, - {FunctionName: aws.String("5")}, - {FunctionName: aws.String("6")}, - {FunctionName: aws.String("7")}, - {FunctionName: aws.String("8")}, - }, - wantErr: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeLambda{} - tt.mocks(client) - r := &lambdaRepository{ - client: client, - cache: store, - } - got, err := r.ListAllLambdaFunctions() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllLambdaFunctions() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*lambda.FunctionConfiguration{}, store.Get("lambdaListAllLambdaFunctions")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_lambdaRepository_ListAllLambdaEventSourceMappings(t *testing.T) { - tests := []struct { - name string - mocks func(mock *awstest.MockFakeLambda) - want []*lambda.EventSourceMappingConfiguration - wantErr error - }{ - { - name: "List with 2 pages", - mocks: func(client *awstest.MockFakeLambda) { - client.On("ListEventSourceMappingsPages", - &lambda.ListEventSourceMappingsInput{}, - mock.MatchedBy(func(callback func(res *lambda.ListEventSourceMappingsOutput, lastPage bool) bool) bool { - callback(&lambda.ListEventSourceMappingsOutput{ - EventSourceMappings: []*lambda.EventSourceMappingConfiguration{ - {UUID: aws.String("1")}, - {UUID: aws.String("2")}, - {UUID: aws.String("3")}, - {UUID: aws.String("4")}, - }, - }, false) - callback(&lambda.ListEventSourceMappingsOutput{ - EventSourceMappings: []*lambda.EventSourceMappingConfiguration{ - {UUID: aws.String("5")}, - {UUID: aws.String("6")}, - {UUID: aws.String("7")}, - {UUID: aws.String("8")}, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*lambda.EventSourceMappingConfiguration{ - {UUID: aws.String("1")}, - {UUID: aws.String("2")}, - {UUID: aws.String("3")}, - {UUID: aws.String("4")}, - {UUID: aws.String("5")}, - {UUID: aws.String("6")}, - {UUID: aws.String("7")}, - {UUID: aws.String("8")}, - }, - wantErr: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeLambda{} - tt.mocks(client) - r := &lambdaRepository{ - client: client, - cache: store, - } - got, err := r.ListAllLambdaEventSourceMappings() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllLambdaEventSourceMappings() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*lambda.EventSourceMappingConfiguration{}, store.Get("lambdaListAllLambdaEventSourceMappings")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} diff --git a/pkg/remote/aws/repository/rds_repository.go b/pkg/remote/aws/repository/rds_repository.go deleted file mode 100644 index 28d78a5d..00000000 --- a/pkg/remote/aws/repository/rds_repository.go +++ /dev/null @@ -1,82 +0,0 @@ -package repository - -import ( - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type RDSRepository interface { - ListAllDBInstances() ([]*rds.DBInstance, error) - ListAllDBSubnetGroups() ([]*rds.DBSubnetGroup, error) - ListAllDBClusters() ([]*rds.DBCluster, error) -} - -type rdsRepository struct { - client rdsiface.RDSAPI - cache cache.Cache -} - -func NewRDSRepository(session *session.Session, c cache.Cache) *rdsRepository { - return &rdsRepository{ - rds.New(session), - c, - } -} - -func (r *rdsRepository) ListAllDBInstances() ([]*rds.DBInstance, error) { - if v := r.cache.Get("rdsListAllDBInstances"); v != nil { - return v.([]*rds.DBInstance), nil - } - - var result []*rds.DBInstance - input := &rds.DescribeDBInstancesInput{} - err := r.client.DescribeDBInstancesPages(input, func(res *rds.DescribeDBInstancesOutput, lastPage bool) bool { - result = append(result, res.DBInstances...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put("rdsListAllDBInstances", result) - return result, nil -} - -func (r *rdsRepository) ListAllDBSubnetGroups() ([]*rds.DBSubnetGroup, error) { - if v := r.cache.Get("rdsListAllDBSubnetGroups"); v != nil { - return v.([]*rds.DBSubnetGroup), nil - } - - var subnetGroups []*rds.DBSubnetGroup - input := rds.DescribeDBSubnetGroupsInput{} - err := r.client.DescribeDBSubnetGroupsPages(&input, - func(resp *rds.DescribeDBSubnetGroupsOutput, lastPage bool) bool { - subnetGroups = append(subnetGroups, resp.DBSubnetGroups...) - return !lastPage - }, - ) - - r.cache.Put("rdsListAllDBSubnetGroups", subnetGroups) - return subnetGroups, err -} - -func (r *rdsRepository) ListAllDBClusters() ([]*rds.DBCluster, error) { - cacheKey := "rdsListAllDBClusters" - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*rds.DBCluster), nil - } - - var clusters []*rds.DBCluster - input := rds.DescribeDBClustersInput{} - err := r.client.DescribeDBClustersPages(&input, - func(resp *rds.DescribeDBClustersOutput, lastPage bool) bool { - clusters = append(clusters, resp.DBClusters...) - return !lastPage - }, - ) - - r.cache.Put(cacheKey, clusters) - return clusters, err -} diff --git a/pkg/remote/aws/repository/rds_repository_test.go b/pkg/remote/aws/repository/rds_repository_test.go deleted file mode 100644 index 5a4d3898..00000000 --- a/pkg/remote/aws/repository/rds_repository_test.go +++ /dev/null @@ -1,245 +0,0 @@ -package repository - -import ( - "strings" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func Test_rdsRepository_ListAllDBInstances(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeRDS) - want []*rds.DBInstance - wantErr error - }{ - { - name: "List with 2 pages", - mocks: func(client *awstest.MockFakeRDS) { - client.On("DescribeDBInstancesPages", - &rds.DescribeDBInstancesInput{}, - mock.MatchedBy(func(callback func(res *rds.DescribeDBInstancesOutput, lastPage bool) bool) bool { - callback(&rds.DescribeDBInstancesOutput{ - DBInstances: []*rds.DBInstance{ - {DBInstanceIdentifier: aws.String("1")}, - {DBInstanceIdentifier: aws.String("2")}, - {DBInstanceIdentifier: aws.String("3")}, - }, - }, false) - callback(&rds.DescribeDBInstancesOutput{ - DBInstances: []*rds.DBInstance{ - {DBInstanceIdentifier: aws.String("4")}, - {DBInstanceIdentifier: aws.String("5")}, - {DBInstanceIdentifier: aws.String("6")}, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*rds.DBInstance{ - {DBInstanceIdentifier: aws.String("1")}, - {DBInstanceIdentifier: aws.String("2")}, - {DBInstanceIdentifier: aws.String("3")}, - {DBInstanceIdentifier: aws.String("4")}, - {DBInstanceIdentifier: aws.String("5")}, - {DBInstanceIdentifier: aws.String("6")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeRDS{} - tt.mocks(client) - r := &rdsRepository{ - client: client, - cache: store, - } - got, err := r.ListAllDBInstances() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllDBInstances() - assert.Nil(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*rds.DBInstance{}, store.Get("rdsListAllDBInstances")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_rdsRepository_ListAllDBSubnetGroups(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeRDS) - want []*rds.DBSubnetGroup - wantErr error - }{ - { - name: "List with 2 pages", - mocks: func(client *awstest.MockFakeRDS) { - client.On("DescribeDBSubnetGroupsPages", - &rds.DescribeDBSubnetGroupsInput{}, - mock.MatchedBy(func(callback func(res *rds.DescribeDBSubnetGroupsOutput, lastPage bool) bool) bool { - callback(&rds.DescribeDBSubnetGroupsOutput{ - DBSubnetGroups: []*rds.DBSubnetGroup{ - {DBSubnetGroupName: aws.String("1")}, - {DBSubnetGroupName: aws.String("2")}, - {DBSubnetGroupName: aws.String("3")}, - }, - }, false) - callback(&rds.DescribeDBSubnetGroupsOutput{ - DBSubnetGroups: []*rds.DBSubnetGroup{ - {DBSubnetGroupName: aws.String("4")}, - {DBSubnetGroupName: aws.String("5")}, - {DBSubnetGroupName: aws.String("6")}, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*rds.DBSubnetGroup{ - {DBSubnetGroupName: aws.String("1")}, - {DBSubnetGroupName: aws.String("2")}, - {DBSubnetGroupName: aws.String("3")}, - {DBSubnetGroupName: aws.String("4")}, - {DBSubnetGroupName: aws.String("5")}, - {DBSubnetGroupName: aws.String("6")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeRDS{} - tt.mocks(client) - r := &rdsRepository{ - client: client, - cache: store, - } - got, err := r.ListAllDBSubnetGroups() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllDBSubnetGroups() - assert.Nil(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*rds.DBSubnetGroup{}, store.Get("rdsListAllDBSubnetGroups")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_rdsRepository_ListAllDBClusters(t *testing.T) { - tests := []struct { - name string - mocks func(*awstest.MockFakeRDS, *cache.MockCache) - want []*rds.DBCluster - wantErr error - }{ - { - name: "should list with 2 pages", - mocks: func(client *awstest.MockFakeRDS, store *cache.MockCache) { - clusters := []*rds.DBCluster{ - {DBClusterIdentifier: aws.String("1")}, - {DBClusterIdentifier: aws.String("2")}, - {DBClusterIdentifier: aws.String("3")}, - {DBClusterIdentifier: aws.String("4")}, - {DBClusterIdentifier: aws.String("5")}, - {DBClusterIdentifier: aws.String("6")}, - } - - client.On("DescribeDBClustersPages", - &rds.DescribeDBClustersInput{}, - mock.MatchedBy(func(callback func(res *rds.DescribeDBClustersOutput, lastPage bool) bool) bool { - callback(&rds.DescribeDBClustersOutput{DBClusters: clusters[:3]}, false) - callback(&rds.DescribeDBClustersOutput{DBClusters: clusters[3:]}, true) - return true - })).Return(nil).Once() - - store.On("Get", "rdsListAllDBClusters").Return(nil).Once() - store.On("Put", "rdsListAllDBClusters", clusters).Return(false).Once() - }, - want: []*rds.DBCluster{ - {DBClusterIdentifier: aws.String("1")}, - {DBClusterIdentifier: aws.String("2")}, - {DBClusterIdentifier: aws.String("3")}, - {DBClusterIdentifier: aws.String("4")}, - {DBClusterIdentifier: aws.String("5")}, - {DBClusterIdentifier: aws.String("6")}, - }, - }, - { - name: "should hit cache", - mocks: func(client *awstest.MockFakeRDS, store *cache.MockCache) { - clusters := []*rds.DBCluster{ - {DBClusterIdentifier: aws.String("1")}, - {DBClusterIdentifier: aws.String("2")}, - {DBClusterIdentifier: aws.String("3")}, - {DBClusterIdentifier: aws.String("4")}, - {DBClusterIdentifier: aws.String("5")}, - {DBClusterIdentifier: aws.String("6")}, - } - - store.On("Get", "rdsListAllDBClusters").Return(clusters).Once() - }, - want: []*rds.DBCluster{ - {DBClusterIdentifier: aws.String("1")}, - {DBClusterIdentifier: aws.String("2")}, - {DBClusterIdentifier: aws.String("3")}, - {DBClusterIdentifier: aws.String("4")}, - {DBClusterIdentifier: aws.String("5")}, - {DBClusterIdentifier: aws.String("6")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := &cache.MockCache{} - client := &awstest.MockFakeRDS{} - tt.mocks(client, store) - r := &rdsRepository{ - client: client, - cache: store, - } - got, err := r.ListAllDBClusters() - assert.Equal(t, tt.wantErr, err) - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} diff --git a/pkg/remote/aws/repository/route53_repository.go b/pkg/remote/aws/repository/route53_repository.go deleted file mode 100644 index 2e7dc962..00000000 --- a/pkg/remote/aws/repository/route53_repository.go +++ /dev/null @@ -1,92 +0,0 @@ -package repository - -import ( - "fmt" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/route53" - "github.com/aws/aws-sdk-go/service/route53/route53iface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type Route53Repository interface { - ListAllHealthChecks() ([]*route53.HealthCheck, error) - ListAllZones() ([]*route53.HostedZone, error) - ListRecordsForZone(zoneId string) ([]*route53.ResourceRecordSet, error) -} - -type route53Repository struct { - client route53iface.Route53API - cache cache.Cache -} - -func NewRoute53Repository(session *session.Session, c cache.Cache) *route53Repository { - return &route53Repository{ - route53.New(session), - c, - } -} - -func (r *route53Repository) ListAllHealthChecks() ([]*route53.HealthCheck, error) { - if v := r.cache.Get("route53ListAllHealthChecks"); v != nil { - return v.([]*route53.HealthCheck), nil - } - - var tables []*route53.HealthCheck - input := &route53.ListHealthChecksInput{} - err := r.client.ListHealthChecksPages(input, func(res *route53.ListHealthChecksOutput, lastPage bool) bool { - tables = append(tables, res.HealthChecks...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put("route53ListAllHealthChecks", tables) - return tables, nil -} - -func (r *route53Repository) ListAllZones() ([]*route53.HostedZone, error) { - cacheKey := "route53ListAllZones" - v := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - if v != nil { - return v.([]*route53.HostedZone), nil - } - - var result []*route53.HostedZone - input := &route53.ListHostedZonesInput{} - err := r.client.ListHostedZonesPages(input, func(res *route53.ListHostedZonesOutput, lastPage bool) bool { - result = append(result, res.HostedZones...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, result) - return result, nil -} - -func (r *route53Repository) ListRecordsForZone(zoneId string) ([]*route53.ResourceRecordSet, error) { - cacheKey := fmt.Sprintf("route53ListRecordsForZone_%s", zoneId) - if v := r.cache.Get(cacheKey); v != nil { - return v.([]*route53.ResourceRecordSet), nil - } - - var results []*route53.ResourceRecordSet - input := &route53.ListResourceRecordSetsInput{ - HostedZoneId: aws.String(zoneId), - } - err := r.client.ListResourceRecordSetsPages(input, func(res *route53.ListResourceRecordSetsOutput, lastPage bool) bool { - results = append(results, res.ResourceRecordSets...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, results) - return results, nil -} diff --git a/pkg/remote/aws/repository/route53_repository_test.go b/pkg/remote/aws/repository/route53_repository_test.go deleted file mode 100644 index 7a652336..00000000 --- a/pkg/remote/aws/repository/route53_repository_test.go +++ /dev/null @@ -1,240 +0,0 @@ -package repository - -import ( - "fmt" - "strings" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/route53" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - "github.com/stretchr/testify/mock" - - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" -) - -func Test_route53Repository_ListAllHealthChecks(t *testing.T) { - - tests := []struct { - name string - mocks func(client *awstest.MockFakeRoute53) - want []*route53.HealthCheck - wantErr error - }{ - { - name: "List with 2 pages", - mocks: func(client *awstest.MockFakeRoute53) { - client.On("ListHealthChecksPages", - &route53.ListHealthChecksInput{}, - mock.MatchedBy(func(callback func(res *route53.ListHealthChecksOutput, lastPage bool) bool) bool { - callback(&route53.ListHealthChecksOutput{ - HealthChecks: []*route53.HealthCheck{ - {Id: aws.String("1")}, - {Id: aws.String("2")}, - {Id: aws.String("3")}, - }, - }, false) - callback(&route53.ListHealthChecksOutput{ - HealthChecks: []*route53.HealthCheck{ - {Id: aws.String("4")}, - {Id: aws.String("5")}, - {Id: aws.String("6")}, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*route53.HealthCheck{ - {Id: aws.String("1")}, - {Id: aws.String("2")}, - {Id: aws.String("3")}, - {Id: aws.String("4")}, - {Id: aws.String("5")}, - {Id: aws.String("6")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := awstest.MockFakeRoute53{} - tt.mocks(&client) - r := &route53Repository{ - client: &client, - cache: store, - } - got, err := r.ListAllHealthChecks() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllHealthChecks() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*route53.HealthCheck{}, store.Get("route53ListAllHealthChecks")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_route53Repository_ListAllZones(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeRoute53) - want []*route53.HostedZone - wantErr error - }{ - {name: "Zones with 2 pages", - mocks: func(client *awstest.MockFakeRoute53) { - client.On("ListHostedZonesPages", - &route53.ListHostedZonesInput{}, - mock.MatchedBy(func(callback func(res *route53.ListHostedZonesOutput, lastPage bool) bool) bool { - callback(&route53.ListHostedZonesOutput{ - HostedZones: []*route53.HostedZone{ - {Id: aws.String("1")}, - {Id: aws.String("2")}, - {Id: aws.String("3")}, - }, - }, false) - callback(&route53.ListHostedZonesOutput{ - HostedZones: []*route53.HostedZone{ - {Id: aws.String("4")}, - {Id: aws.String("5")}, - {Id: aws.String("6")}, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*route53.HostedZone{ - {Id: aws.String("1")}, - {Id: aws.String("2")}, - {Id: aws.String("3")}, - {Id: aws.String("4")}, - {Id: aws.String("5")}, - {Id: aws.String("6")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := awstest.MockFakeRoute53{} - tt.mocks(&client) - r := &route53Repository{ - client: &client, - cache: store, - } - got, err := r.ListAllZones() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllZones() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*route53.HostedZone{}, store.Get("route53ListAllZones")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_route53Repository_ListRecordsForZone(t *testing.T) { - tests := []struct { - name string - zoneIds []string - mocks func(client *awstest.MockFakeRoute53) - want []*route53.ResourceRecordSet - wantErr error - }{ - { - name: "records for zone with 2 pages", - zoneIds: []string{ - "1", - }, - mocks: func(client *awstest.MockFakeRoute53) { - client.On("ListResourceRecordSetsPages", - &route53.ListResourceRecordSetsInput{ - HostedZoneId: aws.String("1"), - }, - mock.MatchedBy(func(callback func(res *route53.ListResourceRecordSetsOutput, lastPage bool) bool) bool { - callback(&route53.ListResourceRecordSetsOutput{ - ResourceRecordSets: []*route53.ResourceRecordSet{ - {Name: aws.String("1")}, - {Name: aws.String("2")}, - {Name: aws.String("3")}, - }, - }, false) - callback(&route53.ListResourceRecordSetsOutput{ - ResourceRecordSets: []*route53.ResourceRecordSet{ - {Name: aws.String("4")}, - {Name: aws.String("5")}, - {Name: aws.String("6")}, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*route53.ResourceRecordSet{ - {Name: aws.String("1")}, - {Name: aws.String("2")}, - {Name: aws.String("3")}, - {Name: aws.String("4")}, - {Name: aws.String("5")}, - {Name: aws.String("6")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := awstest.MockFakeRoute53{} - tt.mocks(&client) - r := &route53Repository{ - client: &client, - cache: store, - } - for _, id := range tt.zoneIds { - got, err := r.ListRecordsForZone(id) - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListRecordsForZone(id) - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*route53.ResourceRecordSet{}, store.Get(fmt.Sprintf("route53ListRecordsForZone_%s", id))) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - } - }) - } -} diff --git a/pkg/remote/aws/repository/s3_repository.go b/pkg/remote/aws/repository/s3_repository.go deleted file mode 100644 index cabd14bf..00000000 --- a/pkg/remote/aws/repository/s3_repository.go +++ /dev/null @@ -1,287 +0,0 @@ -package repository - -import ( - "fmt" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/remote/aws/client" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type S3Repository interface { - ListAllBuckets() ([]*s3.Bucket, error) - GetBucketNotification(bucketName, region string) (*s3.NotificationConfiguration, error) - GetBucketPolicy(bucketName, region string) (*string, error) - GetBucketPublicAccessBlock(bucketName, region string) (*s3.PublicAccessBlockConfiguration, error) - ListBucketInventoryConfigurations(bucket *s3.Bucket, region string) ([]*s3.InventoryConfiguration, error) - ListBucketMetricsConfigurations(bucket *s3.Bucket, region string) ([]*s3.MetricsConfiguration, error) - ListBucketAnalyticsConfigurations(bucket *s3.Bucket, region string) ([]*s3.AnalyticsConfiguration, error) - GetBucketLocation(bucketName string) (string, error) -} - -type s3Repository struct { - clientFactory client.AwsClientFactoryInterface - cache cache.Cache -} - -func NewS3Repository(factory client.AwsClientFactoryInterface, c cache.Cache) *s3Repository { - return &s3Repository{ - factory, - c, - } -} - -func (s *s3Repository) ListAllBuckets() ([]*s3.Bucket, error) { - cacheKey := "s3ListAllBuckets" - v := s.cache.GetAndLock(cacheKey) - defer s.cache.Unlock(cacheKey) - if v != nil { - return v.([]*s3.Bucket), nil - } - - out, err := s.clientFactory.GetS3Client(nil).ListBuckets(&s3.ListBucketsInput{}) - if err != nil { - return nil, err - } - s.cache.Put(cacheKey, out.Buckets) - return out.Buckets, nil -} - -func (s *s3Repository) GetBucketPolicy(bucketName, region string) (*string, error) { - cacheKey := fmt.Sprintf("s3GetBucketPolicy_%s_%s", bucketName, region) - if v := s.cache.Get(cacheKey); v != nil { - return v.(*string), nil - } - policy, err := s.clientFactory. - GetS3Client(&awssdk.Config{Region: ®ion}). - GetBucketPolicy( - &s3.GetBucketPolicyInput{Bucket: &bucketName}, - ) - if err != nil { - if awsErr, ok := err.(awserr.Error); ok { - if awsErr.Code() == "NoSuchBucketPolicy" { - return nil, nil - } - } - return nil, errors.Wrapf( - err, - "Error listing bucket policy %s", - bucketName, - ) - } - - result := policy.Policy - if result != nil && *result == "" { - result = nil - } - - s.cache.Put(cacheKey, result) - return result, nil -} - -func (s *s3Repository) GetBucketPublicAccessBlock(bucketName, region string) (*s3.PublicAccessBlockConfiguration, error) { - cacheKey := fmt.Sprintf("s3GetBucketPublicAccessBlock_%s_%s", bucketName, region) - if v := s.cache.Get(cacheKey); v != nil { - return v.(*s3.PublicAccessBlockConfiguration), nil - } - response, err := s.clientFactory. - GetS3Client(&awssdk.Config{Region: ®ion}). - GetPublicAccessBlock(&s3.GetPublicAccessBlockInput{Bucket: &bucketName}) - - if err != nil { - if awsErr, ok := err.(awserr.Error); ok { - if awsErr.Code() == "NoSuchPublicAccessBlockConfiguration" { - return nil, nil - } - } - return nil, errors.Wrapf( - err, - "Error listing bucket public access block %s", - bucketName, - ) - } - - result := response.PublicAccessBlockConfiguration - - s.cache.Put(cacheKey, result) - return result, nil -} - -func (s *s3Repository) GetBucketNotification(bucketName, region string) (*s3.NotificationConfiguration, error) { - cacheKey := fmt.Sprintf("s3GetBucketNotification_%s_%s", bucketName, region) - if v := s.cache.Get(cacheKey); v != nil { - return v.(*s3.NotificationConfiguration), nil - } - bucketNotificationConfig, err := s.clientFactory. - GetS3Client(&awssdk.Config{Region: ®ion}). - GetBucketNotificationConfiguration( - &s3.GetBucketNotificationConfigurationRequest{Bucket: &bucketName}, - ) - if err != nil { - return nil, errors.Wrapf( - err, - "Error listing bucket notification configuration %s", - bucketName, - ) - } - - result := bucketNotificationConfig - if s.notificationIsEmpty(bucketNotificationConfig) { - result = nil - } - - s.cache.Put(cacheKey, result) - return result, nil -} - -func (s *s3Repository) notificationIsEmpty(notification *s3.NotificationConfiguration) bool { - return notification.TopicConfigurations == nil && - notification.QueueConfigurations == nil && - notification.LambdaFunctionConfigurations == nil -} - -func (s *s3Repository) ListBucketInventoryConfigurations(bucket *s3.Bucket, region string) ([]*s3.InventoryConfiguration, error) { - cacheKey := fmt.Sprintf("s3ListBucketInventoryConfigurations_%s_%s", *bucket.Name, region) - if v := s.cache.Get(cacheKey); v != nil { - return v.([]*s3.InventoryConfiguration), nil - } - - inventoryConfigurations := make([]*s3.InventoryConfiguration, 0) - s3client := s.clientFactory.GetS3Client(&awssdk.Config{Region: ®ion}) - request := &s3.ListBucketInventoryConfigurationsInput{ - Bucket: bucket.Name, - ContinuationToken: nil, - } - - for { - configurations, err := s3client.ListBucketInventoryConfigurations(request) - if err != nil { - return nil, errors.Wrapf( - err, - "Error listing bucket inventory configuration %s", - *bucket.Name, - ) - } - inventoryConfigurations = append(inventoryConfigurations, configurations.InventoryConfigurationList...) - if configurations.IsTruncated != nil && *configurations.IsTruncated { - request.ContinuationToken = configurations.NextContinuationToken - } else { - break - } - } - - s.cache.Put(cacheKey, inventoryConfigurations) - return inventoryConfigurations, nil -} - -func (s *s3Repository) ListBucketMetricsConfigurations(bucket *s3.Bucket, region string) ([]*s3.MetricsConfiguration, error) { - cacheKey := fmt.Sprintf("s3ListBucketMetricsConfigurations_%s_%s", *bucket.Name, region) - if v := s.cache.Get(cacheKey); v != nil { - return v.([]*s3.MetricsConfiguration), nil - } - - metricsConfigurationList := make([]*s3.MetricsConfiguration, 0) - s3client := s.clientFactory.GetS3Client(&awssdk.Config{Region: ®ion}) - request := &s3.ListBucketMetricsConfigurationsInput{ - Bucket: bucket.Name, - ContinuationToken: nil, - } - - for { - configurations, err := s3client.ListBucketMetricsConfigurations(request) - if err != nil { - return nil, errors.Wrapf( - err, - "Error listing bucket metrics configuration %s", - *bucket.Name, - ) - } - metricsConfigurationList = append(metricsConfigurationList, configurations.MetricsConfigurationList...) - if configurations.IsTruncated != nil && *configurations.IsTruncated { - request.ContinuationToken = configurations.NextContinuationToken - } else { - break - } - } - - s.cache.Put(cacheKey, metricsConfigurationList) - return metricsConfigurationList, nil -} - -func (s *s3Repository) ListBucketAnalyticsConfigurations(bucket *s3.Bucket, region string) ([]*s3.AnalyticsConfiguration, error) { - cacheKey := fmt.Sprintf("s3ListBucketAnalyticsConfigurations_%s_%s", *bucket.Name, region) - if v := s.cache.Get(cacheKey); v != nil { - return v.([]*s3.AnalyticsConfiguration), nil - } - - analyticsConfigurationList := make([]*s3.AnalyticsConfiguration, 0) - s3client := s.clientFactory.GetS3Client(&awssdk.Config{Region: ®ion}) - request := &s3.ListBucketAnalyticsConfigurationsInput{ - Bucket: bucket.Name, - ContinuationToken: nil, - } - - for { - configurations, err := s3client.ListBucketAnalyticsConfigurations(request) - if err != nil { - return nil, errors.Wrapf( - err, - "Error listing bucket analytics configuration %s", - *bucket.Name, - ) - } - analyticsConfigurationList = append(analyticsConfigurationList, configurations.AnalyticsConfigurationList...) - - if configurations.IsTruncated != nil && *configurations.IsTruncated { - request.ContinuationToken = configurations.NextContinuationToken - } else { - break - } - } - - s.cache.Put(cacheKey, analyticsConfigurationList) - return analyticsConfigurationList, nil -} - -func (s *s3Repository) GetBucketLocation(bucketName string) (string, error) { - cacheKey := fmt.Sprintf("s3GetBucketLocation_%s", bucketName) - v := s.cache.GetAndLock(cacheKey) - defer s.cache.Unlock(cacheKey) - if v != nil { - return v.(string), nil - } - - bucketLocationRequest := s3.GetBucketLocationInput{Bucket: &bucketName} - bucketLocationResponse, err := s.clientFactory.GetS3Client(nil).GetBucketLocation(&bucketLocationRequest) - if err != nil { - awsErr, ok := err.(awserr.Error) - if ok && awsErr.Code() == s3.ErrCodeNoSuchBucket { - logrus.WithFields(logrus.Fields{ - "bucket": bucketName, - }).Warning("Unable to retrieve bucket region, this may be an inconsistency in S3 api for fresh deleted bucket, skipping ...") - return "", nil - } - return "", err - } - - var location string - - // Buckets in Region us-east-1 have a LocationConstraint of null. - // https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetBucketLocation.html#API_GetBucketLocation_ResponseSyntax - if bucketLocationResponse.LocationConstraint == nil { - location = "us-east-1" - } else { - location = *bucketLocationResponse.LocationConstraint - } - - if location == "EU" { - location = "eu-west-1" - } - - s.cache.Put(cacheKey, location) - return location, nil -} diff --git a/pkg/remote/aws/repository/s3_repository_test.go b/pkg/remote/aws/repository/s3_repository_test.go deleted file mode 100644 index 4e49b6bf..00000000 --- a/pkg/remote/aws/repository/s3_repository_test.go +++ /dev/null @@ -1,866 +0,0 @@ -package repository - -import ( - "fmt" - "strings" - "testing" - - "github.com/aws/aws-sdk-go/aws" - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/pkg/errors" - "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/remote/aws/client" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - "github.com/stretchr/testify/assert" -) - -func Test_s3Repository_ListAllBuckets(t *testing.T) { - - tests := []struct { - name string - mocks func(client *awstest.MockFakeS3) - want []*s3.Bucket - wantErr error - }{ - { - name: "List buckets", - mocks: func(client *awstest.MockFakeS3) { - client.On("ListBuckets", &s3.ListBucketsInput{}).Return( - &s3.ListBucketsOutput{ - Buckets: []*s3.Bucket{ - {Name: aws.String("bucket1")}, - {Name: aws.String("bucket2")}, - {Name: aws.String("bucket3")}, - }, - }, - nil, - ).Once() - }, - want: []*s3.Bucket{ - {Name: aws.String("bucket1")}, - {Name: aws.String("bucket2")}, - {Name: aws.String("bucket3")}, - }, - }, - { - name: "Error listing buckets", - mocks: func(client *awstest.MockFakeS3) { - client.On("ListBuckets", &s3.ListBucketsInput{}).Return( - nil, - awserr.NewRequestFailure(nil, 403, ""), - ).Once() - }, - want: nil, - wantErr: awserr.NewRequestFailure(nil, 403, ""), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - mockedClient := &awstest.MockFakeS3{} - tt.mocks(mockedClient) - factory := client.MockAwsClientFactoryInterface{} - factory.On("GetS3Client", (*aws.Config)(nil)).Return(mockedClient).Once() - r := NewS3Repository(&factory, store) - got, err := r.ListAllBuckets() - factory.AssertExpectations(t) - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllBuckets() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*s3.Bucket{}, store.Get("s3ListAllBuckets")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_s3Repository_GetBucketNotification(t *testing.T) { - - tests := []struct { - name string - bucketName, region string - mocks func(client *awstest.MockFakeS3) - want *s3.NotificationConfiguration - wantErr string - }{ - { - name: "get empty bucket notification", - bucketName: "test-bucket", - region: "us-east-1", - mocks: func(client *awstest.MockFakeS3) { - client.On("GetBucketNotificationConfiguration", &s3.GetBucketNotificationConfigurationRequest{ - Bucket: aws.String("test-bucket"), - }).Return( - &s3.NotificationConfiguration{}, - nil, - ).Once() - }, - want: nil, - }, - { - name: "get bucket notification with lambda config", - bucketName: "test-bucket", - region: "us-east-1", - mocks: func(client *awstest.MockFakeS3) { - client.On("GetBucketNotificationConfiguration", &s3.GetBucketNotificationConfigurationRequest{ - Bucket: aws.String("test-bucket"), - }).Return( - &s3.NotificationConfiguration{ - LambdaFunctionConfigurations: []*s3.LambdaFunctionConfiguration{ - { - Id: aws.String("test"), - }, - }, - }, - nil, - ).Once() - }, - want: &s3.NotificationConfiguration{ - LambdaFunctionConfigurations: []*s3.LambdaFunctionConfiguration{ - { - Id: aws.String("test"), - }, - }, - }, - }, - { - name: "get bucket notification with queue config", - bucketName: "test-bucket", - region: "us-east-1", - mocks: func(client *awstest.MockFakeS3) { - client.On("GetBucketNotificationConfiguration", &s3.GetBucketNotificationConfigurationRequest{ - Bucket: aws.String("test-bucket"), - }).Return( - &s3.NotificationConfiguration{ - QueueConfigurations: []*s3.QueueConfiguration{ - { - Id: awssdk.String("test"), - }, - }, - }, - nil, - ).Once() - }, - want: &s3.NotificationConfiguration{ - QueueConfigurations: []*s3.QueueConfiguration{ - { - Id: awssdk.String("test"), - }, - }, - }, - }, - { - name: "get bucket notification with topic config", - bucketName: "test-bucket", - region: "us-east-1", - mocks: func(client *awstest.MockFakeS3) { - client.On("GetBucketNotificationConfiguration", &s3.GetBucketNotificationConfigurationRequest{ - Bucket: aws.String("test-bucket"), - }).Return( - &s3.NotificationConfiguration{ - TopicConfigurations: []*s3.TopicConfiguration{ - { - Id: awssdk.String("test"), - }, - }, - }, - nil, - ).Once() - }, - want: &s3.NotificationConfiguration{ - TopicConfigurations: []*s3.TopicConfiguration{ - { - Id: awssdk.String("test"), - }, - }, - }, - }, - { - name: "get bucket location when error", - bucketName: "test-bucket", - region: "us-east-1", - mocks: func(client *awstest.MockFakeS3) { - client.On("GetBucketNotificationConfiguration", &s3.GetBucketNotificationConfigurationRequest{ - Bucket: aws.String("test-bucket"), - }).Return( - nil, - awserr.New("UnknownError", "aws error", nil), - ).Once() - }, - wantErr: "Error listing bucket notification configuration test-bucket: UnknownError: aws error", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - mockedClient := &awstest.MockFakeS3{} - tt.mocks(mockedClient) - factory := client.MockAwsClientFactoryInterface{} - factory.On("GetS3Client", &aws.Config{Region: &tt.region}).Return(mockedClient).Once() - r := NewS3Repository(&factory, store) - got, err := r.GetBucketNotification(tt.bucketName, tt.region) - factory.AssertExpectations(t) - if err != nil && tt.wantErr == "" { - t.Fatalf("Unexpected error %+v", err) - } - if err != nil { - assert.Equal(t, tt.wantErr, err.Error()) - } - - if err == nil && tt.want != nil { - // Check that results were cached - cachedData, err := r.GetBucketNotification(tt.bucketName, tt.region) - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, &s3.NotificationConfiguration{}, store.Get(fmt.Sprintf("s3GetBucketNotification_%s_%s", tt.bucketName, tt.region))) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_s3Repository_GetBucketPolicy(t *testing.T) { - - tests := []struct { - name string - bucketName, region string - mocks func(client *awstest.MockFakeS3) - want *string - wantErr string - }{ - { - name: "get nil bucket policy", - bucketName: "test-bucket", - region: "us-east-1", - mocks: func(client *awstest.MockFakeS3) { - client.On("GetBucketPolicy", &s3.GetBucketPolicyInput{ - Bucket: aws.String("test-bucket"), - }).Return( - &s3.GetBucketPolicyOutput{}, - nil, - ).Once() - }, - want: nil, - }, - { - name: "get empty bucket policy", - bucketName: "test-bucket", - region: "us-east-1", - mocks: func(client *awstest.MockFakeS3) { - client.On("GetBucketPolicy", &s3.GetBucketPolicyInput{ - Bucket: aws.String("test-bucket"), - }).Return( - &s3.GetBucketPolicyOutput{ - Policy: awssdk.String(""), - }, - nil, - ).Once() - }, - want: nil, - }, - { - name: "get bucket policy", - bucketName: "test-bucket", - region: "us-east-1", - mocks: func(client *awstest.MockFakeS3) { - client.On("GetBucketPolicy", &s3.GetBucketPolicyInput{ - Bucket: aws.String("test-bucket"), - }).Return( - &s3.GetBucketPolicyOutput{ - Policy: awssdk.String("foobar"), - }, - nil, - ).Once() - }, - want: awssdk.String("foobar"), - }, - { - name: "get bucket location on 404", - bucketName: "test-bucket", - region: "us-east-1", - mocks: func(client *awstest.MockFakeS3) { - client.On("GetBucketPolicy", &s3.GetBucketPolicyInput{ - Bucket: aws.String("test-bucket"), - }).Return( - nil, - awserr.New("NoSuchBucketPolicy", "", nil), - ).Once() - }, - want: nil, - }, - { - name: "get bucket location when error", - bucketName: "test-bucket", - region: "us-east-1", - mocks: func(client *awstest.MockFakeS3) { - client.On("GetBucketPolicy", &s3.GetBucketPolicyInput{ - Bucket: aws.String("test-bucket"), - }).Return( - nil, - awserr.New("UnknownError", "aws error", nil), - ).Once() - }, - wantErr: "Error listing bucket policy test-bucket: UnknownError: aws error", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - mockedClient := &awstest.MockFakeS3{} - tt.mocks(mockedClient) - factory := client.MockAwsClientFactoryInterface{} - factory.On("GetS3Client", &aws.Config{Region: &tt.region}).Return(mockedClient).Once() - r := NewS3Repository(&factory, store) - got, err := r.GetBucketPolicy(tt.bucketName, tt.region) - factory.AssertExpectations(t) - if err != nil && tt.wantErr == "" { - t.Fatalf("Unexpected error %+v", err) - } - if err != nil { - assert.Equal(t, tt.wantErr, err.Error()) - } - - if err == nil && tt.want != nil { - // Check that results were cached - cachedData, err := r.GetBucketPolicy(tt.bucketName, tt.region) - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, awssdk.String(""), store.Get(fmt.Sprintf("s3GetBucketPolicy_%s_%s", tt.bucketName, tt.region))) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_s3Repository_ListBucketInventoryConfigurations(t *testing.T) { - tests := []struct { - name string - input struct { - bucket s3.Bucket - region string - } - mocks func(client *awstest.MockFakeS3) - want []*s3.InventoryConfiguration - wantErr string - }{ - { - name: "List inventory configs", - input: struct { - bucket s3.Bucket - region string - }{ - bucket: s3.Bucket{ - Name: awssdk.String("test-bucket"), - }, - region: "us-east-1", - }, - mocks: func(client *awstest.MockFakeS3) { - client.On( - "ListBucketInventoryConfigurations", - &s3.ListBucketInventoryConfigurationsInput{ - Bucket: awssdk.String("test-bucket"), - ContinuationToken: nil, - }, - ).Return( - &s3.ListBucketInventoryConfigurationsOutput{ - InventoryConfigurationList: []*s3.InventoryConfiguration{ - {Id: awssdk.String("config1")}, - {Id: awssdk.String("config2")}, - {Id: awssdk.String("config3")}, - }, - IsTruncated: awssdk.Bool(true), - NextContinuationToken: awssdk.String("nexttoken"), - }, - nil, - ).Once() - client.On( - "ListBucketInventoryConfigurations", - &s3.ListBucketInventoryConfigurationsInput{ - Bucket: awssdk.String("test-bucket"), - ContinuationToken: awssdk.String("nexttoken"), - }, - ).Return( - &s3.ListBucketInventoryConfigurationsOutput{ - InventoryConfigurationList: []*s3.InventoryConfiguration{ - {Id: awssdk.String("config4")}, - {Id: awssdk.String("config5")}, - {Id: awssdk.String("config6")}, - }, - IsTruncated: awssdk.Bool(false), - }, - nil, - ).Once() - }, - want: []*s3.InventoryConfiguration{ - {Id: awssdk.String("config1")}, - {Id: awssdk.String("config2")}, - {Id: awssdk.String("config3")}, - {Id: awssdk.String("config4")}, - {Id: awssdk.String("config5")}, - {Id: awssdk.String("config6")}, - }, - }, - { - name: "Error listing inventory configs", - input: struct { - bucket s3.Bucket - region string - }{ - bucket: s3.Bucket{ - Name: awssdk.String("test-bucket"), - }, - region: "us-east-1", - }, - mocks: func(client *awstest.MockFakeS3) { - client.On( - "ListBucketInventoryConfigurations", - &s3.ListBucketInventoryConfigurationsInput{ - Bucket: awssdk.String("test-bucket"), - }, - ).Return( - nil, - errors.New("aws error"), - ).Once() - }, - want: nil, - wantErr: "Error listing bucket inventory configuration test-bucket: aws error", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - mockedClient := &awstest.MockFakeS3{} - tt.mocks(mockedClient) - factory := client.MockAwsClientFactoryInterface{} - factory.On("GetS3Client", &aws.Config{Region: awssdk.String(tt.input.region)}).Return(mockedClient).Once() - r := NewS3Repository(&factory, store) - got, err := r.ListBucketInventoryConfigurations(&tt.input.bucket, tt.input.region) - factory.AssertExpectations(t) - if err != nil && tt.wantErr == "" { - t.Fatalf("Unexpected error %+v", err) - } - if err != nil { - assert.Equal(t, tt.wantErr, err.Error()) - } - - if err == nil { - // Check that results were cached - cachedData, err := r.ListBucketInventoryConfigurations(&tt.input.bucket, tt.input.region) - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*s3.InventoryConfiguration{}, store.Get(fmt.Sprintf("s3ListBucketInventoryConfigurations_%s_%s", *tt.input.bucket.Name, tt.input.region))) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_s3Repository_ListBucketMetricsConfigurations(t *testing.T) { - tests := []struct { - name string - input struct { - bucket s3.Bucket - region string - } - mocks func(client *awstest.MockFakeS3) - want []*s3.MetricsConfiguration - wantErr string - }{ - { - name: "List metrics configs", - input: struct { - bucket s3.Bucket - region string - }{ - bucket: s3.Bucket{ - Name: awssdk.String("test-bucket"), - }, - region: "us-east-1", - }, - mocks: func(client *awstest.MockFakeS3) { - client.On( - "ListBucketMetricsConfigurations", - &s3.ListBucketMetricsConfigurationsInput{ - Bucket: awssdk.String("test-bucket"), - ContinuationToken: nil, - }, - ).Return( - &s3.ListBucketMetricsConfigurationsOutput{ - MetricsConfigurationList: []*s3.MetricsConfiguration{ - {Id: awssdk.String("metric1")}, - {Id: awssdk.String("metric2")}, - {Id: awssdk.String("metric3")}, - }, - IsTruncated: awssdk.Bool(true), - NextContinuationToken: awssdk.String("nexttoken"), - }, - nil, - ).Once() - client.On( - "ListBucketMetricsConfigurations", - &s3.ListBucketMetricsConfigurationsInput{ - Bucket: awssdk.String("test-bucket"), - ContinuationToken: awssdk.String("nexttoken"), - }, - ).Return( - &s3.ListBucketMetricsConfigurationsOutput{ - MetricsConfigurationList: []*s3.MetricsConfiguration{ - {Id: awssdk.String("metric4")}, - {Id: awssdk.String("metric5")}, - {Id: awssdk.String("metric6")}, - }, - IsTruncated: awssdk.Bool(false), - }, - nil, - ).Once() - }, - want: []*s3.MetricsConfiguration{ - {Id: awssdk.String("metric1")}, - {Id: awssdk.String("metric2")}, - {Id: awssdk.String("metric3")}, - {Id: awssdk.String("metric4")}, - {Id: awssdk.String("metric5")}, - {Id: awssdk.String("metric6")}, - }, - }, - { - name: "Error listing metrics configs", - input: struct { - bucket s3.Bucket - region string - }{ - bucket: s3.Bucket{ - Name: awssdk.String("test-bucket"), - }, - region: "us-east-1", - }, - mocks: func(client *awstest.MockFakeS3) { - client.On( - "ListBucketMetricsConfigurations", - &s3.ListBucketMetricsConfigurationsInput{ - Bucket: awssdk.String("test-bucket"), - }, - ).Return( - nil, - errors.New("aws error"), - ).Once() - }, - want: nil, - wantErr: "Error listing bucket metrics configuration test-bucket: aws error", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - mockedClient := &awstest.MockFakeS3{} - tt.mocks(mockedClient) - factory := client.MockAwsClientFactoryInterface{} - factory.On("GetS3Client", &aws.Config{Region: awssdk.String(tt.input.region)}).Return(mockedClient).Once() - r := NewS3Repository(&factory, store) - got, err := r.ListBucketMetricsConfigurations(&tt.input.bucket, tt.input.region) - factory.AssertExpectations(t) - if err != nil && tt.wantErr == "" { - t.Fatalf("Unexpected error %+v", err) - } - if err != nil { - assert.Equal(t, tt.wantErr, err.Error()) - } - - if err == nil { - // Check that results were cached - cachedData, err := r.ListBucketMetricsConfigurations(&tt.input.bucket, tt.input.region) - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*s3.MetricsConfiguration{}, store.Get(fmt.Sprintf("s3ListBucketMetricsConfigurations_%s_%s", *tt.input.bucket.Name, tt.input.region))) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_s3Repository_ListBucketAnalyticsConfigurations(t *testing.T) { - tests := []struct { - name string - input struct { - bucket s3.Bucket - region string - } - mocks func(client *awstest.MockFakeS3) - want []*s3.AnalyticsConfiguration - wantErr string - }{ - { - name: "List analytics configs", - input: struct { - bucket s3.Bucket - region string - }{ - bucket: s3.Bucket{ - Name: awssdk.String("test-bucket"), - }, - region: "us-east-1", - }, - mocks: func(client *awstest.MockFakeS3) { - client.On( - "ListBucketAnalyticsConfigurations", - &s3.ListBucketAnalyticsConfigurationsInput{ - Bucket: awssdk.String("test-bucket"), - ContinuationToken: nil, - }, - ).Return( - &s3.ListBucketAnalyticsConfigurationsOutput{ - AnalyticsConfigurationList: []*s3.AnalyticsConfiguration{ - {Id: awssdk.String("analytic1")}, - {Id: awssdk.String("analytic2")}, - {Id: awssdk.String("analytic3")}, - }, - IsTruncated: awssdk.Bool(true), - NextContinuationToken: awssdk.String("nexttoken"), - }, - nil, - ).Once() - client.On( - "ListBucketAnalyticsConfigurations", - &s3.ListBucketAnalyticsConfigurationsInput{ - Bucket: awssdk.String("test-bucket"), - ContinuationToken: awssdk.String("nexttoken"), - }, - ).Return( - &s3.ListBucketAnalyticsConfigurationsOutput{ - AnalyticsConfigurationList: []*s3.AnalyticsConfiguration{ - {Id: awssdk.String("analytic4")}, - {Id: awssdk.String("analytic5")}, - {Id: awssdk.String("analytic6")}, - }, - IsTruncated: awssdk.Bool(false), - }, - nil, - ).Once() - }, - want: []*s3.AnalyticsConfiguration{ - {Id: awssdk.String("analytic1")}, - {Id: awssdk.String("analytic2")}, - {Id: awssdk.String("analytic3")}, - {Id: awssdk.String("analytic4")}, - {Id: awssdk.String("analytic5")}, - {Id: awssdk.String("analytic6")}, - }, - }, - { - name: "Error listing analytics configs", - input: struct { - bucket s3.Bucket - region string - }{ - bucket: s3.Bucket{ - Name: awssdk.String("test-bucket"), - }, - region: "us-east-1", - }, - mocks: func(client *awstest.MockFakeS3) { - client.On( - "ListBucketAnalyticsConfigurations", - &s3.ListBucketAnalyticsConfigurationsInput{ - Bucket: awssdk.String("test-bucket"), - }, - ).Return( - nil, - errors.New("aws error"), - ).Once() - }, - want: nil, - wantErr: "Error listing bucket analytics configuration test-bucket: aws error", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - mockedClient := &awstest.MockFakeS3{} - tt.mocks(mockedClient) - factory := client.MockAwsClientFactoryInterface{} - factory.On("GetS3Client", &aws.Config{Region: awssdk.String(tt.input.region)}).Return(mockedClient).Once() - r := NewS3Repository(&factory, store) - got, err := r.ListBucketAnalyticsConfigurations(&tt.input.bucket, tt.input.region) - factory.AssertExpectations(t) - if err != nil && tt.wantErr == "" { - t.Fatalf("Unexpected error %+v", err) - } - if err != nil { - assert.Equal(t, tt.wantErr, err.Error()) - } - - if err == nil { - // Check that results were cached - cachedData, err := r.ListBucketAnalyticsConfigurations(&tt.input.bucket, tt.input.region) - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*s3.AnalyticsConfiguration{}, store.Get(fmt.Sprintf("s3ListBucketAnalyticsConfigurations_%s_%s", *tt.input.bucket.Name, tt.input.region))) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_s3Repository_GetBucketLocation(t *testing.T) { - - tests := []struct { - name string - bucket *s3.Bucket - mocks func(client *awstest.MockFakeS3) - want string - wantErr string - }{ - { - name: "get bucket location", - bucket: &s3.Bucket{ - Name: awssdk.String("test-bucket"), - }, - mocks: func(client *awstest.MockFakeS3) { - client.On("GetBucketLocation", &s3.GetBucketLocationInput{ - Bucket: awssdk.String("test-bucket"), - }).Return( - &s3.GetBucketLocationOutput{ - LocationConstraint: awssdk.String("eu-east-1"), - }, - nil, - ).Once() - }, - want: "eu-east-1", - }, - { - name: "get bucket location for us-east-2", - bucket: &s3.Bucket{ - Name: awssdk.String("test-bucket"), - }, - mocks: func(client *awstest.MockFakeS3) { - client.On("GetBucketLocation", &s3.GetBucketLocationInput{ - Bucket: awssdk.String("test-bucket"), - }).Return( - &s3.GetBucketLocationOutput{}, - nil, - ).Once() - }, - want: "us-east-1", - }, - { - name: "get bucket location when no such bucket", - bucket: &s3.Bucket{ - Name: awssdk.String("test-bucket"), - }, - mocks: func(client *awstest.MockFakeS3) { - client.On("GetBucketLocation", &s3.GetBucketLocationInput{ - Bucket: awssdk.String("test-bucket"), - }).Return( - nil, - awserr.New(s3.ErrCodeNoSuchBucket, "", nil), - ).Once() - }, - want: "", - }, - { - name: "get bucket location when error", - bucket: &s3.Bucket{ - Name: awssdk.String("test-bucket"), - }, - mocks: func(client *awstest.MockFakeS3) { - client.On("GetBucketLocation", &s3.GetBucketLocationInput{ - Bucket: awssdk.String("test-bucket"), - }).Return( - nil, - awserr.New("UnknownError", "aws error", nil), - ).Once() - }, - wantErr: "UnknownError: aws error", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - mockedClient := &awstest.MockFakeS3{} - tt.mocks(mockedClient) - factory := client.MockAwsClientFactoryInterface{} - factory.On("GetS3Client", (*aws.Config)(nil)).Return(mockedClient).Once() - r := NewS3Repository(&factory, store) - got, err := r.GetBucketLocation(*tt.bucket.Name) - factory.AssertExpectations(t) - if err != nil && tt.wantErr == "" { - t.Fatalf("Unexpected error %+v", err) - } - if err != nil { - assert.Equal(t, tt.wantErr, err.Error()) - } - - if err == nil && tt.want != "" { - // Check that results were cached - cachedData, err := r.GetBucketLocation(*tt.bucket.Name) - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, "", store.Get(fmt.Sprintf("s3GetBucketLocation_%s", *tt.bucket.Name))) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} diff --git a/pkg/remote/aws/repository/sns_repository.go b/pkg/remote/aws/repository/sns_repository.go deleted file mode 100644 index 1e9f3a20..00000000 --- a/pkg/remote/aws/repository/sns_repository.go +++ /dev/null @@ -1,67 +0,0 @@ -package repository - -import ( - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sns" - "github.com/aws/aws-sdk-go/service/sns/snsiface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type SNSRepository interface { - ListAllTopics() ([]*sns.Topic, error) - ListAllSubscriptions() ([]*sns.Subscription, error) -} - -type snsRepository struct { - client snsiface.SNSAPI - cache cache.Cache -} - -func NewSNSRepository(session *session.Session, c cache.Cache) *snsRepository { - return &snsRepository{ - sns.New(session), - c, - } -} - -func (r *snsRepository) ListAllTopics() ([]*sns.Topic, error) { - - cacheKey := "snsListAllTopics" - v := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - if v != nil { - return v.([]*sns.Topic), nil - } - - var topics []*sns.Topic - input := &sns.ListTopicsInput{} - err := r.client.ListTopicsPages(input, func(res *sns.ListTopicsOutput, lastPage bool) bool { - topics = append(topics, res.Topics...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, topics) - return topics, nil -} - -func (r *snsRepository) ListAllSubscriptions() ([]*sns.Subscription, error) { - if v := r.cache.Get("snsListAllSubscriptions"); v != nil { - return v.([]*sns.Subscription), nil - } - - var subscriptions []*sns.Subscription - input := &sns.ListSubscriptionsInput{} - err := r.client.ListSubscriptionsPages(input, func(res *sns.ListSubscriptionsOutput, lastPage bool) bool { - subscriptions = append(subscriptions, res.Subscriptions...) - return !lastPage - }) - if err != nil { - return nil, err - } - - r.cache.Put("snsListAllSubscriptions", subscriptions) - return subscriptions, nil -} diff --git a/pkg/remote/aws/repository/sns_repository_test.go b/pkg/remote/aws/repository/sns_repository_test.go deleted file mode 100644 index 3554d660..00000000 --- a/pkg/remote/aws/repository/sns_repository_test.go +++ /dev/null @@ -1,161 +0,0 @@ -package repository - -import ( - "strings" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - "github.com/stretchr/testify/mock" - - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" - - "github.com/aws/aws-sdk-go/service/sns" -) - -func Test_snsRepository_ListAllTopics(t *testing.T) { - - tests := []struct { - name string - mocks func(client *awstest.MockFakeSNS) - want []*sns.Topic - wantErr error - }{ - { - name: "List with 2 pages", - mocks: func(client *awstest.MockFakeSNS) { - client.On("ListTopicsPages", - &sns.ListTopicsInput{}, - mock.MatchedBy(func(callback func(res *sns.ListTopicsOutput, lastPage bool) bool) bool { - callback(&sns.ListTopicsOutput{ - Topics: []*sns.Topic{ - {TopicArn: aws.String("arn1")}, - {TopicArn: aws.String("arn2")}, - {TopicArn: aws.String("arn3")}, - }, - }, false) - callback(&sns.ListTopicsOutput{ - Topics: []*sns.Topic{ - {TopicArn: aws.String("arn4")}, - {TopicArn: aws.String("arn5")}, - {TopicArn: aws.String("arn6")}, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*sns.Topic{ - {TopicArn: aws.String("arn1")}, - {TopicArn: aws.String("arn2")}, - {TopicArn: aws.String("arn3")}, - {TopicArn: aws.String("arn4")}, - {TopicArn: aws.String("arn5")}, - {TopicArn: aws.String("arn6")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeSNS{} - tt.mocks(client) - r := &snsRepository{ - client: client, - cache: store, - } - got, err := r.ListAllTopics() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllTopics() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*sns.Topic{}, store.Get("snsListAllTopics")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_snsRepository_ListAllSubscriptions(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeSNS) - want []*sns.Subscription - wantErr error - }{ - { - name: "List with 2 pages", - mocks: func(client *awstest.MockFakeSNS) { - client.On("ListSubscriptionsPages", - &sns.ListSubscriptionsInput{}, - mock.MatchedBy(func(callback func(res *sns.ListSubscriptionsOutput, lastPage bool) bool) bool { - callback(&sns.ListSubscriptionsOutput{ - Subscriptions: []*sns.Subscription{ - {TopicArn: aws.String("arn1"), SubscriptionArn: aws.String("SubArn1")}, - {TopicArn: aws.String("arn2"), SubscriptionArn: aws.String("SubArn2")}, - {TopicArn: aws.String("arn3"), SubscriptionArn: aws.String("SubArn3")}, - }, - }, false) - callback(&sns.ListSubscriptionsOutput{ - Subscriptions: []*sns.Subscription{ - {TopicArn: aws.String("arn4"), SubscriptionArn: aws.String("SubArn4")}, - {TopicArn: aws.String("arn5"), SubscriptionArn: aws.String("SubArn5")}, - {TopicArn: aws.String("arn6"), SubscriptionArn: aws.String("SubArn6")}, - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*sns.Subscription{ - {TopicArn: aws.String("arn1"), SubscriptionArn: aws.String("SubArn1")}, - {TopicArn: aws.String("arn2"), SubscriptionArn: aws.String("SubArn2")}, - {TopicArn: aws.String("arn3"), SubscriptionArn: aws.String("SubArn3")}, - {TopicArn: aws.String("arn4"), SubscriptionArn: aws.String("SubArn4")}, - {TopicArn: aws.String("arn5"), SubscriptionArn: aws.String("SubArn5")}, - {TopicArn: aws.String("arn6"), SubscriptionArn: aws.String("SubArn6")}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeSNS{} - tt.mocks(client) - r := &snsRepository{ - client: client, - cache: store, - } - got, err := r.ListAllSubscriptions() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllSubscriptions() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*sns.Subscription{}, store.Get("snsListAllSubscriptions")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} diff --git a/pkg/remote/aws/repository/sqs_repository.go b/pkg/remote/aws/repository/sqs_repository.go deleted file mode 100644 index 97900ed6..00000000 --- a/pkg/remote/aws/repository/sqs_repository.go +++ /dev/null @@ -1,72 +0,0 @@ -package repository - -import ( - "fmt" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sqs" - "github.com/aws/aws-sdk-go/service/sqs/sqsiface" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type SQSRepository interface { - ListAllQueues() ([]*string, error) - GetQueueAttributes(url string) (*sqs.GetQueueAttributesOutput, error) -} - -type sqsRepository struct { - client sqsiface.SQSAPI - cache cache.Cache -} - -func NewSQSRepository(session *session.Session, c cache.Cache) *sqsRepository { - return &sqsRepository{ - sqs.New(session), - c, - } -} - -func (r *sqsRepository) GetQueueAttributes(url string) (*sqs.GetQueueAttributesOutput, error) { - cacheKey := fmt.Sprintf("sqsGetQueueAttributes_%s", url) - if v := r.cache.Get(cacheKey); v != nil { - return v.(*sqs.GetQueueAttributesOutput), nil - } - - attributes, err := r.client.GetQueueAttributes(&sqs.GetQueueAttributesInput{ - AttributeNames: aws.StringSlice([]string{sqs.QueueAttributeNamePolicy}), - QueueUrl: &url, - }) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, attributes) - - return attributes, nil -} - -func (r *sqsRepository) ListAllQueues() ([]*string, error) { - - cacheKey := "sqsListAllQueues" - v := r.cache.GetAndLock(cacheKey) - defer r.cache.Unlock(cacheKey) - if v != nil { - return v.([]*string), nil - } - - var queues []*string - input := sqs.ListQueuesInput{} - err := r.client.ListQueuesPages(&input, - func(resp *sqs.ListQueuesOutput, lastPage bool) bool { - queues = append(queues, resp.QueueUrls...) - return !lastPage - }, - ) - if err != nil { - return nil, err - } - - r.cache.Put(cacheKey, queues) - return queues, nil -} diff --git a/pkg/remote/aws/repository/sqs_repository_test.go b/pkg/remote/aws/repository/sqs_repository_test.go deleted file mode 100644 index f7c11c8d..00000000 --- a/pkg/remote/aws/repository/sqs_repository_test.go +++ /dev/null @@ -1,145 +0,0 @@ -package repository - -import ( - "strings" - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/snyk/driftctl/pkg/remote/cache" - awstest "github.com/snyk/driftctl/test/aws" - - "github.com/aws/aws-sdk-go/service/sqs" - "github.com/r3labs/diff/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func Test_sqsRepository_ListAllQueues(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeSQS) - want []*string - wantErr error - }{ - { - name: "list with multiple pages", - mocks: func(client *awstest.MockFakeSQS) { - client.On("ListQueuesPages", - &sqs.ListQueuesInput{}, - mock.MatchedBy(func(callback func(res *sqs.ListQueuesOutput, lastPage bool) bool) bool { - callback(&sqs.ListQueuesOutput{ - QueueUrls: []*string{ - awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/bar.fifo"), - awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/foo"), - }, - }, false) - callback(&sqs.ListQueuesOutput{ - QueueUrls: []*string{ - awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/baz"), - }, - }, true) - return true - })).Return(nil).Once() - }, - want: []*string{ - awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/bar.fifo"), - awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/foo"), - awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/baz"), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeSQS{} - tt.mocks(client) - r := &sqsRepository{ - client: client, - cache: store, - } - got, err := r.ListAllQueues() - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.ListAllQueues() - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, []*string{}, store.Get("sqsListAllQueues")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} - -func Test_sqsRepository_GetQueueAttributes(t *testing.T) { - tests := []struct { - name string - mocks func(client *awstest.MockFakeSQS) - want *sqs.GetQueueAttributesOutput - wantErr error - }{ - { - name: "get attributes", - mocks: func(client *awstest.MockFakeSQS) { - client.On( - "GetQueueAttributes", - &sqs.GetQueueAttributesInput{ - AttributeNames: awssdk.StringSlice([]string{sqs.QueueAttributeNamePolicy}), - QueueUrl: awssdk.String("http://example.com"), - }, - ).Return( - &sqs.GetQueueAttributesOutput{ - Attributes: map[string]*string{ - sqs.QueueAttributeNamePolicy: awssdk.String("foobar"), - }, - }, - nil, - ).Once() - }, - want: &sqs.GetQueueAttributesOutput{ - Attributes: map[string]*string{ - sqs.QueueAttributeNamePolicy: awssdk.String("foobar"), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := cache.New(1) - client := &awstest.MockFakeSQS{} - tt.mocks(client) - r := &sqsRepository{ - client: client, - cache: store, - } - got, err := r.GetQueueAttributes("http://example.com") - assert.Equal(t, tt.wantErr, err) - - if err == nil { - // Check that results were cached - cachedData, err := r.GetQueueAttributes("http://example.com") - assert.NoError(t, err) - assert.Equal(t, got, cachedData) - assert.IsType(t, &sqs.GetQueueAttributesOutput{}, store.Get("sqsGetQueueAttributes_http://example.com")) - } - - changelog, err := diff.Diff(got, tt.want) - assert.Nil(t, err) - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To) - } - t.Fail() - } - }) - } -} diff --git a/pkg/remote/aws/route53_health_check_enumerator.go b/pkg/remote/aws/route53_health_check_enumerator.go deleted file mode 100644 index 16ec37f1..00000000 --- a/pkg/remote/aws/route53_health_check_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type Route53HealthCheckEnumerator struct { - repository repository.Route53Repository - factory resource.ResourceFactory -} - -func NewRoute53HealthCheckEnumerator(repo repository.Route53Repository, factory resource.ResourceFactory) *Route53HealthCheckEnumerator { - return &Route53HealthCheckEnumerator{ - repo, - factory, - } -} - -func (e *Route53HealthCheckEnumerator) SupportedType() resource.ResourceType { - return aws.AwsRoute53HealthCheckResourceType -} - -func (e *Route53HealthCheckEnumerator) Enumerate() ([]*resource.Resource, error) { - healthChecks, err := e.repository.ListAllHealthChecks() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(healthChecks)) - - for _, healthCheck := range healthChecks { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *healthCheck.Id, - map[string]interface{}{}, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/route53_record_enumerator.go b/pkg/remote/aws/route53_record_enumerator.go deleted file mode 100644 index 24d2ff1d..00000000 --- a/pkg/remote/aws/route53_record_enumerator.go +++ /dev/null @@ -1,101 +0,0 @@ -package aws - -import ( - "strconv" - "strings" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" -) - -type Route53RecordEnumerator struct { - client repository.Route53Repository - factory resource.ResourceFactory -} - -func NewRoute53RecordEnumerator(repo repository.Route53Repository, factory resource.ResourceFactory) *Route53RecordEnumerator { - return &Route53RecordEnumerator{ - repo, - factory, - } -} - -func (e *Route53RecordEnumerator) SupportedType() resource.ResourceType { - return resourceaws.AwsRoute53RecordResourceType -} - -func (e *Route53RecordEnumerator) Enumerate() ([]*resource.Resource, error) { - - zones, err := e.client.ListAllZones() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsRoute53ZoneResourceType) - } - - results := make([]*resource.Resource, 0, len(zones)) - - for _, hostedZone := range zones { - records, err := e.listRecordsForZone(strings.TrimPrefix(*hostedZone.Id, "/hostedzone/")) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results = append(results, records...) - } - - return results, err -} - -func (e *Route53RecordEnumerator) listRecordsForZone(zoneId string) ([]*resource.Resource, error) { - - records, err := e.client.ListRecordsForZone(zoneId) - if err != nil { - return nil, err - } - - results := make([]*resource.Resource, 0, len(records)) - - for _, raw := range records { - rawType := *raw.Type - rawName := *raw.Name - rawSetIdentifier := raw.SetIdentifier - - vars := []string{ - zoneId, - strings.ToLower(strings.TrimSuffix(rawName, ".")), - rawType, - } - if rawSetIdentifier != nil { - vars = append(vars, *rawSetIdentifier) - } - - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - e.cleanRecordName(strings.Join(vars, "_")), - map[string]interface{}{ - "type": rawType, - }, - ), - ) - } - - return results, nil -} - -// cleanRecordName -// Route 53 stores certain characters with the octal equivalent in ASCII format. -// This function converts all of these characters back into the original character. -// E.g. "*" is stored as "\\052" and "@" as "\\100" -func (e *Route53RecordEnumerator) cleanRecordName(name string) string { - str := name - s, err := strconv.Unquote(`"` + str + `"`) - if err != nil { - return str - } - return s -} diff --git a/pkg/remote/aws/route53_zone_enumerator.go b/pkg/remote/aws/route53_zone_enumerator.go deleted file mode 100644 index 2ecfde23..00000000 --- a/pkg/remote/aws/route53_zone_enumerator.go +++ /dev/null @@ -1,50 +0,0 @@ -package aws - -import ( - "strings" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" -) - -type Route53ZoneSupplier struct { - client repository.Route53Repository - factory resource.ResourceFactory -} - -func NewRoute53ZoneEnumerator(repo repository.Route53Repository, factory resource.ResourceFactory) *Route53ZoneSupplier { - return &Route53ZoneSupplier{ - repo, - factory, - } -} - -func (e *Route53ZoneSupplier) SupportedType() resource.ResourceType { - return resourceaws.AwsRoute53ZoneResourceType -} - -func (e *Route53ZoneSupplier) Enumerate() ([]*resource.Resource, error) { - zones, err := e.client.ListAllZones() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(zones)) - - for _, hostedZone := range zones { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - strings.TrimPrefix(*hostedZone.Id, "/hostedzone/"), - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/s3_bucket_analytic_enumerator.go b/pkg/remote/aws/s3_bucket_analytic_enumerator.go deleted file mode 100644 index 5e537085..00000000 --- a/pkg/remote/aws/s3_bucket_analytic_enumerator.go +++ /dev/null @@ -1,80 +0,0 @@ -package aws - -import ( - "fmt" - - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - tf "github.com/snyk/driftctl/pkg/remote/terraform" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type S3BucketAnalyticEnumerator struct { - repository repository.S3Repository - factory resource.ResourceFactory - providerConfig tf.TerraformProviderConfig - alerter alerter.AlerterInterface -} - -func NewS3BucketAnalyticEnumerator(repo repository.S3Repository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3BucketAnalyticEnumerator { - return &S3BucketAnalyticEnumerator{ - repository: repo, - factory: factory, - providerConfig: providerConfig, - alerter: alerter, - } -} - -func (e *S3BucketAnalyticEnumerator) SupportedType() resource.ResourceType { - return aws.AwsS3BucketAnalyticsConfigurationResourceType -} - -func (e *S3BucketAnalyticEnumerator) Enumerate() ([]*resource.Resource, error) { - buckets, err := e.repository.ListAllBuckets() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsS3BucketResourceType) - } - - results := make([]*resource.Resource, 0, len(buckets)) - - for _, bucket := range buckets { - region, err := e.repository.GetBucketLocation(*bucket.Name) - if err != nil { - alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) - continue - } - if region == "" || region != e.providerConfig.DefaultAlias { - logrus.WithFields(logrus.Fields{ - "region": region, - "bucket": *bucket.Name, - }).Debug("Skipped bucket analytic") - continue - } - - analyticsConfigurationList, err := e.repository.ListBucketAnalyticsConfigurations(bucket, region) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, analytics := range analyticsConfigurationList { - id := fmt.Sprintf("%s:%s", *bucket.Name, *analytics.Id) - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - id, - map[string]interface{}{ - "region": region, - }, - ), - ) - } - } - - return results, err -} diff --git a/pkg/remote/aws/s3_bucket_enumerator.go b/pkg/remote/aws/s3_bucket_enumerator.go deleted file mode 100644 index e245bf13..00000000 --- a/pkg/remote/aws/s3_bucket_enumerator.go +++ /dev/null @@ -1,69 +0,0 @@ -package aws - -import ( - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - tf "github.com/snyk/driftctl/pkg/remote/terraform" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type S3BucketEnumerator struct { - repository repository.S3Repository - factory resource.ResourceFactory - providerConfig tf.TerraformProviderConfig - alerter alerter.AlerterInterface -} - -func NewS3BucketEnumerator(repo repository.S3Repository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3BucketEnumerator { - return &S3BucketEnumerator{ - repository: repo, - factory: factory, - providerConfig: providerConfig, - alerter: alerter, - } -} - -func (e *S3BucketEnumerator) SupportedType() resource.ResourceType { - return aws.AwsS3BucketResourceType -} - -func (e *S3BucketEnumerator) Enumerate() ([]*resource.Resource, error) { - buckets, err := e.repository.ListAllBuckets() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(buckets)) - - for _, bucket := range buckets { - region, err := e.repository.GetBucketLocation(*bucket.Name) - if err != nil { - alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) - continue - } - if region == "" || region != e.providerConfig.DefaultAlias { - logrus.WithFields(logrus.Fields{ - "region": region, - "bucket": *bucket.Name, - }).Debug("Skipped bucket") - continue - } - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *bucket.Name, - map[string]interface{}{ - "region": region, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/s3_bucket_inventory_enumerator.go b/pkg/remote/aws/s3_bucket_inventory_enumerator.go deleted file mode 100644 index 35f6f83b..00000000 --- a/pkg/remote/aws/s3_bucket_inventory_enumerator.go +++ /dev/null @@ -1,81 +0,0 @@ -package aws - -import ( - "fmt" - - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - tf "github.com/snyk/driftctl/pkg/remote/terraform" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type S3BucketInventoryEnumerator struct { - repository repository.S3Repository - factory resource.ResourceFactory - providerConfig tf.TerraformProviderConfig - alerter alerter.AlerterInterface -} - -func NewS3BucketInventoryEnumerator(repo repository.S3Repository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3BucketInventoryEnumerator { - return &S3BucketInventoryEnumerator{ - repository: repo, - factory: factory, - providerConfig: providerConfig, - alerter: alerter, - } -} - -func (e *S3BucketInventoryEnumerator) SupportedType() resource.ResourceType { - return aws.AwsS3BucketInventoryResourceType -} - -func (e *S3BucketInventoryEnumerator) Enumerate() ([]*resource.Resource, error) { - buckets, err := e.repository.ListAllBuckets() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsS3BucketResourceType) - } - - results := make([]*resource.Resource, 0, len(buckets)) - - for _, bucket := range buckets { - region, err := e.repository.GetBucketLocation(*bucket.Name) - if err != nil { - alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) - continue - } - if region == "" || region != e.providerConfig.DefaultAlias { - logrus.WithFields(logrus.Fields{ - "region": region, - "bucket": *bucket.Name, - }).Debug("Skipped bucket inventory") - continue - } - - inventoryConfigurations, err := e.repository.ListBucketInventoryConfigurations(bucket, region) - if err != nil { - // TODO: we should think about a way to ignore just one bucket inventory listing - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, config := range inventoryConfigurations { - id := fmt.Sprintf("%s:%s", *bucket.Name, *config.Id) - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - id, - map[string]interface{}{ - "region": region, - }, - ), - ) - } - } - - return results, err -} diff --git a/pkg/remote/aws/s3_bucket_metrics_enumerator.go b/pkg/remote/aws/s3_bucket_metrics_enumerator.go deleted file mode 100644 index 6a23d1d2..00000000 --- a/pkg/remote/aws/s3_bucket_metrics_enumerator.go +++ /dev/null @@ -1,80 +0,0 @@ -package aws - -import ( - "fmt" - - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - tf "github.com/snyk/driftctl/pkg/remote/terraform" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type S3BucketMetricsEnumerator struct { - repository repository.S3Repository - factory resource.ResourceFactory - providerConfig tf.TerraformProviderConfig - alerter alerter.AlerterInterface -} - -func NewS3BucketMetricsEnumerator(repo repository.S3Repository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3BucketMetricsEnumerator { - return &S3BucketMetricsEnumerator{ - repository: repo, - factory: factory, - providerConfig: providerConfig, - alerter: alerter, - } -} - -func (e *S3BucketMetricsEnumerator) SupportedType() resource.ResourceType { - return aws.AwsS3BucketMetricResourceType -} - -func (e *S3BucketMetricsEnumerator) Enumerate() ([]*resource.Resource, error) { - buckets, err := e.repository.ListAllBuckets() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsS3BucketResourceType) - } - - results := make([]*resource.Resource, 0, len(buckets)) - - for _, bucket := range buckets { - region, err := e.repository.GetBucketLocation(*bucket.Name) - if err != nil { - alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) - continue - } - if region == "" || region != e.providerConfig.DefaultAlias { - logrus.WithFields(logrus.Fields{ - "region": region, - "bucket": *bucket.Name, - }).Debug("Skipped bucket") - continue - } - - metricsConfigurationList, err := e.repository.ListBucketMetricsConfigurations(bucket, region) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, metric := range metricsConfigurationList { - id := fmt.Sprintf("%s:%s", *bucket.Name, *metric.Id) - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - id, - map[string]interface{}{ - "region": region, - }, - ), - ) - } - } - - return results, nil -} diff --git a/pkg/remote/aws/s3_bucket_notification_enumerator.go b/pkg/remote/aws/s3_bucket_notification_enumerator.go deleted file mode 100644 index 7e84daec..00000000 --- a/pkg/remote/aws/s3_bucket_notification_enumerator.go +++ /dev/null @@ -1,84 +0,0 @@ -package aws - -import ( - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - tf "github.com/snyk/driftctl/pkg/remote/terraform" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type S3BucketNotificationEnumerator struct { - repository repository.S3Repository - factory resource.ResourceFactory - providerConfig tf.TerraformProviderConfig - alerter alerter.AlerterInterface -} - -func NewS3BucketNotificationEnumerator(repo repository.S3Repository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3BucketNotificationEnumerator { - return &S3BucketNotificationEnumerator{ - repository: repo, - factory: factory, - providerConfig: providerConfig, - alerter: alerter, - } -} - -func (e *S3BucketNotificationEnumerator) SupportedType() resource.ResourceType { - return aws.AwsS3BucketNotificationResourceType -} - -func (e *S3BucketNotificationEnumerator) Enumerate() ([]*resource.Resource, error) { - buckets, err := e.repository.ListAllBuckets() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsS3BucketResourceType) - } - - results := make([]*resource.Resource, 0, len(buckets)) - - for _, bucket := range buckets { - region, err := e.repository.GetBucketLocation(*bucket.Name) - if err != nil { - alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) - continue - } - if region == "" || region != e.providerConfig.DefaultAlias { - logrus.WithFields(logrus.Fields{ - "region": region, - "bucket": *bucket.Name, - }).Debug("Skipped bucket") - continue - } - - notification, err := e.repository.GetBucketNotification(*bucket.Name, region) - if err != nil { - alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) - continue - } - - if notification == nil { - logrus.WithFields(logrus.Fields{ - "region": region, - "bucket": *bucket.Name, - }).Debug("Skipped empty bucket notification") - continue - } - - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *bucket.Name, - map[string]interface{}{ - "region": region, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/s3_bucket_policy_enumerator.go b/pkg/remote/aws/s3_bucket_policy_enumerator.go deleted file mode 100644 index c5548057..00000000 --- a/pkg/remote/aws/s3_bucket_policy_enumerator.go +++ /dev/null @@ -1,78 +0,0 @@ -package aws - -import ( - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - tf "github.com/snyk/driftctl/pkg/remote/terraform" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type S3BucketPolicyEnumerator struct { - repository repository.S3Repository - factory resource.ResourceFactory - providerConfig tf.TerraformProviderConfig - alerter alerter.AlerterInterface -} - -func NewS3BucketPolicyEnumerator(repo repository.S3Repository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3BucketPolicyEnumerator { - return &S3BucketPolicyEnumerator{ - repository: repo, - factory: factory, - providerConfig: providerConfig, - alerter: alerter, - } -} - -func (e *S3BucketPolicyEnumerator) SupportedType() resource.ResourceType { - return aws.AwsS3BucketPolicyResourceType -} - -func (e *S3BucketPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { - buckets, err := e.repository.ListAllBuckets() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsS3BucketResourceType) - } - - results := make([]*resource.Resource, 0, len(buckets)) - - for _, bucket := range buckets { - region, err := e.repository.GetBucketLocation(*bucket.Name) - if err != nil { - alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) - continue - } - if region == "" || region != e.providerConfig.DefaultAlias { - logrus.WithFields(logrus.Fields{ - "region": region, - "bucket": *bucket.Name, - }).Debug("Skipped bucket policy") - continue - } - - policy, err := e.repository.GetBucketPolicy(*bucket.Name, region) - if err != nil { - alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) - continue - } - - if policy != nil { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *bucket.Name, - map[string]interface{}{ - "region": region, - }, - ), - ) - } - } - - return results, err -} diff --git a/pkg/remote/aws/s3_bucket_public_access_block_enumerator.go b/pkg/remote/aws/s3_bucket_public_access_block_enumerator.go deleted file mode 100644 index f61cfc03..00000000 --- a/pkg/remote/aws/s3_bucket_public_access_block_enumerator.go +++ /dev/null @@ -1,82 +0,0 @@ -package aws - -import ( - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - tf "github.com/snyk/driftctl/pkg/remote/terraform" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type S3BucketPublicAccessBlockEnumerator struct { - repository repository.S3Repository - factory resource.ResourceFactory - providerConfig tf.TerraformProviderConfig - alerter alerter.AlerterInterface -} - -func NewS3BucketPublicAccessBlockEnumerator(repo repository.S3Repository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3BucketPublicAccessBlockEnumerator { - return &S3BucketPublicAccessBlockEnumerator{ - repository: repo, - factory: factory, - providerConfig: providerConfig, - alerter: alerter, - } -} - -func (e *S3BucketPublicAccessBlockEnumerator) SupportedType() resource.ResourceType { - return aws.AwsS3BucketPublicAccessBlockResourceType -} - -func (e *S3BucketPublicAccessBlockEnumerator) Enumerate() ([]*resource.Resource, error) { - buckets, err := e.repository.ListAllBuckets() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsS3BucketResourceType) - } - - results := make([]*resource.Resource, 0, len(buckets)) - - for _, bucket := range buckets { - region, err := e.repository.GetBucketLocation(*bucket.Name) - if err != nil { - alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) - continue - } - if region == "" || region != e.providerConfig.DefaultAlias { - logrus.WithFields(logrus.Fields{ - "region": region, - "bucket": *bucket.Name, - }).Debug("Skipped bucket public access block") - continue - } - - block, err := e.repository.GetBucketPublicAccessBlock(*bucket.Name, region) - if err != nil { - alerts.SendEnumerationAlert(common.RemoteAWSTerraform, e.alerter, remoteerror.NewResourceScanningError(err, string(e.SupportedType()), *bucket.Name)) - continue - } - - if block != nil { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *bucket.Name, - map[string]interface{}{ - "block_public_acls": awssdk.BoolValue(block.BlockPublicAcls), - "block_public_policy": awssdk.BoolValue(block.BlockPublicPolicy), - "ignore_public_acls": awssdk.BoolValue(block.IgnorePublicAcls), - "restrict_public_buckets": awssdk.BoolValue(block.RestrictPublicBuckets), - }, - ), - ) - } - } - - return results, err -} diff --git a/pkg/remote/aws/sns_topic_enumerator.go b/pkg/remote/aws/sns_topic_enumerator.go deleted file mode 100644 index 89e5ecfb..00000000 --- a/pkg/remote/aws/sns_topic_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type SNSTopicEnumerator struct { - repository repository.SNSRepository - factory resource.ResourceFactory -} - -func NewSNSTopicEnumerator(repo repository.SNSRepository, factory resource.ResourceFactory) *SNSTopicEnumerator { - return &SNSTopicEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *SNSTopicEnumerator) SupportedType() resource.ResourceType { - return aws.AwsSnsTopicResourceType -} - -func (e *SNSTopicEnumerator) Enumerate() ([]*resource.Resource, error) { - topics, err := e.repository.ListAllTopics() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(topics)) - - for _, topic := range topics { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *topic.TopicArn, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/sns_topic_policy_enumerator.go b/pkg/remote/aws/sns_topic_policy_enumerator.go deleted file mode 100644 index c5e7b2ad..00000000 --- a/pkg/remote/aws/sns_topic_policy_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type SNSTopicPolicyEnumerator struct { - repository repository.SNSRepository - factory resource.ResourceFactory -} - -func NewSNSTopicPolicyEnumerator(repo repository.SNSRepository, factory resource.ResourceFactory) *SNSTopicPolicyEnumerator { - return &SNSTopicPolicyEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *SNSTopicPolicyEnumerator) SupportedType() resource.ResourceType { - return aws.AwsSnsTopicPolicyResourceType -} - -func (e *SNSTopicPolicyEnumerator) Enumerate() ([]*resource.Resource, error) { - topics, err := e.repository.ListAllTopics() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsSnsTopicResourceType) - } - - results := make([]*resource.Resource, 0, len(topics)) - - for _, topic := range topics { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *topic.TopicArn, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/sns_topic_subscription_enumerator.go b/pkg/remote/aws/sns_topic_subscription_enumerator.go deleted file mode 100644 index ce7fe078..00000000 --- a/pkg/remote/aws/sns_topic_subscription_enumerator.go +++ /dev/null @@ -1,85 +0,0 @@ -package aws - -import ( - "fmt" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" -) - -type wrongArnTopicAlert struct { - arn string - endpoint *string -} - -func NewWrongArnTopicAlert(arn string, endpoint *string) *wrongArnTopicAlert { - return &wrongArnTopicAlert{arn: arn, endpoint: endpoint} -} - -func (p *wrongArnTopicAlert) Message() string { - return fmt.Sprintf("%s with incorrect subscription arn (%s) for endpoint \"%s\" will be ignored", - aws.AwsSnsTopicSubscriptionResourceType, - p.arn, - awssdk.StringValue(p.endpoint)) -} - -func (p *wrongArnTopicAlert) ShouldIgnoreResource() bool { - return false -} - -type SNSTopicSubscriptionEnumerator struct { - repository repository.SNSRepository - factory resource.ResourceFactory - alerter alerter.AlerterInterface -} - -func NewSNSTopicSubscriptionEnumerator( - repo repository.SNSRepository, - factory resource.ResourceFactory, - alerter alerter.AlerterInterface, -) *SNSTopicSubscriptionEnumerator { - return &SNSTopicSubscriptionEnumerator{ - repo, - factory, - alerter, - } -} - -func (e *SNSTopicSubscriptionEnumerator) SupportedType() resource.ResourceType { - return aws.AwsSnsTopicSubscriptionResourceType -} - -func (e *SNSTopicSubscriptionEnumerator) Enumerate() ([]*resource.Resource, error) { - allSubscriptions, err := e.repository.ListAllSubscriptions() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(allSubscriptions)) - - for _, subscription := range allSubscriptions { - if subscription.SubscriptionArn == nil || !arn.IsARN(*subscription.SubscriptionArn) { - e.alerter.SendAlert( - fmt.Sprintf("%s.%s", e.SupportedType(), *subscription.SubscriptionArn), - NewWrongArnTopicAlert(*subscription.SubscriptionArn, subscription.Endpoint), - ) - continue - } - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *subscription.SubscriptionArn, - map[string]interface{}{}, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/sqs_queue_details_fetcher.go b/pkg/remote/aws/sqs_queue_details_fetcher.go deleted file mode 100644 index ed267409..00000000 --- a/pkg/remote/aws/sqs_queue_details_fetcher.go +++ /dev/null @@ -1,47 +0,0 @@ -package aws - -import ( - "strings" - - "github.com/sirupsen/logrus" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" -) - -type SQSQueueDetailsFetcher struct { - reader terraform.ResourceReader - deserializer *resource.Deserializer -} - -func NewSQSQueueDetailsFetcher(provider terraform.ResourceReader, deserializer *resource.Deserializer) *SQSQueueDetailsFetcher { - return &SQSQueueDetailsFetcher{ - reader: provider, - deserializer: deserializer, - } -} - -func (r *SQSQueueDetailsFetcher) ReadDetails(res *resource.Resource) (*resource.Resource, error) { - ctyVal, err := r.reader.ReadResource(terraform.ReadResourceArgs{ - ID: res.ResourceId(), - Ty: aws.AwsSqsQueueResourceType, - }) - if err != nil { - if strings.Contains(err.Error(), "NonExistentQueue") { - logrus.WithFields(logrus.Fields{ - "id": res.ResourceId(), - "type": aws.AwsSqsQueueResourceType, - }).Debugf("Ignoring queue that seems to be already deleted: %+v", err) - return nil, nil - } - logrus.Error(err) - return nil, remoteerror.NewResourceScanningError(err, res.ResourceType(), res.ResourceId()) - } - deserializedRes, err := r.deserializer.DeserializeOne(aws.AwsSqsQueueResourceType, *ctyVal) - if err != nil { - return nil, err - } - - return deserializedRes, nil -} diff --git a/pkg/remote/aws/sqs_queue_enumerator.go b/pkg/remote/aws/sqs_queue_enumerator.go deleted file mode 100644 index de2a70ae..00000000 --- a/pkg/remote/aws/sqs_queue_enumerator.go +++ /dev/null @@ -1,48 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - - awssdk "github.com/aws/aws-sdk-go/aws" -) - -type SQSQueueEnumerator struct { - repository repository.SQSRepository - factory resource.ResourceFactory -} - -func NewSQSQueueEnumerator(repo repository.SQSRepository, factory resource.ResourceFactory) *SQSQueueEnumerator { - return &SQSQueueEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *SQSQueueEnumerator) SupportedType() resource.ResourceType { - return aws.AwsSqsQueueResourceType -} - -func (e *SQSQueueEnumerator) Enumerate() ([]*resource.Resource, error) { - queues, err := e.repository.ListAllQueues() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(queues)) - - for _, queue := range queues { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - awssdk.StringValue(queue), - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/sqs_queue_policy_enumerator.go b/pkg/remote/aws/sqs_queue_policy_enumerator.go deleted file mode 100644 index b60828c7..00000000 --- a/pkg/remote/aws/sqs_queue_policy_enumerator.go +++ /dev/null @@ -1,69 +0,0 @@ -package aws - -import ( - "strings" - - "github.com/aws/aws-sdk-go/service/sqs" - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/aws" - - awssdk "github.com/aws/aws-sdk-go/aws" -) - -type SQSQueuePolicyEnumerator struct { - repository repository.SQSRepository - factory resource.ResourceFactory -} - -func NewSQSQueuePolicyEnumerator(repo repository.SQSRepository, factory resource.ResourceFactory) *SQSQueuePolicyEnumerator { - return &SQSQueuePolicyEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *SQSQueuePolicyEnumerator) SupportedType() resource.ResourceType { - return aws.AwsSqsQueuePolicyResourceType -} - -func (e *SQSQueuePolicyEnumerator) Enumerate() ([]*resource.Resource, error) { - queues, err := e.repository.ListAllQueues() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsSqsQueueResourceType) - } - - results := make([]*resource.Resource, 0, len(queues)) - - for _, queue := range queues { - attrs := map[string]interface{}{ - "policy": "", - } - attributes, err := e.repository.GetQueueAttributes(*queue) - if err != nil { - if strings.Contains(err.Error(), "NonExistentQueue") { - logrus.WithFields(logrus.Fields{ - "queue": *queue, - "type": aws.AwsSqsQueueResourceType, - }).Debugf("Ignoring queue that seems to be already deleted: %+v", err) - continue - } - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - if attributes.Attributes != nil { - attrs["policy"] = *attributes.Attributes[sqs.QueueAttributeNamePolicy] - } - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - awssdk.StringValue(queue), - attrs, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/aws/vpc_default_security_group_enumerator.go b/pkg/remote/aws/vpc_default_security_group_enumerator.go deleted file mode 100644 index 5e36f9f6..00000000 --- a/pkg/remote/aws/vpc_default_security_group_enumerator.go +++ /dev/null @@ -1,49 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - - "github.com/aws/aws-sdk-go/aws" -) - -type VPCDefaultSecurityGroupEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewVPCDefaultSecurityGroupEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *VPCDefaultSecurityGroupEnumerator { - return &VPCDefaultSecurityGroupEnumerator{ - repo, - factory, - } -} - -func (e *VPCDefaultSecurityGroupEnumerator) SupportedType() resource.ResourceType { - return resourceaws.AwsDefaultSecurityGroupResourceType -} - -func (e *VPCDefaultSecurityGroupEnumerator) Enumerate() ([]*resource.Resource, error) { - _, defaultSecurityGroups, err := e.repository.ListAllSecurityGroups() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(defaultSecurityGroups)) - - for _, item := range defaultSecurityGroups { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - aws.StringValue(item.GroupId), - map[string]interface{}{}, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/vpc_enumerator.go b/pkg/remote/aws/vpc_enumerator.go deleted file mode 100644 index 59051808..00000000 --- a/pkg/remote/aws/vpc_enumerator.go +++ /dev/null @@ -1,48 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - - "github.com/snyk/driftctl/pkg/resource/aws" - - "github.com/snyk/driftctl/pkg/resource" -) - -type VPCEnumerator struct { - repo repository.EC2Repository - factory resource.ResourceFactory -} - -func NewVPCEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *VPCEnumerator { - return &VPCEnumerator{ - repo, - factory, - } -} - -func (e *VPCEnumerator) SupportedType() resource.ResourceType { - return aws.AwsVpcResourceType -} - -func (e *VPCEnumerator) Enumerate() ([]*resource.Resource, error) { - VPCs, _, err := e.repo.ListAllVPCs() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(VPCs)) - - for _, item := range VPCs { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *item.VpcId, - map[string]interface{}{}, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/vpc_security_group_enumerator.go b/pkg/remote/aws/vpc_security_group_enumerator.go deleted file mode 100644 index be93e369..00000000 --- a/pkg/remote/aws/vpc_security_group_enumerator.go +++ /dev/null @@ -1,49 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - - "github.com/aws/aws-sdk-go/aws" -) - -type VPCSecurityGroupEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -func NewVPCSecurityGroupEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *VPCSecurityGroupEnumerator { - return &VPCSecurityGroupEnumerator{ - repo, - factory, - } -} - -func (e *VPCSecurityGroupEnumerator) SupportedType() resource.ResourceType { - return resourceaws.AwsSecurityGroupResourceType -} - -func (e *VPCSecurityGroupEnumerator) Enumerate() ([]*resource.Resource, error) { - securityGroups, _, err := e.repository.ListAllSecurityGroups() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(securityGroups)) - - for _, item := range securityGroups { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - aws.StringValue(item.GroupId), - map[string]interface{}{}, - ), - ) - } - - return results, nil -} diff --git a/pkg/remote/aws/vpc_security_group_rule_enumerator.go b/pkg/remote/aws/vpc_security_group_rule_enumerator.go deleted file mode 100644 index ebd80567..00000000 --- a/pkg/remote/aws/vpc_security_group_rule_enumerator.go +++ /dev/null @@ -1,170 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/remote/aws/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" -) - -const ( - sgRuleTypeIngress = "ingress" - sgRuleTypeEgress = "egress" -) - -type VPCSecurityGroupRuleEnumerator struct { - repository repository.EC2Repository - factory resource.ResourceFactory -} - -type securityGroupRule struct { - Type string - SecurityGroupId string - Protocol string - FromPort float64 - ToPort float64 - Self bool - SourceSecurityGroupId string - CidrBlocks []string - Ipv6CidrBlocks []string - PrefixListIds []string -} - -func (s *securityGroupRule) getId() string { - attrs := s.getAttrs() - return resourceaws.CreateSecurityGroupRuleIdHash(&attrs) -} - -func (s *securityGroupRule) getAttrs() resource.Attributes { - attrs := resource.Attributes{ - "type": s.Type, - "security_group_id": s.SecurityGroupId, - "protocol": s.Protocol, - "from_port": s.FromPort, - "to_port": s.ToPort, - "self": s.Self, - "source_security_group_id": s.SourceSecurityGroupId, - "cidr_blocks": toInterfaceSlice(s.CidrBlocks), - "ipv6_cidr_blocks": toInterfaceSlice(s.Ipv6CidrBlocks), - "prefix_list_ids": toInterfaceSlice(s.PrefixListIds), - } - - return attrs -} - -func toInterfaceSlice(val []string) []interface{} { - var res []interface{} - for _, v := range val { - res = append(res, v) - } - return res -} - -func NewVPCSecurityGroupRuleEnumerator(repository repository.EC2Repository, factory resource.ResourceFactory) *VPCSecurityGroupRuleEnumerator { - return &VPCSecurityGroupRuleEnumerator{ - repository, - factory, - } -} - -func (e *VPCSecurityGroupRuleEnumerator) SupportedType() resource.ResourceType { - return resourceaws.AwsSecurityGroupRuleResourceType -} - -func (e *VPCSecurityGroupRuleEnumerator) Enumerate() ([]*resource.Resource, error) { - securityGroups, defaultSecurityGroups, err := e.repository.ListAllSecurityGroups() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsSecurityGroupResourceType) - } - - secGroups := make([]*ec2.SecurityGroup, 0, len(securityGroups)+len(defaultSecurityGroups)) - secGroups = append(secGroups, securityGroups...) - secGroups = append(secGroups, defaultSecurityGroups...) - securityGroupsRules := e.listSecurityGroupsRules(secGroups) - - results := make([]*resource.Resource, 0, len(securityGroupsRules)) - for _, rule := range securityGroupsRules { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - rule.getId(), - rule.getAttrs(), - ), - ) - } - - return results, nil -} - -func (e *VPCSecurityGroupRuleEnumerator) listSecurityGroupsRules(securityGroups []*ec2.SecurityGroup) []securityGroupRule { - var securityGroupsRules []securityGroupRule - for _, sg := range securityGroups { - for _, rule := range sg.IpPermissions { - securityGroupsRules = append(securityGroupsRules, e.addSecurityGroupRule(sgRuleTypeIngress, rule, sg)...) - } - for _, rule := range sg.IpPermissionsEgress { - securityGroupsRules = append(securityGroupsRules, e.addSecurityGroupRule(sgRuleTypeEgress, rule, sg)...) - } - } - return securityGroupsRules -} - -// addSecurityGroupRule will iterate through each "Source" as per Aws definition and create a -// rule with custom attributes -func (e *VPCSecurityGroupRuleEnumerator) addSecurityGroupRule(ruleType string, rule *ec2.IpPermission, sg *ec2.SecurityGroup) []securityGroupRule { - var rules []securityGroupRule - for _, groupPair := range rule.UserIdGroupPairs { - r := securityGroupRule{ - Type: ruleType, - SecurityGroupId: aws.StringValue(sg.GroupId), - Protocol: aws.StringValue(rule.IpProtocol), - FromPort: float64(aws.Int64Value(rule.FromPort)), - ToPort: float64(aws.Int64Value(rule.ToPort)), - } - if aws.StringValue(groupPair.GroupId) == aws.StringValue(sg.GroupId) { - r.Self = true - } else { - r.SourceSecurityGroupId = aws.StringValue(groupPair.GroupId) - } - rules = append(rules, r) - } - for _, ipRange := range rule.IpRanges { - r := securityGroupRule{ - Type: ruleType, - SecurityGroupId: aws.StringValue(sg.GroupId), - Protocol: aws.StringValue(rule.IpProtocol), - FromPort: float64(aws.Int64Value(rule.FromPort)), - ToPort: float64(aws.Int64Value(rule.ToPort)), - CidrBlocks: []string{aws.StringValue(ipRange.CidrIp)}, - } - rules = append(rules, r) - } - for _, ipRange := range rule.Ipv6Ranges { - r := securityGroupRule{ - Type: ruleType, - SecurityGroupId: aws.StringValue(sg.GroupId), - Protocol: aws.StringValue(rule.IpProtocol), - FromPort: float64(aws.Int64Value(rule.FromPort)), - ToPort: float64(aws.Int64Value(rule.ToPort)), - Ipv6CidrBlocks: []string{aws.StringValue(ipRange.CidrIpv6)}, - } - rules = append(rules, r) - } - for _, listId := range rule.PrefixListIds { - r := securityGroupRule{ - Type: ruleType, - SecurityGroupId: aws.StringValue(sg.GroupId), - Protocol: aws.StringValue(rule.IpProtocol), - FromPort: float64(aws.Int64Value(rule.FromPort)), - ToPort: float64(aws.Int64Value(rule.ToPort)), - PrefixListIds: []string{aws.StringValue(listId.PrefixListId)}, - } - rules = append(rules, r) - } - return rules -} diff --git a/pkg/remote/aws_api_gateway_scanner_test.go b/pkg/remote/aws_api_gateway_scanner_test.go deleted file mode 100644 index 9efef0a0..00000000 --- a/pkg/remote/aws_api_gateway_scanner_test.go +++ /dev/null @@ -1,1726 +0,0 @@ -package remote - -import ( - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/apigateway" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test/remote" - testresource "github.com/snyk/driftctl/test/resource" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestApiGatewayRestApi(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway rest apis", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRestApis").Return([]*apigateway.RestApi{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway rest apis", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRestApis").Return([]*apigateway.RestApi{ - {Id: awssdk.String("3of73v5ob4")}, - {Id: awssdk.String("1jitcobwol")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "3of73v5ob4") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayRestApiResourceType) - - assert.Equal(t, got[1].ResourceId(), "1jitcobwol") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayRestApiResourceType) - }, - }, - { - test: "cannot list api gateway rest apis", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRestApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayRestApiResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayRestApiResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayRestApiResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayRestApiEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayAccount(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway account", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("GetAccount").Return(nil, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "empty api gateway account", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("GetAccount").Return(&apigateway.Account{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - - assert.Equal(t, got[0].ResourceId(), "api-gateway-account") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayAccountResourceType) - }, - }, - { - test: "cannot get api gateway account", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("GetAccount").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayAccountResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayAccountResourceType, resourceaws.AwsApiGatewayAccountResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayAccountResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayAccountEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayApiKey(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway api keys", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApiKeys").Return([]*apigateway.ApiKey{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway api keys", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApiKeys").Return([]*apigateway.ApiKey{ - {Id: awssdk.String("fuwnl8lrva")}, - {Id: awssdk.String("9ge737dd45")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "fuwnl8lrva") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayApiKeyResourceType) - - assert.Equal(t, got[1].ResourceId(), "9ge737dd45") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayApiKeyResourceType) - }, - }, - { - test: "cannot list api gateway api keys", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApiKeys").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayApiKeyResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayApiKeyResourceType, resourceaws.AwsApiGatewayApiKeyResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayApiKeyResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayApiKeyEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayAuthorizer(t *testing.T) { - dummyError := errors.New("this is an error") - apis := []*apigateway.RestApi{ - {Id: awssdk.String("3of73v5ob4")}, - {Id: awssdk.String("1jitcobwol")}, - } - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway authorizers", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiAuthorizers", *apis[0].Id).Return([]*apigateway.Authorizer{}, nil).Once() - repo.On("ListAllRestApiAuthorizers", *apis[1].Id).Return([]*apigateway.Authorizer{}, nil).Once() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway authorizers", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiAuthorizers", *apis[0].Id).Return([]*apigateway.Authorizer{ - {Id: awssdk.String("ypcpde")}, - }, nil).Once() - repo.On("ListAllRestApiAuthorizers", *apis[1].Id).Return([]*apigateway.Authorizer{ - {Id: awssdk.String("bwhebj")}, - }, nil).Once() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "ypcpde") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayAuthorizerResourceType) - - assert.Equal(t, got[1].ResourceId(), "bwhebj") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayAuthorizerResourceType) - }, - }, - { - test: "cannot list rest apis", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayAuthorizerResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayAuthorizerResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayAuthorizerResourceType, resourceaws.AwsApiGatewayRestApiResourceType), - }, - { - test: "cannot list api gateway resources", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiAuthorizers", *apis[0].Id).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayAuthorizerResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayAuthorizerResourceType, resourceaws.AwsApiGatewayAuthorizerResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayAuthorizerResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayAuthorizerEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayStage(t *testing.T) { - dummyError := errors.New("this is an error") - apis := []*apigateway.RestApi{ - {Id: awssdk.String("3of73v5ob4")}, - } - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway stages", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiStages", *apis[0].Id).Return([]*apigateway.Stage{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway stages", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiStages", *apis[0].Id).Return([]*apigateway.Stage{ - {StageName: awssdk.String("foo")}, - {StageName: awssdk.String("baz")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "ags-3of73v5ob4-foo") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayStageResourceType) - - assert.Equal(t, got[1].ResourceId(), "ags-3of73v5ob4-baz") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayStageResourceType) - }, - }, - { - test: "cannot list rest apis", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayStageResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayStageResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayStageResourceType, resourceaws.AwsApiGatewayRestApiResourceType), - }, - { - test: "cannot list api gateway stages", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiStages", *apis[0].Id).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayStageResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayStageResourceType, resourceaws.AwsApiGatewayStageResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayStageResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayStageEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayResource(t *testing.T) { - dummyError := errors.New("this is an error") - apis := []*apigateway.RestApi{ - {Id: awssdk.String("3of73v5ob4")}, - } - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway resources", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway resources", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ - {Id: awssdk.String("21zk4y"), Path: awssdk.String("/")}, - {Id: awssdk.String("2ltv32p058"), Path: awssdk.String("/")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "21zk4y") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayResourceResourceType) - - assert.Equal(t, got[1].ResourceId(), "2ltv32p058") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayResourceResourceType) - }, - }, - { - test: "cannot list rest apis", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayResourceResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayResourceResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayResourceResourceType, resourceaws.AwsApiGatewayRestApiResourceType), - }, - { - test: "cannot list api gateway resources", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiResources", *apis[0].Id).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayResourceResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayResourceResourceType, resourceaws.AwsApiGatewayResourceResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayResourceResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayResourceEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayDomainName(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway domain names", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllDomainNames").Return([]*apigateway.DomainName{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "single api gateway domain name", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllDomainNames").Return([]*apigateway.DomainName{ - {DomainName: awssdk.String("example-driftctl.com")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - - assert.Equal(t, got[0].ResourceId(), "example-driftctl.com") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayDomainNameResourceType) - }, - }, - { - test: "cannot list api gateway domain names", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllDomainNames").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayDomainNameResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayDomainNameResourceType, resourceaws.AwsApiGatewayDomainNameResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayDomainNameResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayDomainNameEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayVpcLink(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway vpc links", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllVpcLinks").Return([]*apigateway.UpdateVpcLinkOutput{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "single api gateway vpc link", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllVpcLinks").Return([]*apigateway.UpdateVpcLinkOutput{ - {Id: awssdk.String("ipu24n")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - - assert.Equal(t, got[0].ResourceId(), "ipu24n") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayVpcLinkResourceType) - }, - }, - { - test: "cannot list api gateway vpc links", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllVpcLinks").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayVpcLinkResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayVpcLinkResourceType, resourceaws.AwsApiGatewayVpcLinkResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayVpcLinkResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayVpcLinkEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayRequestValidator(t *testing.T) { - dummyError := errors.New("this is an error") - apis := []*apigateway.RestApi{ - {Id: awssdk.String("vryjzimtj1")}, - } - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway request validators", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiRequestValidators", *apis[0].Id).Return([]*apigateway.UpdateRequestValidatorOutput{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway request validators", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiRequestValidators", *apis[0].Id).Return([]*apigateway.UpdateRequestValidatorOutput{ - {Id: awssdk.String("ywlcuf")}, - {Id: awssdk.String("qmpbs8")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "ywlcuf") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayRequestValidatorResourceType) - - assert.Equal(t, got[1].ResourceId(), "qmpbs8") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayRequestValidatorResourceType) - }, - }, - { - test: "cannot list rest apis", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayRequestValidatorResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayRequestValidatorResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayRequestValidatorResourceType, resourceaws.AwsApiGatewayRestApiResourceType), - }, - { - test: "cannot list api gateway request validators", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiRequestValidators", *apis[0].Id).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayRequestValidatorResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayRequestValidatorResourceType, resourceaws.AwsApiGatewayRequestValidatorResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayRequestValidatorResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayRequestValidatorEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayRestApiPolicy(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway rest api policies", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRestApis").Return([]*apigateway.RestApi{ - {Id: awssdk.String("3of73v5ob4")}, - {Id: awssdk.String("9x7kq9pbyh"), Policy: awssdk.String("")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway rest api policies", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRestApis").Return([]*apigateway.RestApi{ - {Id: awssdk.String("c3n3aqga5d"), Policy: awssdk.String("{\"Version\":\"2012-10-17\",\"Statement\":[{\"Effect\":\"Allow\",\"Principal\":{\"AWS\":\"*\"},\"Action\":\"execute-api:Invoke\",\"Resource\":\"arn:aws:execute-api:us-east-1:111111111111:c3n3aqga5d/*\",\"Condition\":{\"IpAddress\":{\"aws:SourceIp\":\"123.123.123.123/32\"}}}]}")}, - {Id: awssdk.String("9y1eus3hr7"), Policy: awssdk.String("{\"Version\":\"2012-10-17\",\"Statement\":[{\"Effect\":\"Allow\",\"Principal\":{\"AWS\":\"*\"},\"Action\":\"execute-api:Invoke\",\"Resource\":\"arn:aws:execute-api:us-east-1:111111111111:9y1eus3hr7/*\",\"Condition\":{\"IpAddress\":{\"aws:SourceIp\":\"123.123.123.123/32\"}}}]}")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "c3n3aqga5d") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayRestApiPolicyResourceType) - - assert.Equal(t, got[1].ResourceId(), "9y1eus3hr7") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayRestApiPolicyResourceType) - }, - }, - { - test: "cannot list rest apis", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayRestApiPolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayRestApiPolicyResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayRestApiPolicyResourceType, resourceaws.AwsApiGatewayRestApiResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayRestApiPolicyEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayBasePathMapping(t *testing.T) { - dummyError := errors.New("this is an error") - domainNames := []*apigateway.DomainName{ - {DomainName: awssdk.String("example-driftctl.com")}, - } - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no domain name base path mappings", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllDomainNames").Return(domainNames, nil) - repo.On("ListAllDomainNameBasePathMappings", *domainNames[0].DomainName).Return([]*apigateway.BasePathMapping{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple domain name base path mappings", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllDomainNames").Return(domainNames, nil) - repo.On("ListAllDomainNameBasePathMappings", *domainNames[0].DomainName).Return([]*apigateway.BasePathMapping{ - {BasePath: awssdk.String("foo")}, - {BasePath: awssdk.String("(none)")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "example-driftctl.com/foo") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayBasePathMappingResourceType) - - assert.Equal(t, got[1].ResourceId(), "example-driftctl.com/") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayBasePathMappingResourceType) - }, - }, - { - test: "cannot list domain names", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllDomainNames").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayBasePathMappingResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayBasePathMappingResourceType, resourceaws.AwsApiGatewayDomainNameResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayBasePathMappingResourceType, resourceaws.AwsApiGatewayDomainNameResourceType), - }, - { - test: "cannot list domain name base path mappings", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllDomainNames").Return(domainNames, nil) - repo.On("ListAllDomainNameBasePathMappings", *domainNames[0].DomainName).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayBasePathMappingResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayBasePathMappingResourceType, resourceaws.AwsApiGatewayBasePathMappingResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayBasePathMappingResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayBasePathMappingEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayMethod(t *testing.T) { - dummyError := errors.New("this is an error") - apis := []*apigateway.RestApi{ - {Id: awssdk.String("vryjzimtj1")}, - } - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway methods", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ - {Id: awssdk.String("hl7ksq"), Path: awssdk.String("/foo")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway methods", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ - {Id: awssdk.String("hl7ksq"), Path: awssdk.String("/foo"), ResourceMethods: map[string]*apigateway.Method{ - "GET": {}, - "POST": {}, - "DELETE": {}, - }}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 3) - - assert.Equal(t, got[0].ResourceId(), "agm-vryjzimtj1-hl7ksq-DELETE") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayMethodResourceType) - - assert.Equal(t, got[1].ResourceId(), "agm-vryjzimtj1-hl7ksq-GET") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayMethodResourceType) - - assert.Equal(t, got[2].ResourceId(), "agm-vryjzimtj1-hl7ksq-POST") - assert.Equal(t, got[2].ResourceType(), resourceaws.AwsApiGatewayMethodResourceType) - }, - }, - { - test: "cannot list rest apis", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayMethodResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResourceType, resourceaws.AwsApiGatewayRestApiResourceType), - }, - { - test: "cannot list api gateway resources", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiResources", *apis[0].Id).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayMethodResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResourceType, resourceaws.AwsApiGatewayResourceResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResourceType, resourceaws.AwsApiGatewayResourceResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayMethodEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := remote.NewSortableScanner(NewScanner(remoteLibrary, alerter, scanOptions, testFilter)) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayModel(t *testing.T) { - dummyError := errors.New("this is an error") - apis := []*apigateway.RestApi{ - {Id: awssdk.String("vryjzimtj1")}, - } - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway models", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiModels", *apis[0].Id).Return([]*apigateway.Model{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway models", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiModels", *apis[0].Id).Return([]*apigateway.Model{ - {Id: awssdk.String("g68a4s")}, - {Id: awssdk.String("85v536")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "g68a4s") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayModelResourceType) - - assert.Equal(t, got[1].ResourceId(), "85v536") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayModelResourceType) - }, - }, - { - test: "cannot list rest apis", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayModelResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayModelResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayModelResourceType, resourceaws.AwsApiGatewayRestApiResourceType), - }, - { - test: "cannot list api gateway models", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiModels", *apis[0].Id).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayModelResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayModelResourceType, resourceaws.AwsApiGatewayModelResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayModelResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayModelEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayMethodResponse(t *testing.T) { - dummyError := errors.New("this is an error") - apis := []*apigateway.RestApi{ - {Id: awssdk.String("vryjzimtj1")}, - } - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway method responses", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ - {Id: awssdk.String("hl7ksq"), Path: awssdk.String("/foo"), ResourceMethods: map[string]*apigateway.Method{ - "GET": {}, - }}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway method responses", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ - {Id: awssdk.String("hl7ksq"), Path: awssdk.String("/foo"), ResourceMethods: map[string]*apigateway.Method{ - "GET": {MethodResponses: map[string]*apigateway.MethodResponse{ - "200": {}, - "404": {}, - "503": {}, - }}, - }}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 3) - - assert.Equal(t, got[0].ResourceId(), "agmr-vryjzimtj1-hl7ksq-GET-200") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayMethodResponseResourceType) - - assert.Equal(t, got[1].ResourceId(), "agmr-vryjzimtj1-hl7ksq-GET-404") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayMethodResponseResourceType) - - assert.Equal(t, got[2].ResourceId(), "agmr-vryjzimtj1-hl7ksq-GET-503") - assert.Equal(t, got[2].ResourceType(), resourceaws.AwsApiGatewayMethodResponseResourceType) - }, - }, - { - test: "cannot list rest apis", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayMethodResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResponseResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResponseResourceType, resourceaws.AwsApiGatewayRestApiResourceType), - }, - { - test: "cannot list api gateway resources", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiResources", *apis[0].Id).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayMethodResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResponseResourceType, resourceaws.AwsApiGatewayResourceResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodResponseResourceType, resourceaws.AwsApiGatewayResourceResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayMethodResponseEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := remote.NewSortableScanner(NewScanner(remoteLibrary, alerter, scanOptions, testFilter)) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayGatewayResponse(t *testing.T) { - dummyError := errors.New("this is an error") - apis := []*apigateway.RestApi{ - {Id: awssdk.String("vryjzimtj1")}, - } - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway gateway responses", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiGatewayResponses", *apis[0].Id).Return([]*apigateway.UpdateGatewayResponseOutput{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway gateway responses", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiGatewayResponses", *apis[0].Id).Return([]*apigateway.UpdateGatewayResponseOutput{ - {ResponseType: awssdk.String("UNAUTHORIZED")}, - {ResponseType: awssdk.String("ACCESS_DENIED")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "aggr-vryjzimtj1-UNAUTHORIZED") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayGatewayResponseResourceType) - - assert.Equal(t, got[1].ResourceId(), "aggr-vryjzimtj1-ACCESS_DENIED") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayGatewayResponseResourceType) - }, - }, - { - test: "cannot list rest apis", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayGatewayResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayGatewayResponseResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayGatewayResponseResourceType, resourceaws.AwsApiGatewayRestApiResourceType), - }, - { - test: "cannot list api gateway gateway responses", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiGatewayResponses", *apis[0].Id).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayGatewayResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayGatewayResponseResourceType, resourceaws.AwsApiGatewayGatewayResponseResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayGatewayResponseResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayGatewayResponseEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayMethodSettings(t *testing.T) { - dummyError := errors.New("this is an error") - apis := []*apigateway.RestApi{ - {Id: awssdk.String("vryjzimtj1")}, - } - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway method settings", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiStages", *apis[0].Id).Return([]*apigateway.Stage{ - {StageName: awssdk.String("foo"), MethodSettings: map[string]*apigateway.MethodSetting{}}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway method settings", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiStages", *apis[0].Id).Return([]*apigateway.Stage{ - {StageName: awssdk.String("foo"), MethodSettings: map[string]*apigateway.MethodSetting{ - "*/*": {}, - "foo/GET": {}, - "foo/DELETE": {}, - }}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 3) - - assert.Equal(t, got[0].ResourceId(), "vryjzimtj1-foo-*/*") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayMethodSettingsResourceType) - - assert.Equal(t, got[1].ResourceId(), "vryjzimtj1-foo-foo/DELETE") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayMethodSettingsResourceType) - - assert.Equal(t, got[2].ResourceId(), "vryjzimtj1-foo-foo/GET") - assert.Equal(t, got[2].ResourceType(), resourceaws.AwsApiGatewayMethodSettingsResourceType) - }, - }, - { - test: "cannot list rest apis", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayMethodSettingsResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodSettingsResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodSettingsResourceType, resourceaws.AwsApiGatewayRestApiResourceType), - }, - { - test: "cannot list api gateway settings", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiStages", *apis[0].Id).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayMethodSettingsResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodSettingsResourceType, resourceaws.AwsApiGatewayStageResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayMethodSettingsResourceType, resourceaws.AwsApiGatewayStageResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayMethodSettingsEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := remote.NewSortableScanner(NewScanner(remoteLibrary, alerter, scanOptions, testFilter)) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayIntegration(t *testing.T) { - dummyError := errors.New("this is an error") - apis := []*apigateway.RestApi{ - {Id: awssdk.String("u7jce3lokk")}, - } - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway integrations", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ - {Id: awssdk.String("z9ag20"), Path: awssdk.String("/foo")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway integrations", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ - {Id: awssdk.String("z9ag20"), Path: awssdk.String("/foo"), ResourceMethods: map[string]*apigateway.Method{ - "GET": {}, - "POST": {}, - "DELETE": {}, - }}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 3) - - assert.Equal(t, got[0].ResourceId(), "agi-u7jce3lokk-z9ag20-DELETE") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayIntegrationResourceType) - - assert.Equal(t, got[1].ResourceId(), "agi-u7jce3lokk-z9ag20-GET") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayIntegrationResourceType) - - assert.Equal(t, got[2].ResourceId(), "agi-u7jce3lokk-z9ag20-POST") - assert.Equal(t, got[2].ResourceType(), resourceaws.AwsApiGatewayIntegrationResourceType) - }, - }, - { - test: "cannot list rest apis", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayIntegrationResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResourceType, resourceaws.AwsApiGatewayRestApiResourceType), - }, - { - test: "cannot list api gateway resources", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiResources", *apis[0].Id).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayIntegrationResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResourceType, resourceaws.AwsApiGatewayResourceResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResourceType, resourceaws.AwsApiGatewayResourceResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayIntegrationEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := remote.NewSortableScanner(NewScanner(remoteLibrary, alerter, scanOptions, testFilter)) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayIntegrationResponse(t *testing.T) { - dummyError := errors.New("this is an error") - apis := []*apigateway.RestApi{ - {Id: awssdk.String("u7jce3lokk")}, - } - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway integration responses", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ - {Id: awssdk.String("z9ag20"), Path: awssdk.String("/foo"), ResourceMethods: map[string]*apigateway.Method{ - "GET": {}, - }}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway integration responses", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiResources", *apis[0].Id).Return([]*apigateway.Resource{ - {Id: awssdk.String("z9ag20"), Path: awssdk.String("/foo"), ResourceMethods: map[string]*apigateway.Method{ - "GET": { - MethodIntegration: &apigateway.Integration{ - IntegrationResponses: map[string]*apigateway.IntegrationResponse{ - "200": {}, - "302": {}, - }, - }, - }, - }}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "agir-u7jce3lokk-z9ag20-GET-200") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayIntegrationResponseResourceType) - - assert.Equal(t, got[1].ResourceId(), "agir-u7jce3lokk-z9ag20-GET-302") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayIntegrationResponseResourceType) - }, - }, - { - test: "cannot list rest apis", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayIntegrationResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResponseResourceType, resourceaws.AwsApiGatewayRestApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResponseResourceType, resourceaws.AwsApiGatewayRestApiResourceType), - }, - { - test: "cannot list api gateway resources", - mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRestApis").Return(apis, nil) - repo.On("ListAllRestApiResources", *apis[0].Id).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayIntegrationResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResponseResourceType, resourceaws.AwsApiGatewayResourceResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayIntegrationResponseResourceType, resourceaws.AwsApiGatewayResourceResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayIntegrationResponseEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := remote.NewSortableScanner(NewScanner(remoteLibrary, alerter, scanOptions, testFilter)) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_apigatewayv2_scanner_test.go b/pkg/remote/aws_apigatewayv2_scanner_test.go deleted file mode 100644 index 28d3f110..00000000 --- a/pkg/remote/aws_apigatewayv2_scanner_test.go +++ /dev/null @@ -1,1229 +0,0 @@ -package remote - -import ( - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/apigateway" - "github.com/aws/aws-sdk-go/service/apigatewayv2" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestApiGatewayV2Api(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayV2Repository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway v2 api", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "single api gateway v2 api", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{ - {ApiId: awssdk.String("f5vdrg12tk")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - - assert.Equal(t, got[0].ResourceId(), "f5vdrg12tk") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2ApiResourceType) - }, - }, - { - test: "cannot list api gateway v2 apis", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2ApiResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ApiResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayV2ApiResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayV2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayV2Repository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayV2ApiEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayV2Route(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayV2Repository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway v2 api", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "single api gateway v2 api with a single route", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{ - {ApiId: awssdk.String("a-gateway")}, - }, nil) - repository.On("ListAllApiRoutes", awssdk.String("a-gateway")). - Return([]*apigatewayv2.Route{{ - RouteId: awssdk.String("a-route"), - RouteKey: awssdk.String("POST /an-example"), - }}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - - assert.Equal(t, "a-route", got[0].ResourceId()) - assert.Equal(t, resourceaws.AwsApiGatewayV2RouteResourceType, got[0].ResourceType()) - expectedAttrs := &resource.Attributes{ - "api_id": "a-gateway", - "route_key": "POST /an-example", - } - assert.Equal(t, expectedAttrs, got[0].Attributes()) - }, - }, - { - test: "cannot list api gateway v2 apis", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2RouteResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ApiResourceType, resourceaws.AwsApiGatewayV2RouteResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2RouteResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayV2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayV2Repository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayV2RouteEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayV2Deployment(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayV2Repository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "single api gateway v2 api with a single deployment", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{ - {ApiId: awssdk.String("a-gateway")}, - }, nil) - repository.On("ListAllApiDeployments", awssdk.String("a-gateway")). - Return([]*apigatewayv2.Deployment{{ - DeploymentId: awssdk.String("a-deployment"), - }}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - - assert.Equal(t, "a-deployment", got[0].ResourceId()) - assert.Equal(t, resourceaws.AwsApiGatewayV2DeploymentResourceType, got[0].ResourceType()) - expectedAttrs := &resource.Attributes{} - assert.Equal(t, expectedAttrs, got[0].Attributes()) - }, - }, - { - test: "no API gateways", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "single API gateway with no deployments", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{ - {ApiId: awssdk.String("a-gateway")}, - }, nil) - repository.On("ListAllApiDeployments", awssdk.String("a-gateway")). - Return([]*apigatewayv2.Deployment{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "error listing API gateways", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2DeploymentResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ApiResourceType, resourceaws.AwsApiGatewayV2RouteResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2DeploymentResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), - }, - { - test: "error listing deployments of an API", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{ - {ApiId: awssdk.String("a-gateway")}, - }, nil) - repository.On("ListAllApiDeployments", awssdk.String("a-gateway")).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2DeploymentResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2DeploymentResourceType, resourceaws.AwsApiGatewayV2DeploymentResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2DeploymentResourceType, resourceaws.AwsApiGatewayV2DeploymentResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayV2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayV2Repository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayV2DeploymentEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayV2VpcLink(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayV2Repository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway v2 vpc links", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllVpcLinks").Return([]*apigatewayv2.VpcLink{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "single api gateway v2 vpc link", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllVpcLinks").Return([]*apigatewayv2.VpcLink{ - {VpcLinkId: awssdk.String("b8r351")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - - assert.Equal(t, got[0].ResourceId(), "b8r351") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2VpcLinkResourceType) - }, - }, - { - test: "cannot list api gateway v2 vpc links", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllVpcLinks").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2VpcLinkResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2VpcLinkResourceType, resourceaws.AwsApiGatewayV2VpcLinkResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayV2VpcLinkResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayV2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayV2Repository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayV2VpcLinkEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayV2Authorizer(t *testing.T) { - dummyError := errors.New("this is an error") - - apis := []*apigatewayv2.Api{ - {ApiId: awssdk.String("bmyl5c6huh")}, - {ApiId: awssdk.String("blghshbgte")}, - } - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayV2Repository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway v2 authorizers", - mocks: func(repo *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repo.On("ListAllApis").Return(apis, nil) - repo.On("ListAllApiAuthorizers", *apis[0].ApiId).Return([]*apigatewayv2.Authorizer{}, nil).Once() - repo.On("ListAllApiAuthorizers", *apis[1].ApiId).Return([]*apigatewayv2.Authorizer{}, nil).Once() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway v2 authorizers", - mocks: func(repo *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repo.On("ListAllApis").Return(apis, nil) - repo.On("ListAllApiAuthorizers", *apis[0].ApiId).Return([]*apigatewayv2.Authorizer{ - {AuthorizerId: awssdk.String("xaappu")}, - }, nil).Once() - repo.On("ListAllApiAuthorizers", *apis[1].ApiId).Return([]*apigatewayv2.Authorizer{ - {AuthorizerId: awssdk.String("bwhebj")}, - }, nil).Once() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "xaappu") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2AuthorizerResourceType) - - assert.Equal(t, got[1].ResourceId(), "bwhebj") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayV2AuthorizerResourceType) - }, - }, - { - test: "cannot list apis", - mocks: func(repo *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repo.On("ListAllApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2AuthorizerResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2AuthorizerResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2AuthorizerResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), - }, - { - test: "cannot list api gateway v2 authorizers", - mocks: func(repo *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repo.On("ListAllApis").Return(apis, nil) - repo.On("ListAllApiAuthorizers", *apis[0].ApiId).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2AuthorizerResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2AuthorizerResourceType, resourceaws.AwsApiGatewayV2AuthorizerResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayV2AuthorizerResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayV2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayV2Repository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayV2AuthorizerEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayV2Integration(t *testing.T) { - dummyError := errors.New("this is an error") - - apis := []*apigatewayv2.Api{ - {ApiId: awssdk.String("bmyl5c6huh")}, - {ApiId: awssdk.String("blghshbgte")}, - } - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayV2Repository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway v2 integrations", - mocks: func(repo *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repo.On("ListAllApis").Return(apis, nil) - repo.On("ListAllApiIntegrations", *apis[0].ApiId).Return([]*apigatewayv2.Integration{}, nil).Once() - repo.On("ListAllApiIntegrations", *apis[1].ApiId).Return([]*apigatewayv2.Integration{}, nil).Once() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway v2 integrations", - mocks: func(repo *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repo.On("ListAllApis").Return(apis, nil) - repo.On("ListAllApiIntegrations", *apis[0].ApiId).Return([]*apigatewayv2.Integration{ - { - IntegrationId: awssdk.String("xaappu"), - IntegrationType: awssdk.String("MOCK"), - }, - }, nil).Once() - repo.On("ListAllApiIntegrations", *apis[1].ApiId).Return([]*apigatewayv2.Integration{ - { - IntegrationId: awssdk.String("bwhebj"), - IntegrationType: awssdk.String("MOCK"), - }, - }, nil).Once() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "xaappu") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2IntegrationResourceType) - - assert.Equal(t, got[1].ResourceId(), "bwhebj") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayV2IntegrationResourceType) - }, - }, - { - test: "cannot list apis", - mocks: func(repo *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repo.On("ListAllApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2IntegrationResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), - }, - { - test: "cannot list api gateway v2 integrations", - mocks: func(repo *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repo.On("ListAllApis").Return(apis, nil) - repo.On("ListAllApiIntegrations", *apis[0].ApiId).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2IntegrationResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResourceType, resourceaws.AwsApiGatewayV2IntegrationResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayV2IntegrationResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayV2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayV2Repository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayV2IntegrationEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayV2Model(t *testing.T) { - dummyError := errors.New("this is an error") - - apis := []*apigatewayv2.Api{ - {ApiId: awssdk.String("bmyl5c6huh")}, - {ApiId: awssdk.String("blghshbgte")}, - } - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayV2Repository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway v2 models", - mocks: func(repo *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repo.On("ListAllApis").Return(apis, nil) - repo.On("ListAllApiModels", *apis[0].ApiId).Return([]*apigatewayv2.Model{}, nil).Once() - repo.On("ListAllApiModels", *apis[1].ApiId).Return([]*apigatewayv2.Model{}, nil).Once() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple api gateway v2 models", - mocks: func(repo *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repo.On("ListAllApis").Return(apis, nil) - repo.On("ListAllApiModels", *apis[0].ApiId).Return([]*apigatewayv2.Model{ - { - ModelId: awssdk.String("vdw6up"), - Name: awssdk.String("model1"), - }, - }, nil).Once() - repo.On("ListAllApiModels", *apis[1].ApiId).Return([]*apigatewayv2.Model{ - { - ModelId: awssdk.String("bwhebj"), - Name: awssdk.String("model2"), - }, - }, nil).Once() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "vdw6up") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2ModelResourceType) - assert.Equal(t, "model1", *got[0].Attributes().GetString("name")) - - assert.Equal(t, got[1].ResourceId(), "bwhebj") - assert.Equal(t, got[1].ResourceType(), resourceaws.AwsApiGatewayV2ModelResourceType) - assert.Equal(t, "model2", *got[1].Attributes().GetString("name")) - - }, - }, - { - test: "cannot list apis", - mocks: func(repo *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repo.On("ListAllApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2ModelResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ModelResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ModelResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), - }, - { - test: "cannot list api gateway v2 model", - mocks: func(repo *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repo.On("ListAllApis").Return(apis, nil) - repo.On("ListAllApiModels", *apis[0].ApiId).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2ModelResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ModelResourceType, resourceaws.AwsApiGatewayV2ModelResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayV2ModelResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayV2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayV2Repository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayV2ModelEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayV2Stage(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayV2Repository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway v2 api", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "single api gateway v2 api with a single stage", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{ - {ApiId: awssdk.String("a-gateway")}, - }, nil) - repository.On("ListAllApiStages", "a-gateway"). - Return([]*apigatewayv2.Stage{{ - StageName: awssdk.String("a-stage"), - }}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - - assert.Equal(t, "a-stage", got[0].ResourceId()) - assert.Equal(t, resourceaws.AwsApiGatewayV2StageResourceType, got[0].ResourceType()) - }, - }, - { - test: "cannot list api gateway v2 apis", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2StageResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ApiResourceType, resourceaws.AwsApiGatewayV2StageResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2StageResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayV2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayV2Repository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayV2StageEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayV2RouteResponse(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayV2Repository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway v2 route responses", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{ - {ApiId: awssdk.String("77ooqulkke")}, - }, nil) - repository.On("ListAllApiRoutes", awssdk.String("77ooqulkke")). - Return([]*apigatewayv2.Route{ - {RouteId: awssdk.String("liqc5u4")}, - }, nil) - repository.On("ListAllApiRouteResponses", "77ooqulkke", "liqc5u4"). - Return([]*apigatewayv2.RouteResponse{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "single api gateway v2 route with one route response", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{ - {ApiId: awssdk.String("77ooqulkke")}, - }, nil) - repository.On("ListAllApiRoutes", awssdk.String("77ooqulkke")). - Return([]*apigatewayv2.Route{ - {RouteId: awssdk.String("liqc5u4")}, - }, nil) - repository.On("ListAllApiRouteResponses", "77ooqulkke", "liqc5u4"). - Return([]*apigatewayv2.RouteResponse{ - {RouteResponseId: awssdk.String("nbw7vw")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - - assert.Equal(t, got[0].ResourceId(), "nbw7vw") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2RouteResponseResourceType) - }, - }, - { - test: "cannot list api gateway v2 apis", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2RouteResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ApiResourceType, resourceaws.AwsApiGatewayV2RouteResponseResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2RouteResponseResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), - }, - { - test: "cannot list api gateway v2 routes", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{ - {ApiId: awssdk.String("77ooqulkke")}, - }, nil) - repository.On("ListAllApiRoutes", awssdk.String("77ooqulkke")).Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2RouteResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2RouteResourceType, resourceaws.AwsApiGatewayV2RouteResponseResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2RouteResponseResourceType, resourceaws.AwsApiGatewayV2RouteResourceType), - }, - { - test: "cannot list api gateway v2 route responses", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{ - {ApiId: awssdk.String("77ooqulkke")}, - }, nil) - repository.On("ListAllApiRoutes", awssdk.String("77ooqulkke")). - Return([]*apigatewayv2.Route{ - {RouteId: awssdk.String("liqc5u4")}, - }, nil) - repository.On("ListAllApiRouteResponses", "77ooqulkke", "liqc5u4").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2RouteResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2RouteResponseResourceType, resourceaws.AwsApiGatewayV2RouteResponseResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2RouteResponseResourceType, resourceaws.AwsApiGatewayV2RouteResponseResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayV2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayV2Repository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayV2RouteResponseEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayV2Mapping(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *repository.MockApiGatewayV2Repository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway v2 domains", - mocks: func(repositoryV1 *repository.MockApiGatewayRepository, repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repositoryV1.On("ListAllDomainNames").Return([]*apigateway.DomainName{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "single api gateway v2 domain with a single mapping", - mocks: func(repositoryV1 *repository.MockApiGatewayRepository, repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repositoryV1.On("ListAllDomainNames").Return([]*apigateway.DomainName{ - {DomainName: awssdk.String("example.com")}, - }, nil) - repository.On("ListAllApiMappings", "example.com"). - Return([]*apigatewayv2.ApiMapping{{ - Stage: awssdk.String("a-stage"), - ApiId: awssdk.String("foobar"), - ApiMappingId: awssdk.String("barfoo"), - }}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - - assert.Equal(t, "barfoo", got[0].ResourceId()) - assert.Equal(t, resourceaws.AwsApiGatewayV2MappingResourceType, got[0].ResourceType()) - }, - }, - { - test: "cannot list api gateway v2 domains", - mocks: func(repositoryV1 *repository.MockApiGatewayRepository, repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repositoryV1.On("ListAllDomainNames").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2MappingResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayDomainNameResourceType, resourceaws.AwsApiGatewayV2MappingResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2MappingResourceType, resourceaws.AwsApiGatewayDomainNameResourceType), - }, - { - test: "cannot list api gateway v2 mappings", - mocks: func(repositoryV1 *repository.MockApiGatewayRepository, repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repositoryV1.On("ListAllDomainNames").Return([]*apigateway.DomainName{ - {DomainName: awssdk.String("example.com")}, - }, nil) - repository.On("ListAllApiMappings", "example.com"). - Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2MappingResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2MappingResourceType, resourceaws.AwsApiGatewayV2MappingResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayV2MappingResourceType), - }, - { - test: "returning mapping with invalid attributes", - mocks: func(repositoryV1 *repository.MockApiGatewayRepository, repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repositoryV1.On("ListAllDomainNames").Return([]*apigateway.DomainName{ - {DomainName: awssdk.String("example.com")}, - }, nil) - repository.On("ListAllApiMappings", "example.com"). - Return([]*apigatewayv2.ApiMapping{ - { - ApiMappingId: awssdk.String("barfoo"), - }, - { - Stage: awssdk.String("a-stage"), - ApiId: awssdk.String("foobar"), - ApiMappingId: awssdk.String("foobar"), - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, "barfoo", got[0].ResourceId()) - assert.Equal(t, resourceaws.AwsApiGatewayV2MappingResourceType, got[0].ResourceType()) - - assert.Equal(t, "foobar", got[1].ResourceId()) - assert.Equal(t, resourceaws.AwsApiGatewayV2MappingResourceType, got[1].ResourceType()) - }, - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepoV1 := &repository.MockApiGatewayRepository{} - fakeRepo := &repository.MockApiGatewayV2Repository{} - c.mocks(fakeRepoV1, fakeRepo, alerter) - - remoteLibrary.AddEnumerator(aws.NewApiGatewayV2MappingEnumerator(fakeRepo, fakeRepoV1, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - fakeRepoV1.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayV2DomainName(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway v2 domain names", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllDomainNames").Return([]*apigateway.DomainName{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "single api gateway v2 domain name", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllDomainNames").Return([]*apigateway.DomainName{ - {DomainName: awssdk.String("b8r351.example.com")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - - assert.Equal(t, got[0].ResourceId(), "b8r351.example.com") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2DomainNameResourceType) - }, - }, - { - test: "cannot list api gateway v2 domain names", - mocks: func(repository *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllDomainNames").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2DomainNameResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2DomainNameResourceType, resourceaws.AwsApiGatewayV2DomainNameResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayV2DomainNameResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayV2DomainNameEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestApiGatewayV2IntegrationResponse(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockApiGatewayV2Repository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no api gateway v2 integration responses", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{ - {ApiId: awssdk.String("yw28nwdf34")}, - }, nil) - repository.On("ListAllApiIntegrations", "yw28nwdf34"). - Return([]*apigatewayv2.Integration{ - {IntegrationId: awssdk.String("fmezvlh")}, - }, nil) - repository.On("ListAllApiIntegrationResponses", "yw28nwdf34", "fmezvlh"). - Return([]*apigatewayv2.IntegrationResponse{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "single api gateway v2 integration with one integration response", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{ - {ApiId: awssdk.String("yw28nwdf34")}, - }, nil) - repository.On("ListAllApiIntegrations", "yw28nwdf34"). - Return([]*apigatewayv2.Integration{ - {IntegrationId: awssdk.String("fmezvlh")}, - }, nil) - repository.On("ListAllApiIntegrationResponses", "yw28nwdf34", "fmezvlh"). - Return([]*apigatewayv2.IntegrationResponse{ - {IntegrationResponseId: awssdk.String("sf67ti7")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - - assert.Equal(t, got[0].ResourceId(), "sf67ti7") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsApiGatewayV2IntegrationResponseResourceType) - }, - }, - { - test: "cannot list api gateway v2 apis", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2IntegrationResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2ApiResourceType, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType, resourceaws.AwsApiGatewayV2ApiResourceType), - }, - { - test: "cannot list api gateway v2 integrations", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{ - {ApiId: awssdk.String("yw28nwdf34")}, - }, nil) - repository.On("ListAllApiIntegrations", "yw28nwdf34").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2IntegrationResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResourceType, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType, resourceaws.AwsApiGatewayV2IntegrationResourceType), - }, - { - test: "cannot list api gateway v2 integration responses", - mocks: func(repository *repository.MockApiGatewayV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllApis").Return([]*apigatewayv2.Api{ - {ApiId: awssdk.String("yw28nwdf34")}, - }, nil) - repository.On("ListAllApiIntegrations", "yw28nwdf34"). - Return([]*apigatewayv2.Integration{ - {IntegrationId: awssdk.String("fmezvlh")}, - }, nil) - repository.On("ListAllApiIntegrationResponses", "yw28nwdf34", "fmezvlh").Return(nil, dummyError) - alerter.On("SendAlert", resourceaws.AwsApiGatewayV2IntegrationResponseResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType, resourceaws.AwsApiGatewayV2IntegrationResponseResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockApiGatewayV2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ApiGatewayV2Repository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewApiGatewayV2IntegrationResponseEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_applicationautoscaling_scanner_test.go b/pkg/remote/aws_applicationautoscaling_scanner_test.go deleted file mode 100644 index 9830de45..00000000 --- a/pkg/remote/aws_applicationautoscaling_scanner_test.go +++ /dev/null @@ -1,326 +0,0 @@ -package remote - -import ( - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/applicationautoscaling" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestAppAutoScalingTarget(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockAppAutoScalingRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "should return one target", - dirName: "aws_appautoscaling_target_single", - mocks: func(client *repository.MockAppAutoScalingRepository, alerter *mocks.AlerterInterface) { - client.On("ServiceNamespaceValues").Return(applicationautoscaling.ServiceNamespace_Values()).Once() - - client.On("DescribeScalableTargets", "dynamodb").Return([]*applicationautoscaling.ScalableTarget{ - { - ResourceId: awssdk.String("table/GameScores"), - RoleARN: awssdk.String("arn:aws:iam::533948124879:role/aws-service-role/dynamodb.application-autoscaling.amazonaws.com/AWSServiceRoleForApplicationAutoScaling_DynamoDBTable"), - ScalableDimension: awssdk.String("dynamodb:table:ReadCapacityUnits"), - ServiceNamespace: awssdk.String("dynamodb"), - MaxCapacity: awssdk.Int64(100), - MinCapacity: awssdk.Int64(5), - }, - }, nil).Once() - - client.On("DescribeScalableTargets", mock.AnythingOfType("string")).Return([]*applicationautoscaling.ScalableTarget{}, nil).Times(len(applicationautoscaling.ServiceNamespace_Values()) - 1) - }, - wantErr: nil, - }, - { - test: "should return remote error", - dirName: "aws_appautoscaling_target_single", - mocks: func(client *repository.MockAppAutoScalingRepository, alerter *mocks.AlerterInterface) { - client.On("ServiceNamespaceValues").Return(applicationautoscaling.ServiceNamespace_Values()).Once() - - client.On("DescribeScalableTargets", mock.AnythingOfType("string")).Return(nil, errors.New("remote error")).Once() - }, - wantErr: remoteerror.NewResourceListingError(errors.New("remote error"), resourceaws.AwsAppAutoscalingTargetResourceType), - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockAppAutoScalingRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.AppAutoScalingRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewAppAutoScalingRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewAppAutoscalingTargetEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsAppAutoscalingTargetResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsAppAutoscalingTargetResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - if err != nil { - assert.EqualError(tt, c.wantErr, err.Error()) - } else { - assert.Equal(tt, err, c.wantErr) - } - - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsAppAutoscalingTargetResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAppAutoScalingPolicy(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockAppAutoScalingRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "should return one policy", - dirName: "aws_appautoscaling_policy_single", - mocks: func(client *repository.MockAppAutoScalingRepository, alerter *mocks.AlerterInterface) { - client.On("ServiceNamespaceValues").Return(applicationautoscaling.ServiceNamespace_Values()).Once() - - client.On("DescribeScalingPolicies", "dynamodb").Return([]*applicationautoscaling.ScalingPolicy{ - { - PolicyName: awssdk.String("DynamoDBReadCapacityUtilization:table/GameScores"), - ResourceId: awssdk.String("table/GameScores"), - ScalableDimension: awssdk.String("dynamodb:table:ReadCapacityUnits"), - ServiceNamespace: awssdk.String("dynamodb"), - }, - }, nil).Once() - - client.On("DescribeScalingPolicies", mock.AnythingOfType("string")).Return([]*applicationautoscaling.ScalingPolicy{}, nil).Times(len(applicationautoscaling.ServiceNamespace_Values()) - 1) - }, - wantErr: nil, - }, - { - test: "should return remote error", - dirName: "aws_appautoscaling_policy_single", - mocks: func(client *repository.MockAppAutoScalingRepository, alerter *mocks.AlerterInterface) { - client.On("ServiceNamespaceValues").Return(applicationautoscaling.ServiceNamespace_Values()).Once() - - client.On("DescribeScalingPolicies", mock.AnythingOfType("string")).Return(nil, errors.New("remote error")).Once() - }, - wantErr: remoteerror.NewResourceListingError(errors.New("remote error"), resourceaws.AwsAppAutoscalingPolicyResourceType), - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockAppAutoScalingRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.AppAutoScalingRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewAppAutoScalingRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewAppAutoscalingPolicyEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsAppAutoscalingPolicyResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsAppAutoscalingPolicyResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - if err != nil { - assert.EqualError(tt, c.wantErr, err.Error()) - } else { - assert.Equal(tt, err, c.wantErr) - } - - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsAppAutoscalingPolicyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAppAutoScalingScheduledAction(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockAppAutoScalingRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "should return one scheduled action", - mocks: func(client *repository.MockAppAutoScalingRepository, alerter *mocks.AlerterInterface) { - matchServiceNamespaceFunc := func(ns string) bool { - for _, n := range applicationautoscaling.ServiceNamespace_Values() { - if n == ns { - return true - } - } - return false - } - - client.On("ServiceNamespaceValues").Return(applicationautoscaling.ServiceNamespace_Values()).Once() - - client.On("DescribeScheduledActions", mock.MatchedBy(matchServiceNamespaceFunc)).Return([]*applicationautoscaling.ScheduledAction{ - { - ScheduledActionName: awssdk.String("action"), - ResourceId: awssdk.String("table/GameScores"), - ScalableDimension: awssdk.String("dynamodb:table:ReadCapacityUnits"), - ServiceNamespace: awssdk.String("dynamodb"), - }, - }, nil).Once() - - client.On("DescribeScheduledActions", mock.MatchedBy(matchServiceNamespaceFunc)).Return([]*applicationautoscaling.ScheduledAction{}, nil).Times(len(applicationautoscaling.ServiceNamespace_Values()) - 1) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - assert.Equal(t, "action-dynamodb-table/GameScores", got[0].ResourceId()) - assert.Equal(t, resourceaws.AwsAppAutoscalingScheduledActionResourceType, got[0].ResourceType()) - }, - wantErr: nil, - }, - { - test: "should return remote error", - mocks: func(client *repository.MockAppAutoScalingRepository, alerter *mocks.AlerterInterface) { - client.On("ServiceNamespaceValues").Return(applicationautoscaling.ServiceNamespace_Values()).Once() - - client.On("DescribeScheduledActions", mock.AnythingOfType("string")).Return(nil, dummyError).Once() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - wantErr: remoteerror.NewResourceListingError(dummyError, resourceaws.AwsAppAutoscalingScheduledActionResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockAppAutoScalingRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.AppAutoScalingRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewAppAutoscalingScheduledActionEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_autoscaling_scanner_test.go b/pkg/remote/aws_autoscaling_scanner_test.go deleted file mode 100644 index 1a68df40..00000000 --- a/pkg/remote/aws_autoscaling_scanner_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package remote - -import ( - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/autoscaling" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestAutoscaling_LaunchConfiguration(t *testing.T) { - tests := []struct { - test string - mocks func(*repository.MockAutoScalingRepository, *mocks.AlerterInterface) - assertExpected func(*testing.T, []*resource.Resource) - wantErr error - }{ - { - test: "no launch configuration", - mocks: func(repository *repository.MockAutoScalingRepository, alerter *mocks.AlerterInterface) { - repository.On("DescribeLaunchConfigurations").Return([]*autoscaling.LaunchConfiguration{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple launch configurations", - mocks: func(repository *repository.MockAutoScalingRepository, alerter *mocks.AlerterInterface) { - repository.On("DescribeLaunchConfigurations").Return([]*autoscaling.LaunchConfiguration{ - {LaunchConfigurationName: awssdk.String("web_config_1")}, - {LaunchConfigurationName: awssdk.String("web_config_2")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, "web_config_1", got[0].ResourceId()) - assert.Equal(t, resourceaws.AwsLaunchConfigurationResourceType, got[0].ResourceType()) - - assert.Equal(t, "web_config_2", got[1].ResourceId()) - assert.Equal(t, resourceaws.AwsLaunchConfigurationResourceType, got[1].ResourceType()) - }, - }, - { - test: "cannot list launch configurations", - mocks: func(repository *repository.MockAutoScalingRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("DescribeLaunchConfigurations").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsLaunchConfigurationResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsLaunchConfigurationResourceType, resourceaws.AwsLaunchConfigurationResourceType), alerts.EnumerationPhase)).Return() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockAutoScalingRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.AutoScalingRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewLaunchConfigurationEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_cloudformation_scanner_test.go b/pkg/remote/aws_cloudformation_scanner_test.go deleted file mode 100644 index e95b98ec..00000000 --- a/pkg/remote/aws_cloudformation_scanner_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package remote - -import ( - "errors" - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/cloudformation" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestCloudformationStack(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockCloudformationRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no cloudformation stacks", - dirName: "aws_cloudformation_stack_empty", - mocks: func(repository *repository.MockCloudformationRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllStacks").Return([]*cloudformation.Stack{}, nil) - }, - }, - { - test: "multiple cloudformation stacks", - dirName: "aws_cloudformation_stack_multiple", - mocks: func(repository *repository.MockCloudformationRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllStacks").Return([]*cloudformation.Stack{ - {StackId: awssdk.String("arn:aws:cloudformation:us-east-1:047081014315:stack/bar-stack/c7a96e70-0f21-11ec-bd2a-0a2d95c2b2ab")}, - {StackId: awssdk.String("arn:aws:cloudformation:us-east-1:047081014315:stack/foo-stack/c7aa0ab0-0f21-11ec-ba25-129d8c0b3757")}, - }, nil) - }, - }, - { - test: "cannot list cloudformation stacks", - dirName: "aws_cloudformation_stack_list", - mocks: func(repository *repository.MockCloudformationRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 400, "") - repository.On("ListAllStacks").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsCloudformationStackResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsCloudformationStackResourceType, resourceaws.AwsCloudformationStackResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockCloudformationRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.CloudformationRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewCloudformationRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewCloudformationStackEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsCloudformationStackResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsCloudformationStackResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsCloudformationStackResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_cloudfront_scanner_test.go b/pkg/remote/aws_cloudfront_scanner_test.go deleted file mode 100644 index fa6ff6d5..00000000 --- a/pkg/remote/aws_cloudfront_scanner_test.go +++ /dev/null @@ -1,124 +0,0 @@ -package remote - -import ( - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/cloudfront" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestCloudfrontDistribution(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockCloudfrontRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no cloudfront distributions", - dirName: "aws_cloudfront_distribution_empty", - mocks: func(repository *repository.MockCloudfrontRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllDistributions").Return([]*cloudfront.DistributionSummary{}, nil) - }, - }, - { - test: "single cloudfront distribution", - dirName: "aws_cloudfront_distribution_single", - mocks: func(repository *repository.MockCloudfrontRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllDistributions").Return([]*cloudfront.DistributionSummary{ - {Id: awssdk.String("E1M9CNS0XSHI19")}, - }, nil) - }, - }, - { - test: "cannot list cloudfront distributions", - dirName: "aws_cloudfront_distribution_list", - mocks: func(repository *repository.MockCloudfrontRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 400, "") - repository.On("ListAllDistributions").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsCloudfrontDistributionResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsCloudfrontDistributionResourceType, resourceaws.AwsCloudfrontDistributionResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockCloudfrontRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.CloudfrontRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewCloudfrontRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewCloudfrontDistributionEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsCloudfrontDistributionResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsCloudfrontDistributionResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsCloudfrontDistributionResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_dynamodb_scanner_test.go b/pkg/remote/aws_dynamodb_scanner_test.go deleted file mode 100644 index b1ad5643..00000000 --- a/pkg/remote/aws_dynamodb_scanner_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package remote - -import ( - "errors" - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestDynamoDBTable(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockDynamoDBRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no DynamoDB Table", - dirName: "aws_dynamodb_table_empty", - mocks: func(client *repository.MockDynamoDBRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllTables").Return([]*string{}, nil) - }, - wantErr: nil, - }, - { - test: "Multiple DynamoDB Table", - dirName: "aws_dynamodb_table_multiple", - mocks: func(client *repository.MockDynamoDBRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllTables").Return([]*string{ - awssdk.String("GameScores"), - awssdk.String("example"), - }, nil) - }, - wantErr: nil, - }, - { - test: "cannot list DynamoDB Table", - dirName: "aws_dynamodb_table_list", - mocks: func(client *repository.MockDynamoDBRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 400, "") - client.On("ListAllTables").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsDynamodbTableResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsDynamodbTableResourceType, resourceaws.AwsDynamodbTableResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockDynamoDBRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.DynamoDBRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewDynamoDBRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewDynamoDBTableEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsDynamodbTableResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsDynamodbTableResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsDynamodbTableResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_ec2_scanner_test.go b/pkg/remote/aws_ec2_scanner_test.go deleted file mode 100644 index 49ea7c51..00000000 --- a/pkg/remote/aws_ec2_scanner_test.go +++ /dev/null @@ -1,2905 +0,0 @@ -package remote - -import ( - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestEC2EbsVolume(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no volumes", - dirName: "aws_ec2_ebs_volume_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllVolumes").Return([]*ec2.Volume{}, nil) - }, - }, - { - test: "multiple volumes", - dirName: "aws_ec2_ebs_volume_multiple", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllVolumes").Return([]*ec2.Volume{ - {VolumeId: awssdk.String("vol-081c7272a57a09db1")}, - {VolumeId: awssdk.String("vol-01ddc91d3d9d1318b")}, - }, nil) - }, - }, - { - test: "cannot list volumes", - dirName: "aws_ec2_ebs_volume_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllVolumes").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsEbsVolumeResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsEbsVolumeResourceType, resourceaws.AwsEbsVolumeResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2EbsVolumeEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsEbsVolumeResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsEbsVolumeResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsEbsVolumeResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2EbsSnapshot(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no snapshots", - dirName: "aws_ec2_ebs_snapshot_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllSnapshots").Return([]*ec2.Snapshot{}, nil) - }, - }, - { - test: "multiple snapshots", - dirName: "aws_ec2_ebs_snapshot_multiple", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllSnapshots").Return([]*ec2.Snapshot{ - {SnapshotId: awssdk.String("snap-0c509a2a880d95a39")}, - {SnapshotId: awssdk.String("snap-00672558cecd93a61")}, - }, nil) - }, - }, - { - test: "cannot list snapshots", - dirName: "aws_ec2_ebs_snapshot_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllSnapshots").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsEbsSnapshotResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsEbsSnapshotResourceType, resourceaws.AwsEbsSnapshotResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2EbsSnapshotEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsEbsSnapshotResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsEbsSnapshotResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsEbsSnapshotResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2Eip(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no eips", - dirName: "aws_ec2_eip_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllAddresses").Return([]*ec2.Address{ - {}, // Test Eip without AllocationId because it can happen (seen in sentry) - }, nil) - }, - }, - { - test: "multiple eips", - dirName: "aws_ec2_eip_multiple", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllAddresses").Return([]*ec2.Address{ - {AllocationId: awssdk.String("eipalloc-017d5267e4dda73f1")}, - {AllocationId: awssdk.String("eipalloc-0cf714dc097c992cc")}, - }, nil) - }, - }, - { - test: "cannot list eips", - dirName: "aws_ec2_eip_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllAddresses").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsEipResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsEipResourceType, resourceaws.AwsEipResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2EipEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsEipResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsEipResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsEipResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2Ami(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no amis", - dirName: "aws_ec2_ami_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllImages").Return([]*ec2.Image{}, nil) - }, - }, - { - test: "multiple amis", - dirName: "aws_ec2_ami_multiple", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllImages").Return([]*ec2.Image{ - {ImageId: awssdk.String("ami-03a578b46f4c3081b")}, - {ImageId: awssdk.String("ami-025962fd8b456731f")}, - }, nil) - }, - }, - { - test: "cannot list ami", - dirName: "aws_ec2_ami_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllImages").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsAmiResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsAmiResourceType, resourceaws.AwsAmiResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2AmiEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsAmiResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsAmiResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsAmiResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2KeyPair(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no key pairs", - dirName: "aws_ec2_key_pair_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllKeyPairs").Return([]*ec2.KeyPairInfo{}, nil) - }, - }, - { - test: "multiple key pairs", - dirName: "aws_ec2_key_pair_multiple", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllKeyPairs").Return([]*ec2.KeyPairInfo{ - {KeyName: awssdk.String("test")}, - {KeyName: awssdk.String("bar")}, - }, nil) - }, - }, - { - test: "cannot list key pairs", - dirName: "aws_ec2_key_pair_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllKeyPairs").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsKeyPairResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsKeyPairResourceType, resourceaws.AwsKeyPairResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2KeyPairEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsKeyPairResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsKeyPairResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsKeyPairResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2EipAssociation(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no eip associations", - dirName: "aws_ec2_eip_association_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllAddressesAssociation").Return([]*ec2.Address{}, nil) - }, - }, - { - test: "single eip association", - dirName: "aws_ec2_eip_association_single", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllAddressesAssociation").Return([]*ec2.Address{ - { - AssociationId: awssdk.String("eipassoc-0e9a7356e30f0c3d1"), - AllocationId: awssdk.String("eipalloc-017d5267e4dda73f1"), - }, - }, nil) - }, - }, - { - test: "cannot list eip associations", - dirName: "aws_ec2_eip_association_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllAddressesAssociation").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsEipAssociationResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsEipAssociationResourceType, resourceaws.AwsEipAssociationResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2EipAssociationEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsEipAssociationResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsEipAssociationResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsEipAssociationResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2Instance(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no instances", - dirName: "aws_ec2_instance_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllInstances").Return([]*ec2.Instance{}, nil) - }, - }, - { - test: "multiple instances", - dirName: "aws_ec2_instance_multiple", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllInstances").Return([]*ec2.Instance{ - {InstanceId: awssdk.String("i-0d3650a23f4e45dc0")}, - {InstanceId: awssdk.String("i-010376047a71419f1")}, - }, nil) - }, - }, - { - test: "terminated instances", - dirName: "aws_ec2_instance_terminated", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllInstances").Return([]*ec2.Instance{ - {InstanceId: awssdk.String("i-0e1543baf4f2cd990")}, - {InstanceId: awssdk.String("i-0a3a7ed51ae2b4fa0")}, // Nil - }, nil) - }, - }, - { - test: "cannot list instances", - dirName: "aws_ec2_instance_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllInstances").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsInstanceResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsInstanceResourceType, resourceaws.AwsInstanceResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2InstanceEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsInstanceResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsInstanceResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsInstanceResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2InternetGateway(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no internet gateways", - dirName: "aws_ec2_internet_gateway_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllInternetGateways").Return([]*ec2.InternetGateway{}, nil) - }, - }, - { - test: "multiple internet gateways", - dirName: "aws_ec2_internet_gateway_multiple", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllInternetGateways").Return([]*ec2.InternetGateway{ - {InternetGatewayId: awssdk.String("igw-0184eb41aadc62d1c")}, - {InternetGatewayId: awssdk.String("igw-047b487f5c60fca99")}, - }, nil) - }, - }, - { - test: "cannot list internet gateways", - dirName: "aws_ec2_internet_gateway_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllInternetGateways").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsInternetGatewayResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsInternetGatewayResourceType, resourceaws.AwsInternetGatewayResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2InternetGatewayEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsInternetGatewayResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsInternetGatewayResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsInternetGatewayResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestVPC(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no VPC", - dirName: "aws_vpc_empty", - mocks: func(client *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllVPCs").Once().Return([]*ec2.Vpc{}, []*ec2.Vpc{}, nil) - }, - wantErr: nil, - }, - { - test: "VPC results", - dirName: "aws_vpc", - mocks: func(client *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllVPCs").Once().Return([]*ec2.Vpc{ - { - VpcId: awssdk.String("vpc-0768e1fd0029e3fc3"), - }, - { - VpcId: awssdk.String("vpc-020b072316a95b97f"), - IsDefault: awssdk.Bool(false), - }, - { - VpcId: awssdk.String("vpc-02c50896b59598761"), - IsDefault: awssdk.Bool(false), - }, - }, []*ec2.Vpc{ - { - VpcId: awssdk.String("vpc-a8c5d4c1"), - IsDefault: awssdk.Bool(false), - }, - }, nil) - }, - wantErr: nil, - }, - { - test: "cannot list VPC", - dirName: "aws_vpc_empty", - mocks: func(client *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - client.On("ListAllVPCs").Once().Return(nil, nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsVpcResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsVpcResourceType, resourceaws.AwsVpcResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewVPCEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsVpcResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsVpcResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsVpcResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestDefaultVPC(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no VPC", - dirName: "aws_vpc_empty", - mocks: func(client *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllVPCs").Once().Return([]*ec2.Vpc{}, []*ec2.Vpc{}, nil) - }, - wantErr: nil, - }, - { - test: "default VPC results", - dirName: "aws_default_vpc", - mocks: func(client *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllVPCs").Once().Return([]*ec2.Vpc{ - { - VpcId: awssdk.String("vpc-0768e1fd0029e3fc3"), - IsDefault: awssdk.Bool(false), - }, - { - VpcId: awssdk.String("vpc-020b072316a95b97f"), - IsDefault: awssdk.Bool(false), - }, - }, []*ec2.Vpc{ - { - VpcId: awssdk.String("vpc-a8c5d4c1"), - IsDefault: awssdk.Bool(true), - }, - }, nil) - }, - wantErr: nil, - }, - { - test: "cannot list VPC", - dirName: "aws_vpc_empty", - mocks: func(client *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - client.On("ListAllVPCs").Once().Return(nil, nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsDefaultVpcResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsDefaultVpcResourceType, resourceaws.AwsDefaultVpcResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewDefaultVPCEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsDefaultVpcResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsDefaultVpcResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsDefaultVpcResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2RouteTableAssociation(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no route table associations (test for nil values)", - dirName: "aws_ec2_route_table_association_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{ - { - RouteTableId: awssdk.String("assoc_with_nil"), - Associations: []*ec2.RouteTableAssociation{ - { - AssociationState: nil, - GatewayId: nil, - Main: nil, - RouteTableAssociationId: nil, - RouteTableId: nil, - SubnetId: nil, - }, - }, - }, - {RouteTableId: awssdk.String("nil_assoc")}, - }, nil) - }, - }, - { - test: "multiple route table associations (mixed subnet and gateway associations)", - dirName: "aws_ec2_route_table_association_multiple", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{ - { - RouteTableId: awssdk.String("rtb-05aa6c5673311a17b"), // route - Associations: []*ec2.RouteTableAssociation{ - { // Should be ignored - AssociationState: &ec2.RouteTableAssociationState{ - State: awssdk.String("disassociated"), - }, - GatewayId: awssdk.String("dummy-id"), - }, - { // Should be ignored - SubnetId: nil, - GatewayId: nil, - }, - { // assoc_route_subnet1 - AssociationState: &ec2.RouteTableAssociationState{ - State: awssdk.String("associated"), - }, - Main: awssdk.Bool(false), - RouteTableAssociationId: awssdk.String("rtbassoc-0809598f92dbec03b"), - RouteTableId: awssdk.String("rtb-05aa6c5673311a17b"), - SubnetId: awssdk.String("subnet-05185af647b2eeda3"), - }, - { // assoc_route_subnet - AssociationState: &ec2.RouteTableAssociationState{ - State: awssdk.String("associated"), - }, - Main: awssdk.Bool(false), - RouteTableAssociationId: awssdk.String("rtbassoc-01957791b2cfe6ea4"), - RouteTableId: awssdk.String("rtb-05aa6c5673311a17b"), - SubnetId: awssdk.String("subnet-0e93dbfa2e5dd8282"), - }, - { // assoc_route_subnet2 - AssociationState: &ec2.RouteTableAssociationState{ - State: awssdk.String("associated"), - }, - GatewayId: nil, - Main: awssdk.Bool(false), - RouteTableAssociationId: awssdk.String("rtbassoc-0b4f97ea57490e213"), - RouteTableId: awssdk.String("rtb-05aa6c5673311a17b"), - SubnetId: awssdk.String("subnet-0fd966efd884d0362"), - }, - }, - }, - { - RouteTableId: awssdk.String("rtb-09df7cc9d16de9f8f"), // route2 - Associations: []*ec2.RouteTableAssociation{ - { // assoc_route2_gateway - AssociationState: &ec2.RouteTableAssociationState{ - State: awssdk.String("associated"), - }, - RouteTableAssociationId: awssdk.String("rtbassoc-0a79ccacfceb4944b"), - RouteTableId: awssdk.String("rtb-09df7cc9d16de9f8f"), - GatewayId: awssdk.String("igw-0238f6e09185ac954"), - }, - }, - }, - }, nil) - }, - }, - { - test: "cannot list route table associations", - dirName: "aws_ec2_route_table_association_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllRouteTables").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsRouteTableAssociationResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRouteTableAssociationResourceType, resourceaws.AwsRouteTableResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2RouteTableAssociationEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsRouteTableAssociationResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsRouteTableAssociationResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsRouteTableAssociationResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2Subnet(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no subnets", - dirName: "aws_ec2_subnet_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllSubnets").Return([]*ec2.Subnet{}, []*ec2.Subnet{}, nil) - }, - }, - { - test: "multiple subnets", - dirName: "aws_ec2_subnet_multiple", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllSubnets").Return([]*ec2.Subnet{ - { - SubnetId: awssdk.String("subnet-05810d3f933925f6d"), // subnet1 - DefaultForAz: awssdk.Bool(false), - }, - { - SubnetId: awssdk.String("subnet-0b13f1e0eacf67424"), // subnet2 - DefaultForAz: awssdk.Bool(false), - }, - { - SubnetId: awssdk.String("subnet-0c9b78001fe186e22"), // subnet3 - DefaultForAz: awssdk.Bool(false), - }, - }, []*ec2.Subnet{ - { - SubnetId: awssdk.String("subnet-44fe0c65"), // us-east-1a - DefaultForAz: awssdk.Bool(true), - }, - { - SubnetId: awssdk.String("subnet-65e16628"), // us-east-1b - DefaultForAz: awssdk.Bool(true), - }, - { - SubnetId: awssdk.String("subnet-afa656f0"), // us-east-1c - DefaultForAz: awssdk.Bool(true), - }, - }, nil) - }, - }, - { - test: "cannot list subnets", - dirName: "aws_ec2_subnet_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllSubnets").Return(nil, nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsSubnetResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSubnetResourceType, resourceaws.AwsSubnetResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2SubnetEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsSubnetResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsSubnetResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsSubnetResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2DefaultSubnet(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no default subnets", - dirName: "aws_ec2_default_subnet_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllSubnets").Return([]*ec2.Subnet{}, []*ec2.Subnet{}, nil) - }, - }, - { - test: "multiple default subnets", - dirName: "aws_ec2_default_subnet_multiple", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllSubnets").Return([]*ec2.Subnet{ - { - SubnetId: awssdk.String("subnet-05810d3f933925f6d"), // subnet1 - DefaultForAz: awssdk.Bool(false), - }, - { - SubnetId: awssdk.String("subnet-0b13f1e0eacf67424"), // subnet2 - DefaultForAz: awssdk.Bool(false), - }, - { - SubnetId: awssdk.String("subnet-0c9b78001fe186e22"), // subnet3 - DefaultForAz: awssdk.Bool(false), - }, - }, []*ec2.Subnet{ - { - SubnetId: awssdk.String("subnet-44fe0c65"), // us-east-1a - DefaultForAz: awssdk.Bool(true), - }, - { - SubnetId: awssdk.String("subnet-65e16628"), // us-east-1b - DefaultForAz: awssdk.Bool(true), - }, - { - SubnetId: awssdk.String("subnet-afa656f0"), // us-east-1c - DefaultForAz: awssdk.Bool(true), - }, - }, nil) - }, - }, - { - test: "cannot list default subnets", - dirName: "aws_ec2_default_subnet_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllSubnets").Return(nil, nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsDefaultSubnetResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsDefaultSubnetResourceType, resourceaws.AwsDefaultSubnetResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2DefaultSubnetEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsDefaultSubnetResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsDefaultSubnetResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsDefaultSubnetResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2RouteTable(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no route tables", - dirName: "aws_ec2_route_table_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{}, nil) - }, - }, - { - test: "multiple route tables", - dirName: "aws_ec2_route_table_multiple", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{ - {RouteTableId: awssdk.String("rtb-08b7b71af15e183ce")}, // table1 - {RouteTableId: awssdk.String("rtb-0002ac731f6fdea55")}, // table2 - {RouteTableId: awssdk.String("rtb-0c55d55593f33fbac")}, // table3 - { - RouteTableId: awssdk.String("rtb-0eabf071c709c0976"), // default_table - VpcId: awssdk.String("vpc-0b4a6b3536da20ecd"), - Associations: []*ec2.RouteTableAssociation{ - { - Main: awssdk.Bool(true), - }, - }, - }, - }, nil) - }, - }, - { - test: "cannot list route tables", - dirName: "aws_ec2_route_table_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllRouteTables").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsRouteTableResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRouteTableResourceType, resourceaws.AwsRouteTableResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2RouteTableEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsRouteTableResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsRouteTableResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsRouteTableResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2DefaultRouteTable(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no default route tables", - dirName: "aws_ec2_default_route_table_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{}, nil) - }, - }, - { - test: "multiple default route tables", - dirName: "aws_ec2_default_route_table_single", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{ - {RouteTableId: awssdk.String("rtb-08b7b71af15e183ce")}, // table1 - {RouteTableId: awssdk.String("rtb-0002ac731f6fdea55")}, // table2 - {RouteTableId: awssdk.String("rtb-0c55d55593f33fbac")}, // table3 - { - RouteTableId: awssdk.String("rtb-0eabf071c709c0976"), // default_table - VpcId: awssdk.String("vpc-0b4a6b3536da20ecd"), - Associations: []*ec2.RouteTableAssociation{ - { - Main: awssdk.Bool(true), - }, - }, - }, - }, nil) - }, - }, - { - test: "cannot list default route tables", - dirName: "aws_ec2_default_route_table_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllRouteTables").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsDefaultRouteTableResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsDefaultRouteTableResourceType, resourceaws.AwsDefaultRouteTableResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2DefaultRouteTableEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsDefaultRouteTableResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsDefaultRouteTableResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsDefaultRouteTableResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestVpcSecurityGroup(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no security groups", - dirName: "aws_vpc_security_group_empty", - mocks: func(client *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllSecurityGroups").Once().Return([]*ec2.SecurityGroup{}, []*ec2.SecurityGroup{}, nil) - }, - wantErr: nil, - }, - { - test: "with security groups", - dirName: "aws_vpc_security_group_multiple", - mocks: func(client *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllSecurityGroups").Once().Return([]*ec2.SecurityGroup{ - { - GroupId: awssdk.String("sg-0254c038e32f25530"), - GroupName: awssdk.String("foo"), - }, - }, []*ec2.SecurityGroup{ - { - GroupId: awssdk.String("sg-9e0204ff"), - GroupName: awssdk.String("default"), - }, - }, nil) - }, - wantErr: nil, - }, - { - test: "cannot list security groups", - dirName: "aws_vpc_security_group_empty", - mocks: func(client *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - client.On("ListAllSecurityGroups").Return(nil, nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsSecurityGroupResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSecurityGroupResourceType, resourceaws.AwsSecurityGroupResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewVPCSecurityGroupEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsSecurityGroupResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsSecurityGroupResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsSecurityGroupResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestVpcDefaultSecurityGroup(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no security groups", - dirName: "aws_vpc_default_security_group_empty", - mocks: func(client *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllSecurityGroups").Once().Return([]*ec2.SecurityGroup{}, []*ec2.SecurityGroup{}, nil) - }, - wantErr: nil, - }, - { - test: "with security groups", - dirName: "aws_vpc_default_security_group_multiple", - mocks: func(client *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllSecurityGroups").Once().Return([]*ec2.SecurityGroup{ - { - GroupId: awssdk.String("sg-0254c038e32f25530"), - GroupName: awssdk.String("foo"), - }, - }, []*ec2.SecurityGroup{ - { - GroupId: awssdk.String("sg-9e0204ff"), - GroupName: awssdk.String("default"), - }, - }, nil) - }, - wantErr: nil, - }, - { - test: "cannot list security groups", - dirName: "aws_vpc_default_security_group_empty", - mocks: func(client *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - client.On("ListAllSecurityGroups").Return(nil, nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsDefaultSecurityGroupResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsDefaultSecurityGroupResourceType, resourceaws.AwsDefaultSecurityGroupResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewVPCDefaultSecurityGroupEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsDefaultSecurityGroupResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsDefaultSecurityGroupResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsDefaultSecurityGroupResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2NatGateway(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no nat gateways", - dirName: "aws_ec2_nat_gateway_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllNatGateways").Return([]*ec2.NatGateway{}, nil) - }, - }, - { - test: "single nat gateway", - dirName: "aws_ec2_nat_gateway_single", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllNatGateways").Return([]*ec2.NatGateway{ - {NatGatewayId: awssdk.String("nat-0a5408508b19ef490")}, - }, nil) - }, - }, - { - test: "cannot list nat gateways", - dirName: "aws_ec2_nat_gateway_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllNatGateways").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsNatGatewayResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsNatGatewayResourceType, resourceaws.AwsNatGatewayResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2NatGatewayEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsNatGatewayResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsNatGatewayResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsNatGatewayResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2NetworkACL(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no network ACL", - dirName: "aws_ec2_network_acl_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllNetworkACLs").Return([]*ec2.NetworkAcl{}, nil) - }, - }, - { - test: "network acl", - dirName: "aws_ec2_network_acl", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllNetworkACLs").Return([]*ec2.NetworkAcl{ - { - NetworkAclId: awssdk.String("acl-043880b4682d2366b"), - IsDefault: awssdk.Bool(false), - }, - { - NetworkAclId: awssdk.String("acl-07a565dbe518c0713"), - IsDefault: awssdk.Bool(false), - }, - { - NetworkAclId: awssdk.String("acl-e88ee595"), - IsDefault: awssdk.Bool(true), - }, - }, nil) - }, - }, - { - test: "cannot list network acl", - dirName: "aws_ec2_network_acl_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllNetworkACLs").Return(nil, awsError) - - alerter.On("SendAlert", - resourceaws.AwsNetworkACLResourceType, - alerts.NewRemoteAccessDeniedAlert( - common.RemoteAWSTerraform, - remoteerr.NewResourceListingErrorWithType( - awsError, - resourceaws.AwsNetworkACLResourceType, - resourceaws.AwsNetworkACLResourceType, - ), - alerts.EnumerationPhase, - ), - ).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2NetworkACLEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsNetworkACLResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsNetworkACLResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsNetworkACLResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2NetworkACLRule(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no network ACL", - dirName: "aws_ec2_network_acl_rule_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllNetworkACLs").Return([]*ec2.NetworkAcl{}, nil) - }, - }, - { - test: "network acl rules", - dirName: "aws_ec2_network_acl_rule", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllNetworkACLs").Return([]*ec2.NetworkAcl{ - { - NetworkAclId: awssdk.String("acl-0ad6d657494d17ee2"), // test - IsDefault: awssdk.Bool(false), - Entries: []*ec2.NetworkAclEntry{ - { - Egress: awssdk.Bool(false), - RuleNumber: awssdk.Int64(100), - Protocol: awssdk.String("6"), // tcp - RuleAction: awssdk.String("deny"), - CidrBlock: awssdk.String("0.0.0.0/0"), - }, - { - Egress: awssdk.Bool(false), - RuleNumber: awssdk.Int64(200), - Protocol: awssdk.String("6"), // tcp - RuleAction: awssdk.String("allow"), - Ipv6CidrBlock: awssdk.String("::/0"), - }, - { - Egress: awssdk.Bool(true), - RuleNumber: awssdk.Int64(100), - Protocol: awssdk.String("17"), // udp - RuleAction: awssdk.String("allow"), - CidrBlock: awssdk.String("172.16.1.0/0"), - }, - }, - }, - { - NetworkAclId: awssdk.String("acl-0de54ef59074b622e"), // test2 - IsDefault: awssdk.Bool(false), - Entries: []*ec2.NetworkAclEntry{ - { - Egress: awssdk.Bool(false), - RuleNumber: awssdk.Int64(100), - Protocol: awssdk.String("17"), // udp - RuleAction: awssdk.String("deny"), - CidrBlock: awssdk.String("0.0.0.0/0"), - }, - { - Egress: awssdk.Bool(true), - RuleNumber: awssdk.Int64(100), - Protocol: awssdk.String("17"), // udp - RuleAction: awssdk.String("allow"), - CidrBlock: awssdk.String("172.16.1.0/0"), - }, - }, - }, - }, nil) - }, - }, - { - test: "cannot list network acl", - dirName: "aws_ec2_network_acl_rule_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllNetworkACLs").Return(nil, awsError) - - alerter.On("SendAlert", - resourceaws.AwsNetworkACLRuleResourceType, - alerts.NewRemoteAccessDeniedAlert( - common.RemoteAWSTerraform, - remoteerr.NewResourceListingErrorWithType( - awsError, - resourceaws.AwsNetworkACLRuleResourceType, - resourceaws.AwsNetworkACLResourceType, - ), - alerts.EnumerationPhase, - ), - ).Return() - }, - wantErr: nil, - }, - } - - version := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", version) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := version - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - } - - remoteLibrary.AddEnumerator(aws.NewEC2NetworkACLRuleEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsNetworkACLRuleResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsNetworkACLRuleResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsNetworkACLRuleResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2DefaultNetworkACL(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no network ACL", - dirName: "aws_ec2_default_network_acl_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllNetworkACLs").Return([]*ec2.NetworkAcl{}, nil) - }, - }, - { - test: "default network acl", - dirName: "aws_ec2_default_network_acl", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllNetworkACLs").Return([]*ec2.NetworkAcl{ - { - NetworkAclId: awssdk.String("acl-043880b4682d2366b"), - IsDefault: awssdk.Bool(false), - }, - { - NetworkAclId: awssdk.String("acl-07a565dbe518c0713"), - IsDefault: awssdk.Bool(false), - }, - { - NetworkAclId: awssdk.String("acl-e88ee595"), - IsDefault: awssdk.Bool(true), - }, - }, nil) - }, - }, - { - test: "cannot list default network acl", - dirName: "aws_ec2_default_network_acl_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllNetworkACLs").Return(nil, awsError) - - alerter.On("SendAlert", - resourceaws.AwsDefaultNetworkACLResourceType, - alerts.NewRemoteAccessDeniedAlert( - common.RemoteAWSTerraform, - remoteerr.NewResourceListingErrorWithType( - awsError, - resourceaws.AwsDefaultNetworkACLResourceType, - resourceaws.AwsDefaultNetworkACLResourceType, - ), - alerts.EnumerationPhase, - ), - ).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2DefaultNetworkACLEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsDefaultNetworkACLResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsDefaultNetworkACLResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsDefaultNetworkACLResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2Route(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - // route table with no routes case is not possible - // as a default route will always be present in each route table - test: "no routes", - dirName: "aws_ec2_route_empty", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{}, nil) - }, - }, - { - test: "multiple routes (mixed default_route_table and route_table)", - dirName: "aws_ec2_route_multiple", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRouteTables").Return([]*ec2.RouteTable{ - { - RouteTableId: awssdk.String("rtb-096bdfb69309c54c3"), // table1 - Routes: []*ec2.Route{ - { - DestinationCidrBlock: awssdk.String("10.0.0.0/16"), - Origin: awssdk.String("CreateRouteTable"), // default route - }, - { - DestinationCidrBlock: awssdk.String("1.1.1.1/32"), - GatewayId: awssdk.String("igw-030e74f73bd67f21b"), - Origin: awssdk.String("CreateRoute"), - }, - { - DestinationIpv6CidrBlock: awssdk.String("::/0"), - GatewayId: awssdk.String("igw-030e74f73bd67f21b"), - Origin: awssdk.String("CreateRoute"), - }, - }, - }, - { - RouteTableId: awssdk.String("rtb-0169b0937fd963ddc"), // table2 - Routes: []*ec2.Route{ - { - DestinationCidrBlock: awssdk.String("10.0.0.0/16"), - Origin: awssdk.String("CreateRouteTable"), // default route - }, - { - DestinationCidrBlock: awssdk.String("0.0.0.0/0"), - GatewayId: awssdk.String("igw-030e74f73bd67f21b"), - Origin: awssdk.String("CreateRoute"), - }, - { - DestinationIpv6CidrBlock: awssdk.String("::/0"), - GatewayId: awssdk.String("igw-030e74f73bd67f21b"), - Origin: awssdk.String("CreateRoute"), - }, - }, - }, - { - RouteTableId: awssdk.String("rtb-02780c485f0be93c5"), // default_table - VpcId: awssdk.String("vpc-09fe5abc2309ba49d"), - Associations: []*ec2.RouteTableAssociation{ - { - Main: awssdk.Bool(true), - }, - }, - Routes: []*ec2.Route{ - { - DestinationCidrBlock: awssdk.String("10.0.0.0/16"), - Origin: awssdk.String("CreateRouteTable"), // default route - }, - { - DestinationCidrBlock: awssdk.String("10.1.1.0/24"), - GatewayId: awssdk.String("igw-030e74f73bd67f21b"), - Origin: awssdk.String("CreateRoute"), - }, - { - DestinationCidrBlock: awssdk.String("10.1.2.0/24"), - GatewayId: awssdk.String("igw-030e74f73bd67f21b"), - Origin: awssdk.String("CreateRoute"), - }, - }, - }, - { - RouteTableId: awssdk.String(""), // table3 - Routes: []*ec2.Route{ - { - DestinationCidrBlock: awssdk.String("10.0.0.0/16"), - Origin: awssdk.String("CreateRouteTable"), // default route - }, - }, - }, - }, nil) - }, - }, - { - test: "cannot list routes", - dirName: "aws_ec2_route_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllRouteTables").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsRouteResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRouteResourceType, resourceaws.AwsRouteTableResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2RouteEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsRouteResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsRouteResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsRouteResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestVpcSecurityGroupRule(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no security group rules", - dirName: "aws_vpc_security_group_rule_empty", - mocks: func(client *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllSecurityGroups").Once().Return([]*ec2.SecurityGroup{ - { - GroupId: awssdk.String("sg-0254c038e32f25530"), - IpPermissions: []*ec2.IpPermission{}, - IpPermissionsEgress: []*ec2.IpPermission{}, - }, - }, nil, nil) - }, - wantErr: nil, - }, - { - test: "with security group rules", - dirName: "aws_vpc_security_group_rule_multiple", - mocks: func(client *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllSecurityGroups").Once().Return([]*ec2.SecurityGroup{ - { - GroupId: awssdk.String("sg-0254c038e32f25530"), - IpPermissions: []*ec2.IpPermission{ - { - FromPort: awssdk.Int64(0), - ToPort: awssdk.Int64(65535), - IpProtocol: awssdk.String("tcp"), - UserIdGroupPairs: []*ec2.UserIdGroupPair{ - { - GroupId: awssdk.String("sg-0254c038e32f25530"), - }, - { - GroupId: awssdk.String("sg-9e0204ff"), - }, - }, - }, - { - IpProtocol: awssdk.String("-1"), - IpRanges: []*ec2.IpRange{ - { - CidrIp: awssdk.String("1.2.0.0/16"), - }, - { - CidrIp: awssdk.String("5.6.7.0/24"), - }, - }, - Ipv6Ranges: []*ec2.Ipv6Range{ - { - CidrIpv6: awssdk.String("::/0"), - }, - }, - }, - }, - IpPermissionsEgress: []*ec2.IpPermission{ - { - IpProtocol: awssdk.String("-1"), - IpRanges: []*ec2.IpRange{ - { - CidrIp: awssdk.String("0.0.0.0/0"), - }, - }, - Ipv6Ranges: []*ec2.Ipv6Range{ - { - CidrIpv6: awssdk.String("::/0"), - }, - }, - }, - }, - }, - { - GroupId: awssdk.String("sg-0cc8b3c3c2851705a"), - IpPermissions: []*ec2.IpPermission{ - { - FromPort: awssdk.Int64(443), - ToPort: awssdk.Int64(443), - IpProtocol: awssdk.String("tcp"), - IpRanges: []*ec2.IpRange{ - { - CidrIp: awssdk.String("0.0.0.0/0"), - }, - }, - }, - }, - IpPermissionsEgress: []*ec2.IpPermission{ - { - IpProtocol: awssdk.String("-1"), - IpRanges: []*ec2.IpRange{ - { - CidrIp: awssdk.String("0.0.0.0/0"), - }, - }, - Ipv6Ranges: []*ec2.Ipv6Range{ - { - CidrIpv6: awssdk.String("::/0"), - }, - }, - }, - { - IpProtocol: awssdk.String("5"), - IpRanges: []*ec2.IpRange{ - { - CidrIp: awssdk.String("0.0.0.0/0"), - }, - }, - }, - }, - }, - }, nil, nil) - }, - wantErr: nil, - }, - { - test: "cannot list security group rules", - dirName: "aws_vpc_security_group_rule_empty", - mocks: func(client *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - client.On("ListAllSecurityGroups").Once().Return(nil, nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsSecurityGroupRuleResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSecurityGroupRuleResourceType, resourceaws.AwsSecurityGroupResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewVPCSecurityGroupRuleEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsSecurityGroupRuleResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsSecurityGroupRuleResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsSecurityGroupRuleResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestEC2LaunchTemplate(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no launch template", - dirName: "aws_launch_template", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("DescribeLaunchTemplates").Return([]*ec2.LaunchTemplate{}, nil) - }, - }, - { - test: "multiple launch templates", - dirName: "aws_launch_template_multiple", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - launchTemplates := []*ec2.LaunchTemplate{ - {LaunchTemplateId: awssdk.String("lt-0ed993d09ce6afc67"), LatestVersionNumber: awssdk.Int64(1)}, - {LaunchTemplateId: awssdk.String("lt-00b2d18c6cee7fe23"), LatestVersionNumber: awssdk.Int64(1)}, - } - - repository.On("DescribeLaunchTemplates").Return(launchTemplates, nil) - }, - }, - { - test: "cannot list launch templates", - dirName: "aws_launch_template", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("DescribeLaunchTemplates").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsLaunchTemplateResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsLaunchTemplateResourceType, resourceaws.AwsLaunchTemplateResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewLaunchTemplateEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsLaunchTemplateResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsLaunchTemplateResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsLaunchTemplateResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestEC2EbsEncryptionByDefault(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockEC2Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no encryption by default resource", - dirName: "aws_ebs_encryption_by_default_list", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - repository.On("IsEbsEncryptionEnabledByDefault").Return(false, nil) - }, - }, - { - test: "cannot list encryption by default resources", - dirName: "aws_ebs_encryption_by_default_error", - mocks: func(repository *repository.MockEC2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("IsEbsEncryptionEnabledByDefault").Return(false, awsError) - - alerter.On("SendAlert", resourceaws.AwsEbsEncryptionByDefaultResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsEbsEncryptionByDefaultResourceType, resourceaws.AwsEbsEncryptionByDefaultResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockEC2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.EC2Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewEC2Repository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewEC2EbsEncryptionByDefaultEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsEbsEncryptionByDefaultResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsEbsEncryptionByDefaultResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsEbsEncryptionByDefaultResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_ecr_scanner_test.go b/pkg/remote/aws_ecr_scanner_test.go deleted file mode 100644 index d7a97d49..00000000 --- a/pkg/remote/aws_ecr_scanner_test.go +++ /dev/null @@ -1,197 +0,0 @@ -package remote - -import ( - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ecr" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestECRRepository(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockECRRepository, *mocks.AlerterInterface) - err error - }{ - { - test: "no repository", - dirName: "aws_ecr_repository_empty", - mocks: func(client *repository.MockECRRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllRepositories").Return([]*ecr.Repository{}, nil) - }, - err: nil, - }, - { - test: "multiple repositories", - dirName: "aws_ecr_repository_multiple", - mocks: func(client *repository.MockECRRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllRepositories").Return([]*ecr.Repository{ - {RepositoryName: awssdk.String("test_ecr")}, - {RepositoryName: awssdk.String("bar")}, - }, nil) - }, - err: nil, - }, - { - test: "cannot list repository", - dirName: "aws_ecr_repository_empty", - mocks: func(client *repository.MockECRRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - client.On("ListAllRepositories").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsEcrRepositoryResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsEcrRepositoryResourceType, resourceaws.AwsEcrRepositoryResourceType), alerts.EnumerationPhase)).Return() - }, - err: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockECRRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ECRRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewECRRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewECRRepositoryEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsEcrRepositoryResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsEcrRepositoryResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsEcrRepositoryResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestECRRepositoryPolicy(t *testing.T) { - tests := []struct { - test string - mocks func(*repository.MockECRRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - err error - }{ - { - test: "single repository policy", - mocks: func(client *repository.MockECRRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllRepositories").Return([]*ecr.Repository{ - {RepositoryName: awssdk.String("test_ecr_repo_policy")}, - {RepositoryName: awssdk.String("test_ecr_repo_without_policy")}, - }, nil) - client.On("GetRepositoryPolicy", &ecr.Repository{ - RepositoryName: awssdk.String("test_ecr_repo_policy"), - }).Return(&ecr.GetRepositoryPolicyOutput{ - RegistryId: awssdk.String("1"), - RepositoryName: awssdk.String("test_ecr_repo_policy"), - }, nil) - client.On("GetRepositoryPolicy", &ecr.Repository{ - RepositoryName: awssdk.String("test_ecr_repo_without_policy"), - }).Return(nil, &ecr.RepositoryPolicyNotFoundException{}) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - - assert.Equal(t, got[0].ResourceId(), "test_ecr_repo_policy") - }, - err: nil, - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockECRRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ECRRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewECRRepositoryPolicyEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_elasticache_scanner_test.go b/pkg/remote/aws_elasticache_scanner_test.go deleted file mode 100644 index 59a72807..00000000 --- a/pkg/remote/aws_elasticache_scanner_test.go +++ /dev/null @@ -1,109 +0,0 @@ -package remote - -import ( - "errors" - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/elasticache" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestElastiCacheCluster(t *testing.T) { - dummyError := errors.New("dummy error") - - tests := []struct { - test string - mocks func(*repository.MockElastiCacheRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no elasticache clusters", - mocks: func(repository *repository.MockElastiCacheRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllCacheClusters").Return([]*elasticache.CacheCluster{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "should list elasticache clusters", - mocks: func(repository *repository.MockElastiCacheRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllCacheClusters").Return([]*elasticache.CacheCluster{ - {CacheClusterId: awssdk.String("cluster-foo")}, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - assert.Equal(t, got[0].ResourceId(), "cluster-foo") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsElastiCacheClusterResourceType) - }, - }, - { - test: "cannot list elasticache clusters (403)", - mocks: func(repository *repository.MockElastiCacheRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllCacheClusters").Return(nil, awsError) - alerter.On("SendAlert", resourceaws.AwsElastiCacheClusterResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsElastiCacheClusterResourceType, resourceaws.AwsElastiCacheClusterResourceType), alerts.EnumerationPhase)).Return() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "cannot list elasticache clusters (dummy error)", - mocks: func(repository *repository.MockElastiCacheRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllCacheClusters").Return(nil, dummyError) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - wantErr: remoteerr.NewResourceScanningError(dummyError, resourceaws.AwsElastiCacheClusterResourceType, ""), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockElastiCacheRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ElastiCacheRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewElastiCacheClusterEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_elb_scanner_test.go b/pkg/remote/aws_elb_scanner_test.go deleted file mode 100644 index 7104778e..00000000 --- a/pkg/remote/aws_elb_scanner_test.go +++ /dev/null @@ -1,116 +0,0 @@ -package remote - -import ( - "errors" - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/elb" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestELB_LoadBalancer(t *testing.T) { - dummyError := errors.New("dummy error") - - tests := []struct { - test string - mocks func(*repository.MockELBRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no load balancer", - mocks: func(repository *repository.MockELBRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllLoadBalancers").Return([]*elb.LoadBalancerDescription{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "should list load balancers", - mocks: func(repository *repository.MockELBRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllLoadBalancers").Return([]*elb.LoadBalancerDescription{ - { - LoadBalancerName: awssdk.String("acc-test-lb-tf"), - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - assert.Equal(t, "acc-test-lb-tf", got[0].ResourceId()) - assert.Equal(t, resourceaws.AwsClassicLoadBalancerResourceType, got[0].ResourceType()) - }, - }, - { - test: "cannot list load balancers", - mocks: func(repository *repository.MockELBRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllLoadBalancers").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsClassicLoadBalancerResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsClassicLoadBalancerResourceType, resourceaws.AwsClassicLoadBalancerResourceType), alerts.EnumerationPhase)).Return() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "cannot list load balancers (dummy error)", - mocks: func(repository *repository.MockELBRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllLoadBalancers").Return(nil, dummyError) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - wantErr: remoteerr.NewResourceScanningError(dummyError, resourceaws.AwsClassicLoadBalancerResourceType, ""), - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockELBRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ELBRepository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewClassicLoadBalancerEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_elbv2_scanner_test.go b/pkg/remote/aws_elbv2_scanner_test.go deleted file mode 100644 index 61371a5e..00000000 --- a/pkg/remote/aws_elbv2_scanner_test.go +++ /dev/null @@ -1,245 +0,0 @@ -package remote - -import ( - "errors" - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/elbv2" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestELBV2_LoadBalancer(t *testing.T) { - dummyError := errors.New("dummy error") - - tests := []struct { - test string - mocks func(*repository.MockELBV2Repository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no load balancer", - mocks: func(repository *repository.MockELBV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllLoadBalancers").Return([]*elbv2.LoadBalancer{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "should list load balancers", - mocks: func(repository *repository.MockELBV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllLoadBalancers").Return([]*elbv2.LoadBalancer{ - { - LoadBalancerArn: awssdk.String("arn:aws:elasticloadbalancing:us-east-1:533948124879:loadbalancer/app/acc-test-lb-tf/9114c60e08560420"), - LoadBalancerName: awssdk.String("acc-test-lb-tf"), - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - assert.Equal(t, "arn:aws:elasticloadbalancing:us-east-1:533948124879:loadbalancer/app/acc-test-lb-tf/9114c60e08560420", got[0].ResourceId()) - assert.Equal(t, resourceaws.AwsLoadBalancerResourceType, got[0].ResourceType()) - }, - }, - { - test: "cannot list load balancers (403)", - mocks: func(repository *repository.MockELBV2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllLoadBalancers").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsLoadBalancerResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsLoadBalancerResourceType, resourceaws.AwsLoadBalancerResourceType), alerts.EnumerationPhase)).Return() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "cannot list load balancers (dummy error)", - mocks: func(repository *repository.MockELBV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllLoadBalancers").Return(nil, dummyError) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - wantErr: remoteerr.NewResourceScanningError(dummyError, resourceaws.AwsLoadBalancerResourceType, ""), - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockELBV2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ELBV2Repository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewLoadBalancerEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestELBV2_LoadBalancerListener(t *testing.T) { - dummyError := errors.New("dummy error") - - tests := []struct { - test string - mocks func(*repository.MockELBV2Repository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no load balancer listener", - mocks: func(repository *repository.MockELBV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllLoadBalancers").Return([]*elbv2.LoadBalancer{ - { - LoadBalancerArn: awssdk.String("test-lb"), - }, - }, nil) - repository.On("ListAllLoadBalancerListeners", "test-lb").Return([]*elbv2.Listener{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "should list load balancer listener", - mocks: func(repository *repository.MockELBV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllLoadBalancers").Return([]*elbv2.LoadBalancer{ - { - LoadBalancerArn: awssdk.String("test-lb"), - }, - }, nil) - - repository.On("ListAllLoadBalancerListeners", "test-lb").Return([]*elbv2.Listener{ - { - ListenerArn: awssdk.String("test-lb-listener-1"), - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - assert.Equal(t, "test-lb-listener-1", got[0].ResourceId()) - assert.Equal(t, resourceaws.AwsLoadBalancerListenerResourceType, got[0].ResourceType()) - }, - }, - { - test: "cannot list load balancer listeners (403)", - mocks: func(repository *repository.MockELBV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllLoadBalancers").Return([]*elbv2.LoadBalancer{ - { - LoadBalancerArn: awssdk.String("test-lb"), - }, - }, nil) - - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllLoadBalancerListeners", "test-lb").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsLoadBalancerListenerResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingError(awsError, resourceaws.AwsLoadBalancerListenerResourceType), alerts.EnumerationPhase)).Return() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "cannot list load balancers (403)", - mocks: func(repository *repository.MockELBV2Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllLoadBalancers").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsLoadBalancerListenerResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsLoadBalancerListenerResourceType, resourceaws.AwsLoadBalancerResourceType), alerts.EnumerationPhase)).Return() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "cannot list load balancer listeners (dummy error)", - mocks: func(repository *repository.MockELBV2Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllLoadBalancers").Return([]*elbv2.LoadBalancer{ - { - LoadBalancerArn: awssdk.String("test-lb"), - }, - }, nil) - - repository.On("ListAllLoadBalancerListeners", "test-lb").Return(nil, dummyError) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - wantErr: remoteerr.NewResourceScanningError(dummyError, resourceaws.AwsLoadBalancerListenerResourceType, ""), - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockELBV2Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ELBV2Repository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewLoadBalancerListenerEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_iam_scanner_test.go b/pkg/remote/aws_iam_scanner_test.go deleted file mode 100644 index e84ee457..00000000 --- a/pkg/remote/aws_iam_scanner_test.go +++ /dev/null @@ -1,1322 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/iam" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - remoteaws "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestIamUser(t *testing.T) { - - cases := []struct { - test string - dirName string - mocks func(*repository.MockIAMRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no iam user", - dirName: "aws_iam_user_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllUsers").Return([]*iam.User{}, nil) - }, - wantErr: nil, - }, - { - test: "iam multiples users", - dirName: "aws_iam_user_multiple", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllUsers").Return([]*iam.User{ - { - UserName: aws.String("test-driftctl-0"), - }, - { - UserName: aws.String("test-driftctl-1"), - }, - { - UserName: aws.String("test-driftctl-2"), - }, - }, nil) - }, - wantErr: nil, - }, - { - test: "cannot list iam user", - dirName: "aws_iam_user_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repo.On("ListAllUsers").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsIamUserResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamUserResourceType, resourceaws.AwsIamUserResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockIAMRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.IAMRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewIAMRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(remoteaws.NewIamUserEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamUserResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsIamUserResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsIamUserResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestIamUserPolicy(t *testing.T) { - - cases := []struct { - test string - dirName string - mocks func(*repository.MockIAMRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no iam user policy", - dirName: "aws_iam_user_policy_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - users := []*iam.User{ - { - UserName: aws.String("loadbalancer"), - }, - } - repo.On("ListAllUsers").Return(users, nil) - repo.On("ListAllUserPolicies", users).Return([]string{}, nil) - }, - wantErr: nil, - }, - { - test: "iam multiples users multiple policies", - dirName: "aws_iam_user_policy_multiple", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - users := []*iam.User{ - { - UserName: aws.String("loadbalancer"), - }, - { - UserName: aws.String("loadbalancer2"), - }, - { - UserName: aws.String("loadbalancer3"), - }, - } - repo.On("ListAllUsers").Return(users, nil) - repo.On("ListAllUserPolicies", users).Once().Return([]string{ - *aws.String("loadbalancer:test"), - *aws.String("loadbalancer:test2"), - *aws.String("loadbalancer:test3"), - *aws.String("loadbalancer:test4"), - *aws.String("loadbalancer2:test2"), - *aws.String("loadbalancer2:test22"), - *aws.String("loadbalancer2:test23"), - *aws.String("loadbalancer2:test24"), - *aws.String("loadbalancer3:test3"), - *aws.String("loadbalancer3:test32"), - *aws.String("loadbalancer3:test33"), - *aws.String("loadbalancer3:test34"), - }, nil) - }, - wantErr: nil, - }, - { - test: "cannot list user", - dirName: "aws_iam_user_policy_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repo.On("ListAllUsers").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsIamUserPolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamUserPolicyResourceType, resourceaws.AwsIamUserResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - { - test: "cannot list user policy", - dirName: "aws_iam_user_policy_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllUsers").Once().Return([]*iam.User{}, nil) - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repo.On("ListAllUserPolicies", mock.Anything).Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsIamUserPolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamUserPolicyResourceType, resourceaws.AwsIamUserPolicyResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockIAMRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.IAMRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewIAMRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(remoteaws.NewIamUserPolicyEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamUserPolicyResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsIamUserPolicyResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsIamUserPolicyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestIamPolicy(t *testing.T) { - - cases := []struct { - test string - dirName string - mocks func(*repository.MockIAMRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no iam custom policies", - dirName: "aws_iam_policy_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllPolicies").Once().Return([]*iam.Policy{}, nil) - }, - wantErr: nil, - }, - { - test: "iam multiples custom policies", - dirName: "aws_iam_policy_multiple", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllPolicies").Once().Return([]*iam.Policy{ - { - Arn: aws.String("arn:aws:iam::929327065333:policy/policy-0"), - }, - { - Arn: aws.String("arn:aws:iam::929327065333:policy/policy-1"), - }, - { - Arn: aws.String("arn:aws:iam::929327065333:policy/policy-2"), - }, - }, nil) - }, - wantErr: nil, - }, - { - test: "cannot list iam custom policies", - dirName: "aws_iam_policy_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repo.On("ListAllPolicies").Once().Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsIamPolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamPolicyResourceType, resourceaws.AwsIamPolicyResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockIAMRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.IAMRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewIAMRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(remoteaws.NewIamPolicyEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamPolicyResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsIamPolicyResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsIamPolicyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestIamRole(t *testing.T) { - - cases := []struct { - test string - dirName string - mocks func(*repository.MockIAMRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no iam roles", - dirName: "aws_iam_role_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRoles").Return([]*iam.Role{}, nil) - }, - wantErr: nil, - }, - { - test: "iam multiples roles", - dirName: "aws_iam_role_multiple", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRoles").Return([]*iam.Role{ - { - RoleName: aws.String("test_role_0"), - Path: aws.String("/"), - }, - { - RoleName: aws.String("test_role_1"), - Path: aws.String("/"), - }, - { - RoleName: aws.String("test_role_2"), - Path: aws.String("/"), - }, - }, nil) - }, - wantErr: nil, - }, - { - test: "iam roles ignore services roles", - dirName: "aws_iam_role_ignore_services_roles", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRoles").Return([]*iam.Role{ - { - RoleName: aws.String("AWSServiceRoleForOrganizations"), - Path: aws.String("/aws-service-role/organizations.amazonaws.com/"), - }, - { - RoleName: aws.String("AWSServiceRoleForSupport"), - Path: aws.String("/aws-service-role/support.amazonaws.com/"), - }, - { - RoleName: aws.String("AWSServiceRoleForTrustedAdvisor"), - Path: aws.String("/aws-service-role/trustedadvisor.amazonaws.com/"), - }, - }, nil) - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockIAMRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.IAMRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewIAMRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(remoteaws.NewIamRoleEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamRoleResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsIamRoleResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsIamRoleResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestIamRolePolicyAttachment(t *testing.T) { - - cases := []struct { - test string - dirName string - mocks func(*repository.MockIAMRepository, *mocks.AlerterInterface) - err error - }{ - { - test: "no iam role policy", - dirName: "aws_aws_iam_role_policy_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - roles := []*iam.Role{ - { - RoleName: aws.String("test-role"), - }, - } - repo.On("ListAllRoles").Return(roles, nil) - repo.On("ListAllRolePolicyAttachments", roles).Return([]*repository.AttachedRolePolicy{}, nil) - }, - err: nil, - }, - { - test: "iam multiples roles multiple policies", - dirName: "aws_iam_role_policy_attachment_multiple", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - roles := []*iam.Role{ - { - RoleName: aws.String("test-role"), - }, - { - RoleName: aws.String("test-role2"), - }, - } - repo.On("ListAllRoles").Return(roles, nil) - repo.On("ListAllRolePolicyAttachments", roles).Return([]*repository.AttachedRolePolicy{ - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::929327065333:policy/test-policy"), - PolicyName: aws.String("test-policy"), - }, - RoleName: *aws.String("test-role"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::929327065333:policy/test-policy2"), - PolicyName: aws.String("test-policy2"), - }, - RoleName: *aws.String("test-role"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::929327065333:policy/test-policy3"), - PolicyName: aws.String("test-policy3"), - }, - RoleName: *aws.String("test-role"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::929327065333:policy/test-policy"), - PolicyName: aws.String("test-policy"), - }, - RoleName: *aws.String("test-role2"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::929327065333:policy/test-policy2"), - PolicyName: aws.String("test-policy2"), - }, - RoleName: *aws.String("test-role2"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::929327065333:policy/test-policy3"), - PolicyName: aws.String("test-policy3"), - }, - RoleName: *aws.String("test-role2"), - }, - }, nil) - }, - err: nil, - }, - { - test: "iam multiples roles for ignored roles", - dirName: "aws_iam_role_policy_attachment_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - roles := []*iam.Role{ - { - RoleName: aws.String("AWSServiceRoleForSupport"), - }, - { - RoleName: aws.String("AWSServiceRoleForOrganizations"), - }, - { - RoleName: aws.String("AWSServiceRoleForTrustedAdvisor"), - }, - } - repo.On("ListAllRoles").Return(roles, nil) - }, - }, - { - test: "Cannot list roles", - dirName: "aws_iam_role_policy_attachment_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repo.On("ListAllRoles").Once().Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsIamRolePolicyAttachmentResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamRolePolicyAttachmentResourceType, resourceaws.AwsIamRoleResourceType), alerts.EnumerationPhase)).Return() - }, - }, - { - test: "Cannot list roles policy attachment", - dirName: "aws_iam_role_policy_attachment_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRoles").Once().Return([]*iam.Role{{RoleName: aws.String("test")}}, nil) - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repo.On("ListAllRolePolicyAttachments", mock.Anything).Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsIamRolePolicyAttachmentResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamRolePolicyAttachmentResourceType, resourceaws.AwsIamRolePolicyAttachmentResourceType), alerts.EnumerationPhase)).Return() - }, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockIAMRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.IAMRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewIAMRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(remoteaws.NewIamRolePolicyAttachmentEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamRolePolicyAttachmentResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsIamRolePolicyAttachmentResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.err, err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsIamRolePolicyAttachmentResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestIamAccessKey(t *testing.T) { - - cases := []struct { - test string - dirName string - mocks func(*repository.MockIAMRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no iam access_key", - dirName: "aws_iam_access_key_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - users := []*iam.User{ - { - UserName: aws.String("test-driftctl"), - }, - } - repo.On("ListAllUsers").Return(users, nil) - repo.On("ListAllAccessKeys", users).Return([]*iam.AccessKeyMetadata{}, nil) - }, - wantErr: nil, - }, - { - test: "iam multiples keys for multiples users", - dirName: "aws_iam_access_key_multiple", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - users := []*iam.User{ - { - UserName: aws.String("test-driftctl"), - }, - } - repo.On("ListAllUsers").Return(users, nil) - repo.On("ListAllAccessKeys", users).Return([]*iam.AccessKeyMetadata{ - { - AccessKeyId: aws.String("AKIA5QYBVVD223VWU32A"), - UserName: aws.String("test-driftctl"), - }, - { - AccessKeyId: aws.String("AKIA5QYBVVD2QYI36UZP"), - UserName: aws.String("test-driftctl"), - }, - { - AccessKeyId: aws.String("AKIA5QYBVVD26EJME25D"), - UserName: aws.String("test-driftctl2"), - }, - { - AccessKeyId: aws.String("AKIA5QYBVVD2SWDFVVMG"), - UserName: aws.String("test-driftctl2"), - }, - }, nil) - }, - wantErr: nil, - }, - { - test: "Cannot list iam user", - dirName: "aws_iam_access_key_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repo.On("ListAllUsers").Once().Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsIamAccessKeyResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamAccessKeyResourceType, resourceaws.AwsIamUserResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - { - test: "Cannot list iam access_key", - dirName: "aws_iam_access_key_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllUsers").Once().Return([]*iam.User{}, nil) - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repo.On("ListAllAccessKeys", mock.Anything).Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsIamAccessKeyResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamAccessKeyResourceType, resourceaws.AwsIamAccessKeyResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockIAMRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.IAMRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewIAMRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(remoteaws.NewIamAccessKeyEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamAccessKeyResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsIamAccessKeyResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsIamAccessKeyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestIamUserPolicyAttachment(t *testing.T) { - - cases := []struct { - test string - dirName string - mocks func(*repository.MockIAMRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no iam user policy", - dirName: "aws_iam_user_policy_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - users := []*iam.User{ - { - UserName: aws.String("loadbalancer"), - }, - } - repo.On("ListAllUsers").Return(users, nil) - repo.On("ListAllUserPolicyAttachments", users).Return([]*repository.AttachedUserPolicy{}, nil) - }, - wantErr: nil, - }, - { - test: "iam multiples users multiple policies", - dirName: "aws_iam_user_policy_attachment_multiple", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - users := []*iam.User{ - { - UserName: aws.String("loadbalancer"), - }, - { - UserName: aws.String("loadbalancer2"), - }, - { - UserName: aws.String("loadbalancer3"), - }, - } - repo.On("ListAllUsers").Return(users, nil) - repo.On("ListAllUserPolicyAttachments", users).Return([]*repository.AttachedUserPolicy{ - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test"), - PolicyName: aws.String("test"), - }, - UserName: *aws.String("loadbalancer"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test2"), - PolicyName: aws.String("test2"), - }, - UserName: *aws.String("loadbalancer"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test3"), - PolicyName: aws.String("test3"), - }, - UserName: *aws.String("loadbalancer"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test4"), - PolicyName: aws.String("test4"), - }, - UserName: *aws.String("loadbalancer"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test"), - PolicyName: aws.String("test"), - }, - UserName: *aws.String("loadbalancer2"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test2"), - PolicyName: aws.String("test2"), - }, - UserName: *aws.String("loadbalancer2"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test3"), - PolicyName: aws.String("test3"), - }, - UserName: *aws.String("loadbalancer2"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test4"), - PolicyName: aws.String("test4"), - }, - UserName: *aws.String("loadbalancer2"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test"), - PolicyName: aws.String("test"), - }, - UserName: *aws.String("loadbalancer3"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test2"), - PolicyName: aws.String("test2"), - }, - UserName: *aws.String("loadbalancer3"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test3"), - PolicyName: aws.String("test3"), - }, - UserName: *aws.String("loadbalancer3"), - }, - { - AttachedPolicy: iam.AttachedPolicy{ - PolicyArn: aws.String("arn:aws:iam::726421854799:policy/test4"), - PolicyName: aws.String("test4"), - }, - UserName: *aws.String("loadbalancer3"), - }, - }, nil) - - }, - wantErr: nil, - }, - { - test: "cannot list user", - dirName: "aws_iam_user_policy_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repo.On("ListAllUsers").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsIamUserPolicyAttachmentResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamUserPolicyAttachmentResourceType, resourceaws.AwsIamUserResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - { - test: "cannot list user policies attachment", - dirName: "aws_iam_user_policy_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllUsers").Once().Return([]*iam.User{}, nil) - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repo.On("ListAllUserPolicyAttachments", mock.Anything).Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsIamUserPolicyAttachmentResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamUserPolicyAttachmentResourceType, resourceaws.AwsIamUserPolicyAttachmentResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockIAMRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.IAMRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewIAMRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(remoteaws.NewIamUserPolicyAttachmentEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamUserPolicyAttachmentResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsIamUserPolicyAttachmentResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsIamUserPolicyAttachmentResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestIamRolePolicy(t *testing.T) { - - cases := []struct { - test string - dirName string - mocks func(*repository.MockIAMRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no iam role policy", - dirName: "aws_iam_role_policy_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - roles := []*iam.Role{ - { - RoleName: aws.String("test_role"), - }, - } - repo.On("ListAllRoles").Return(roles, nil) - repo.On("ListAllRolePolicies", roles).Return([]repository.RolePolicy{}, nil) - }, - wantErr: nil, - }, - { - test: "multiples roles with inline policies", - dirName: "aws_iam_role_policy_multiple", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - roles := []*iam.Role{ - { - RoleName: aws.String("test_role_0"), - }, - { - RoleName: aws.String("test_role_1"), - }, - } - repo.On("ListAllRoles").Return(roles, nil) - repo.On("ListAllRolePolicies", roles).Return([]repository.RolePolicy{ - {Policy: "policy-role0-0", RoleName: "test_role_0"}, - {Policy: "policy-role0-1", RoleName: "test_role_0"}, - {Policy: "policy-role0-2", RoleName: "test_role_0"}, - {Policy: "policy-role1-0", RoleName: "test_role_1"}, - {Policy: "policy-role1-1", RoleName: "test_role_1"}, - {Policy: "policy-role1-2", RoleName: "test_role_1"}, - }, nil).Once() - }, - wantErr: nil, - }, - { - test: "Cannot list roles", - dirName: "aws_iam_role_policy_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repo.On("ListAllRoles").Once().Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsIamRolePolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamRolePolicyResourceType, resourceaws.AwsIamRoleResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - { - test: "cannot list role policy", - dirName: "aws_iam_role_policy_empty", - mocks: func(repo *repository.MockIAMRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllRoles").Once().Return([]*iam.Role{}, nil) - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repo.On("ListAllRolePolicies", mock.Anything).Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsIamRolePolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsIamRolePolicyResourceType, resourceaws.AwsIamRolePolicyResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockIAMRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.IAMRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewIAMRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(remoteaws.NewIamRolePolicyEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsIamRolePolicyResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsIamRolePolicyResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsIamRolePolicyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestIamGroupPolicy(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockIAMRepository) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "multiple groups, with multiples policies", - mocks: func(repository *repository.MockIAMRepository) { - repository.On("ListAllGroups").Return(nil, nil) - repository.On("ListAllGroupPolicies", []*iam.Group(nil)). - Return([]string{"group1:policy1", "group2:policy2"}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - assert.Equal(t, resourceaws.AwsIamGroupPolicyResourceType, got[0].ResourceType()) - assert.Equal(t, "group1:policy1", got[0].ResourceId()) - assert.Equal(t, resourceaws.AwsIamGroupPolicyResourceType, got[1].ResourceType()) - assert.Equal(t, "group2:policy2", got[1].ResourceId()) - }, - }, - { - test: "cannot list groups", - mocks: func(repository *repository.MockIAMRepository) { - repository.On("ListAllGroups").Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsIamGroupPolicyResourceType, resourceaws.AwsIamGroupResourceType), - }, - { - test: "cannot list policies", - mocks: func(repository *repository.MockIAMRepository) { - repository.On("ListAllGroups").Return(nil, nil) - repository.On("ListAllGroupPolicies", []*iam.Group(nil)).Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsIamGroupPolicyResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockIAMRepository{} - c.mocks(fakeRepo) - - var repo repository.IAMRepository = fakeRepo - - remoteLibrary.AddEnumerator(remoteaws.NewIamGroupPolicyEnumerator( - repo, factory, - )) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestIamGroup(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockIAMRepository) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "multiple groups, with multiples groups", - mocks: func(repository *repository.MockIAMRepository) { - repository.On("ListAllGroups").Return([]*iam.Group{ - { - GroupName: aws.String("group1"), - }, - { - GroupName: aws.String("group2"), - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - assert.Equal(t, resourceaws.AwsIamGroupResourceType, got[0].ResourceType()) - assert.Equal(t, "group1", got[0].ResourceId()) - assert.Equal(t, resourceaws.AwsIamGroupResourceType, got[1].ResourceType()) - assert.Equal(t, "group2", got[1].ResourceId()) - }, - }, - { - test: "cannot list groups", - mocks: func(repository *repository.MockIAMRepository) { - repository.On("ListAllGroups").Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsIamGroupResourceType), - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockIAMRepository{} - c.mocks(fakeRepo) - - var repo repository.IAMRepository = fakeRepo - - remoteLibrary.AddEnumerator(remoteaws.NewIamGroupEnumerator( - repo, factory, - )) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_kms_scanner_test.go b/pkg/remote/aws_kms_scanner_test.go deleted file mode 100644 index 90f4bdfc..00000000 --- a/pkg/remote/aws_kms_scanner_test.go +++ /dev/null @@ -1,224 +0,0 @@ -package remote - -import ( - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kms" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestKMSKey(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockKMSRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no keys", - dirName: "aws_kms_key_empty", - mocks: func(repository *repository.MockKMSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllKeys").Return([]*kms.KeyListEntry{}, nil) - }, - }, - { - test: "multiple keys", - dirName: "aws_kms_key_multiple", - mocks: func(repository *repository.MockKMSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllKeys").Return([]*kms.KeyListEntry{ - {KeyId: awssdk.String("8ee21d91-c000-428c-8032-235aac55da36")}, - {KeyId: awssdk.String("5d765f32-bfdc-4610-b6ab-f82db5d0601b")}, - {KeyId: awssdk.String("89d2c023-ea53-40a5-b20a-d84905c622d7")}, - }, nil) - }, - }, - { - test: "cannot list keys", - dirName: "aws_kms_key_list", - mocks: func(repository *repository.MockKMSRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllKeys").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsKmsKeyResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsKmsKeyResourceType, resourceaws.AwsKmsKeyResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockKMSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.KMSRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewKMSRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewKMSKeyEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsKmsKeyResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsKmsKeyResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsKmsKeyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestKMSAlias(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockKMSRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no aliases", - dirName: "aws_kms_alias_empty", - mocks: func(repository *repository.MockKMSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllAliases").Return([]*kms.AliasListEntry{}, nil) - }, - }, - { - test: "multiple aliases", - dirName: "aws_kms_alias_multiple", - mocks: func(repository *repository.MockKMSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllAliases").Return([]*kms.AliasListEntry{ - {AliasName: awssdk.String("alias/foo")}, - {AliasName: awssdk.String("alias/bar")}, - {AliasName: awssdk.String("alias/baz20210225124429210500000001")}, - }, nil) - }, - }, - { - test: "cannot list aliases", - dirName: "aws_kms_alias_list", - mocks: func(repository *repository.MockKMSRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllAliases").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsKmsAliasResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsKmsAliasResourceType, resourceaws.AwsKmsAliasResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockKMSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.KMSRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewKMSRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewKMSAliasEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsKmsAliasResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsKmsAliasResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsKmsAliasResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_lambda_scanner_test.go b/pkg/remote/aws_lambda_scanner_test.go deleted file mode 100644 index f29001b1..00000000 --- a/pkg/remote/aws_lambda_scanner_test.go +++ /dev/null @@ -1,262 +0,0 @@ -package remote - -import ( - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/lambda" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/mock" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - "github.com/stretchr/testify/assert" -) - -func TestScanLambdaFunction(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockLambdaRepository, *mocks.AlerterInterface) - err error - }{ - { - test: "no lambda functions", - dirName: "aws_lambda_function_empty", - mocks: func(repo *repository.MockLambdaRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllLambdaFunctions").Return([]*lambda.FunctionConfiguration{}, nil) - }, - err: nil, - }, - { - test: "with lambda functions", - dirName: "aws_lambda_function_multiple", - mocks: func(repo *repository.MockLambdaRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllLambdaFunctions").Return([]*lambda.FunctionConfiguration{ - { - FunctionName: awssdk.String("foo"), - }, - { - FunctionName: awssdk.String("bar"), - }, - }, nil) - }, - err: nil, - }, - { - test: "One lambda with signing", - dirName: "aws_lambda_function_signed", - mocks: func(repo *repository.MockLambdaRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllLambdaFunctions").Return([]*lambda.FunctionConfiguration{ - { - FunctionName: awssdk.String("foo"), - }, - }, nil) - }, - err: nil, - }, - { - test: "cannot list lambda functions", - dirName: "aws_lambda_function_empty", - mocks: func(repo *repository.MockLambdaRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repo.On("ListAllLambdaFunctions").Return([]*lambda.FunctionConfiguration{}, awsError) - - alerter.On("SendAlert", resourceaws.AwsLambdaFunctionResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsLambdaFunctionResourceType, resourceaws.AwsLambdaFunctionResourceType), alerts.EnumerationPhase)).Return() - }, - err: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - session := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockLambdaRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.LambdaRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewLambdaRepository(session, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewLambdaFunctionEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsLambdaFunctionResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsLambdaFunctionResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.err, err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsLambdaFunctionResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestScanLambdaEventSourceMapping(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockLambdaRepository, *mocks.AlerterInterface) - err error - }{ - { - test: "no EventSourceMapping", - dirName: "aws_lambda_source_mapping_empty", - mocks: func(repo *repository.MockLambdaRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllLambdaEventSourceMappings").Return([]*lambda.EventSourceMappingConfiguration{}, nil) - }, - err: nil, - }, - { - test: "with 2 sqs EventSourceMapping", - dirName: "aws_lambda_source_mapping_sqs_multiple", - mocks: func(repo *repository.MockLambdaRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllLambdaEventSourceMappings").Return([]*lambda.EventSourceMappingConfiguration{ - { - UUID: awssdk.String("13ff66f8-37eb-4ad6-a0a8-594fea72df4f"), - }, - { - UUID: awssdk.String("4ad7e2b3-79e9-4713-9d9d-5af2c01d9058"), - }, - }, nil) - }, - err: nil, - }, - { - test: "with dynamo EventSourceMapping", - dirName: "aws_lambda_source_mapping_dynamo_multiple", - mocks: func(repo *repository.MockLambdaRepository, alerter *mocks.AlerterInterface) { - repo.On("ListAllLambdaEventSourceMappings").Return([]*lambda.EventSourceMappingConfiguration{ - { - UUID: awssdk.String("1aa9c4a0-060b-41c1-a9ae-dc304ebcdb00"), - }, - }, nil) - }, - err: nil, - }, - { - test: "cannot list lambda functions", - dirName: "aws_lambda_function_empty", - mocks: func(repo *repository.MockLambdaRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repo.On("ListAllLambdaEventSourceMappings").Return([]*lambda.EventSourceMappingConfiguration{}, awsError) - - alerter.On("SendAlert", resourceaws.AwsLambdaEventSourceMappingResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsLambdaEventSourceMappingResourceType, resourceaws.AwsLambdaEventSourceMappingResourceType), alerts.EnumerationPhase)).Return() - }, - err: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - session := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockLambdaRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.LambdaRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewLambdaRepository(session, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewLambdaEventSourceMappingEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsLambdaEventSourceMappingResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsLambdaEventSourceMappingResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.err, err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsLambdaEventSourceMappingResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_rds_scanner_test.go b/pkg/remote/aws_rds_scanner_test.go deleted file mode 100644 index cb0618e1..00000000 --- a/pkg/remote/aws_rds_scanner_test.go +++ /dev/null @@ -1,333 +0,0 @@ -package remote - -import ( - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestRDSDBInstance(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockRDSRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no db instances", - dirName: "aws_rds_db_instance_empty", - mocks: func(repository *repository.MockRDSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllDBInstances").Return([]*rds.DBInstance{}, nil) - }, - }, - { - test: "single db instance", - dirName: "aws_rds_db_instance_single", - mocks: func(repository *repository.MockRDSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllDBInstances").Return([]*rds.DBInstance{ - {DBInstanceIdentifier: awssdk.String("terraform-20201015115018309600000001")}, - }, nil) - }, - }, - { - test: "multiple mixed db instances", - dirName: "aws_rds_db_instance_multiple", - mocks: func(repository *repository.MockRDSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllDBInstances").Return([]*rds.DBInstance{ - {DBInstanceIdentifier: awssdk.String("terraform-20201015115018309600000001")}, - {DBInstanceIdentifier: awssdk.String("database-1")}, - }, nil) - }, - }, - { - test: "cannot list db instances", - dirName: "aws_rds_db_instance_list", - mocks: func(repository *repository.MockRDSRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllDBInstances").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsDbInstanceResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsDbInstanceResourceType, resourceaws.AwsDbInstanceResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockRDSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.RDSRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewRDSRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewRDSDBInstanceEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsDbInstanceResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsDbInstanceResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsDbInstanceResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestRDSDBSubnetGroup(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockRDSRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no db subnet groups", - dirName: "aws_rds_db_subnet_group_empty", - mocks: func(repository *repository.MockRDSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllDBSubnetGroups").Return([]*rds.DBSubnetGroup{}, nil) - }, - }, - { - test: "multiple db subnet groups", - dirName: "aws_rds_db_subnet_group_multiple", - mocks: func(repository *repository.MockRDSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllDBSubnetGroups").Return([]*rds.DBSubnetGroup{ - {DBSubnetGroupName: awssdk.String("foo")}, - {DBSubnetGroupName: awssdk.String("bar")}, - }, nil) - }, - }, - { - test: "cannot list db subnet groups", - dirName: "aws_rds_db_subnet_group_list", - mocks: func(repository *repository.MockRDSRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllDBSubnetGroups").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsDbSubnetGroupResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsDbSubnetGroupResourceType, resourceaws.AwsDbSubnetGroupResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockRDSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.RDSRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewRDSRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewRDSDBSubnetGroupEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsDbSubnetGroupResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsDbSubnetGroupResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsDbSubnetGroupResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestRDSCluster(t *testing.T) { - tests := []struct { - test string - dirName string - mocks func(*repository.MockRDSRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no cluster", - dirName: "aws_rds_cluster_empty", - mocks: func(repository *repository.MockRDSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllDBClusters").Return([]*rds.DBCluster{}, nil) - }, - }, - { - test: "should return one result", - dirName: "aws_rds_clusters_results", - mocks: func(repository *repository.MockRDSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllDBClusters").Return([]*rds.DBCluster{ - { - DBClusterIdentifier: awssdk.String("aurora-cluster-demo"), - DatabaseName: awssdk.String("mydb"), - }, - { - DBClusterIdentifier: awssdk.String("aurora-cluster-demo-2"), - }, - }, nil) - }, - }, - { - test: "cannot list clusters", - dirName: "aws_rds_cluster_denied", - mocks: func(repository *repository.MockRDSRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 400, "") - repository.On("ListAllDBClusters").Return(nil, awsError).Once() - - alerter.On("SendAlert", resourceaws.AwsRDSClusterResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRDSClusterResourceType, resourceaws.AwsRDSClusterResourceType), alerts.EnumerationPhase)).Return().Once() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockRDSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.RDSRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewRDSRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewRDSClusterEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsRDSClusterResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsRDSClusterResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsRDSClusterResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_route53_scanner_test.go b/pkg/remote/aws_route53_scanner_test.go deleted file mode 100644 index 61dce083..00000000 --- a/pkg/remote/aws_route53_scanner_test.go +++ /dev/null @@ -1,475 +0,0 @@ -package remote - -import ( - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/route53" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/mock" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - "github.com/stretchr/testify/assert" -) - -func TestRoute53_HealthCheck(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockRoute53Repository, *mocks.AlerterInterface) - err error - }{ - { - test: "no health check", - dirName: "aws_route53_health_check_empty", - mocks: func(client *repository.MockRoute53Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllHealthChecks").Return([]*route53.HealthCheck{}, nil) - }, - err: nil, - }, - { - test: "Multiple health check", - dirName: "aws_route53_health_check_multiple", - mocks: func(client *repository.MockRoute53Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllHealthChecks").Return([]*route53.HealthCheck{ - {Id: awssdk.String("7001a9df-ded4-4802-9909-668eb80b972b")}, - {Id: awssdk.String("84fc318a-2e0d-41d6-b638-280e2f0f4e26")}, - }, nil) - }, - err: nil, - }, - { - test: "cannot list health check", - dirName: "aws_route53_health_check_empty", - mocks: func(client *repository.MockRoute53Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - client.On("ListAllHealthChecks").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsRoute53HealthCheckResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRoute53HealthCheckResourceType, resourceaws.AwsRoute53HealthCheckResourceType), alerts.EnumerationPhase)).Return() - }, - err: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - session := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockRoute53Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.Route53Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewRoute53Repository(session, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewRoute53HealthCheckEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsRoute53HealthCheckResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsRoute53HealthCheckResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.err, err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsRoute53HealthCheckResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestRoute53_Zone(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockRoute53Repository, *mocks.AlerterInterface) - err error - }{ - { - test: "no zones", - dirName: "aws_route53_zone_empty", - mocks: func(client *repository.MockRoute53Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllZones").Return( - []*route53.HostedZone{}, - nil, - ) - }, - err: nil, - }, - { - test: "single zone", - dirName: "aws_route53_zone_single", - mocks: func(client *repository.MockRoute53Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllZones").Return( - []*route53.HostedZone{ - { - Id: awssdk.String("Z08068311RGDXPHF8KE62"), - Name: awssdk.String("foo.bar"), - }, - }, - nil, - ) - }, - err: nil, - }, - { - test: "multiples zone (test pagination)", - dirName: "aws_route53_zone_multiples", - mocks: func(client *repository.MockRoute53Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllZones").Return( - []*route53.HostedZone{ - { - Id: awssdk.String("Z01809283VH9BBALZHO7B"), - Name: awssdk.String("foo-0.com"), - }, - { - Id: awssdk.String("Z01804312AV8PHE3C43AD"), - Name: awssdk.String("foo-1.com"), - }, - { - Id: awssdk.String("Z01874941AR1TCGV5K65C"), - Name: awssdk.String("foo-2.com"), - }, - }, - nil, - ) - }, - err: nil, - }, - { - test: "cannot list zones", - dirName: "aws_route53_zone_empty", - mocks: func(client *repository.MockRoute53Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - client.On("ListAllZones").Return( - []*route53.HostedZone{}, - awsError, - ) - - alerter.On("SendAlert", resourceaws.AwsRoute53ZoneResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRoute53ZoneResourceType, resourceaws.AwsRoute53ZoneResourceType), alerts.EnumerationPhase)).Return() - }, - err: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - session := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockRoute53Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.Route53Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewRoute53Repository(session, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewRoute53ZoneEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsRoute53ZoneResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsRoute53ZoneResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.err, err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsRoute53ZoneResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestRoute53_Record(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockRoute53Repository, *mocks.AlerterInterface) - err error - }{ - { - test: "no records", - dirName: "aws_route53_zone_with_no_record", - mocks: func(client *repository.MockRoute53Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllZones").Return( - []*route53.HostedZone{ - { - Id: awssdk.String("Z1035360GLIB82T1EH2G"), - Name: awssdk.String("foo-0.com"), - }, - }, - nil, - ) - client.On("ListRecordsForZone", "Z1035360GLIB82T1EH2G").Return([]*route53.ResourceRecordSet{}, nil) - }, - err: nil, - }, - { - test: "multiples records in multiples zones", - dirName: "aws_route53_record_multiples", - mocks: func(client *repository.MockRoute53Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllZones").Return( - []*route53.HostedZone{ - { - Id: awssdk.String("Z1035360GLIB82T1EH2G"), - Name: awssdk.String("foo-0.com"), - }, - { - Id: awssdk.String("Z10347383HV75H96J919W"), - Name: awssdk.String("foo-1.com"), - }, - }, - nil, - ) - client.On("ListRecordsForZone", "Z1035360GLIB82T1EH2G").Return([]*route53.ResourceRecordSet{ - { - Name: awssdk.String("foo-0.com"), - Type: awssdk.String("NS"), - }, - { - Name: awssdk.String("test0"), - Type: awssdk.String("A"), - }, - { - Name: awssdk.String("test1"), - Type: awssdk.String("A"), - }, - { - Name: awssdk.String("test2"), - Type: awssdk.String("A"), - }, - { - Name: awssdk.String("\\052.test4."), - Type: awssdk.String("A"), - }, - }, nil) - client.On("ListRecordsForZone", "Z10347383HV75H96J919W").Return([]*route53.ResourceRecordSet{ - { - Name: awssdk.String("test2"), - Type: awssdk.String("A"), - }, - }, nil) - }, - err: nil, - }, - { - test: "explicit subdomain records", - dirName: "aws_route53_record_explicit_subdomain", - mocks: func(client *repository.MockRoute53Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllZones").Return( - []*route53.HostedZone{ - { - Id: awssdk.String("Z06486383UC8WYSBZTWFM"), - Name: awssdk.String("foo-2.com"), - }, - }, - nil, - ) - client.On("ListRecordsForZone", "Z06486383UC8WYSBZTWFM").Return([]*route53.ResourceRecordSet{ - { - Name: awssdk.String("test0"), - Type: awssdk.String("TXT"), - }, - { - Name: awssdk.String("test0"), - Type: awssdk.String("A"), - }, - { - Name: awssdk.String("test1.foo-2.com"), - Type: awssdk.String("TXT"), - }, - { - Name: awssdk.String("test1.foo-2.com"), - Type: awssdk.String("A"), - }, - { - Name: awssdk.String("_test2.foo-2.com"), - Type: awssdk.String("TXT"), - }, - { - Name: awssdk.String("_test2.foo-2.com"), - Type: awssdk.String("A"), - }, - }, nil) - }, - err: nil, - }, - { - test: "cannot list zones", - dirName: "aws_route53_zone_with_no_record", - mocks: func(client *repository.MockRoute53Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - client.On("ListAllZones").Return( - []*route53.HostedZone{}, - awsError) - - alerter.On("SendAlert", resourceaws.AwsRoute53RecordResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRoute53RecordResourceType, resourceaws.AwsRoute53ZoneResourceType), alerts.EnumerationPhase)).Return() - }, - err: nil, - }, - { - test: "cannot list records", - dirName: "aws_route53_zone_with_no_record", - mocks: func(client *repository.MockRoute53Repository, alerter *mocks.AlerterInterface) { - client.On("ListAllZones").Return( - []*route53.HostedZone{ - { - Id: awssdk.String("Z06486383UC8WYSBZTWFM"), - Name: awssdk.String("foo-2.com"), - }, - }, - nil) - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - client.On("ListRecordsForZone", "Z06486383UC8WYSBZTWFM").Return( - []*route53.ResourceRecordSet{}, awsError) - - alerter.On("SendAlert", resourceaws.AwsRoute53RecordResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsRoute53RecordResourceType, resourceaws.AwsRoute53RecordResourceType), alerts.EnumerationPhase)).Return() - }, - err: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - session := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockRoute53Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.Route53Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewRoute53Repository(session, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewRoute53RecordEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsRoute53RecordResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsRoute53RecordResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.err, err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsRoute53RecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_s3_scanner_test.go b/pkg/remote/aws_s3_scanner_test.go deleted file mode 100644 index 453b2e03..00000000 --- a/pkg/remote/aws_s3_scanner_test.go +++ /dev/null @@ -1,1082 +0,0 @@ -package remote - -import ( - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - tf "github.com/snyk/driftctl/pkg/remote/terraform" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/mock" - - "github.com/snyk/driftctl/pkg/remote/aws/client" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - "github.com/stretchr/testify/assert" -) - -func TestS3Bucket(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockS3Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "multiple bucket", dirName: "aws_s3_bucket_multiple", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - repository.On( - "ListAllBuckets", - ).Return([]*s3.Bucket{ - {Name: awssdk.String("bucket-martin-test-drift")}, - {Name: awssdk.String("bucket-martin-test-drift2")}, - {Name: awssdk.String("bucket-martin-test-drift3")}, - }, nil) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift", - ).Return( - "eu-west-1", - nil, - ) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift2", - ).Return( - "eu-west-3", - nil, - ) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift3", - ).Return( - "ap-northeast-1", - nil, - ) - }, - }, - { - test: "cannot list bucket", dirName: "aws_s3_bucket_list", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllBuckets").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsS3BucketResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketResourceType, resourceaws.AwsS3BucketResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - session := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockS3Repository{} - c.mocks(fakeRepo, alerter) - var repo repository.S3Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewS3Repository(client.NewAWSClientFactory(session), cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewS3BucketEnumerator(repo, factory, tf.TerraformProviderConfig{ - Name: "test", - DefaultAlias: "eu-west-3", - }, alerter)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsS3BucketResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsS3BucketResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsS3BucketResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestS3BucketInventory(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockS3Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "multiple bucket with multiple inventories", dirName: "aws_s3_bucket_inventories_multiple", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - repository.On( - "ListAllBuckets", - ).Return([]*s3.Bucket{ - {Name: awssdk.String("bucket-martin-test-drift")}, - {Name: awssdk.String("bucket-martin-test-drift2")}, - {Name: awssdk.String("bucket-martin-test-drift3")}, - }, nil) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift", - ).Return( - "eu-west-1", - nil, - ) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift2", - ).Return( - "eu-west-3", - nil, - ) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift3", - ).Return( - "eu-west-1", - nil, - ) - - repository.On( - "ListBucketInventoryConfigurations", - &s3.Bucket{Name: awssdk.String("bucket-martin-test-drift2")}, - "eu-west-3", - ).Return( - []*s3.InventoryConfiguration{ - {Id: awssdk.String("Inventory_Bucket2")}, - {Id: awssdk.String("Inventory2_Bucket2")}, - }, - nil, - ) - }, - }, - { - test: "cannot list bucket", dirName: "aws_s3_bucket_inventories_list_bucket", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllBuckets").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsS3BucketInventoryResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketInventoryResourceType, resourceaws.AwsS3BucketResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - { - test: "cannot list bucket inventories", dirName: "aws_s3_bucket_inventories_list_inventories", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllBuckets").Return( - []*s3.Bucket{ - {Name: awssdk.String("bucket-martin-test-drift")}, - }, - nil, - ) - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift", - ).Return( - "eu-west-3", - nil, - ) - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On( - "ListBucketInventoryConfigurations", - &s3.Bucket{Name: awssdk.String("bucket-martin-test-drift")}, - "eu-west-3", - ).Return( - nil, - awsError, - ) - - alerter.On("SendAlert", resourceaws.AwsS3BucketInventoryResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketInventoryResourceType, resourceaws.AwsS3BucketInventoryResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - session := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockS3Repository{} - c.mocks(fakeRepo, alerter) - var repo repository.S3Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewS3Repository(client.NewAWSClientFactory(session), cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewS3BucketInventoryEnumerator(repo, factory, tf.TerraformProviderConfig{ - Name: "test", - DefaultAlias: "eu-west-3", - }, alerter)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsS3BucketInventoryResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsS3BucketInventoryResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsS3BucketInventoryResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestS3BucketNotification(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockS3Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "single bucket without notifications", - dirName: "aws_s3_bucket_notifications_no_notif", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - repository.On( - "ListAllBuckets", - ).Return([]*s3.Bucket{ - {Name: awssdk.String("dritftctl-test-no-notifications")}, - }, nil) - - repository.On( - "GetBucketLocation", - "dritftctl-test-no-notifications", - ).Return( - "eu-west-3", - nil, - ) - - repository.On( - "GetBucketNotification", - "dritftctl-test-no-notifications", - "eu-west-3", - ).Return( - nil, - nil, - ) - }, - }, - { - test: "multiple bucket with notifications", dirName: "aws_s3_bucket_notifications_multiple", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - repository.On( - "ListAllBuckets", - ).Return([]*s3.Bucket{ - {Name: awssdk.String("bucket-martin-test-drift")}, - {Name: awssdk.String("bucket-martin-test-drift2")}, - {Name: awssdk.String("bucket-martin-test-drift3")}, - }, nil) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift", - ).Return( - "eu-west-1", - nil, - ) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift2", - ).Return( - "eu-west-3", - nil, - ) - - repository.On( - "GetBucketNotification", - "bucket-martin-test-drift2", - "eu-west-3", - ).Return( - &s3.NotificationConfiguration{ - LambdaFunctionConfigurations: []*s3.LambdaFunctionConfiguration{ - { - Id: awssdk.String("tf-s3-lambda-20201103165354926600000001"), - }, - { - Id: awssdk.String("tf-s3-lambda-20201103165354926600000002"), - }, - }, - }, - nil, - ) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift3", - ).Return( - "ap-northeast-1", - nil, - ) - }, - }, - { - test: "Cannot get bucket notification", dirName: "aws_s3_bucket_notifications_list_bucket", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - repository.On( - "ListAllBuckets", - ).Return([]*s3.Bucket{ - {Name: awssdk.String("dritftctl-test-notifications-error")}, - }, nil) - repository.On( - "GetBucketLocation", - "dritftctl-test-notifications-error", - ).Return( - "eu-west-3", - nil, - ) - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("GetBucketNotification", "dritftctl-test-notifications-error", "eu-west-3").Return(nil, awsError) - - alerter.On("SendAlert", "aws_s3_bucket_notification.dritftctl-test-notifications-error", alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, "aws_s3_bucket_notification.dritftctl-test-notifications-error", resourceaws.AwsS3BucketNotificationResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - { - test: "Cannot list bucket", dirName: "aws_s3_bucket_notifications_list_bucket", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllBuckets").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsS3BucketNotificationResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketNotificationResourceType, resourceaws.AwsS3BucketResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - session := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockS3Repository{} - c.mocks(fakeRepo, alerter) - var repo repository.S3Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewS3Repository(client.NewAWSClientFactory(session), cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewS3BucketNotificationEnumerator(repo, factory, tf.TerraformProviderConfig{ - Name: "test", - DefaultAlias: "eu-west-3", - }, alerter)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsS3BucketNotificationResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsS3BucketNotificationResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsS3BucketNotificationResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestS3BucketMetrics(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockS3Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "multiple bucket with multiple metrics", dirName: "aws_s3_bucket_metrics_multiple", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - repository.On( - "ListAllBuckets", - ).Return([]*s3.Bucket{ - {Name: awssdk.String("bucket-martin-test-drift")}, - {Name: awssdk.String("bucket-martin-test-drift2")}, - {Name: awssdk.String("bucket-martin-test-drift3")}, - }, nil) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift", - ).Return( - "eu-west-1", - nil, - ) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift2", - ).Return( - "eu-west-3", - nil, - ) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift3", - ).Return( - "ap-northeast-1", - nil, - ) - - repository.On( - "ListBucketMetricsConfigurations", - &s3.Bucket{Name: awssdk.String("bucket-martin-test-drift2")}, - "eu-west-3", - ).Return( - []*s3.MetricsConfiguration{ - {Id: awssdk.String("Metrics_Bucket2")}, - {Id: awssdk.String("Metrics2_Bucket2")}, - }, - nil, - ) - }, - }, - { - test: "cannot list bucket", dirName: "aws_s3_bucket_metrics_list_bucket", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllBuckets").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsS3BucketMetricResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketMetricResourceType, resourceaws.AwsS3BucketResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - { - test: "cannot list metrics", dirName: "aws_s3_bucket_metrics_list_metrics", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllBuckets").Return( - []*s3.Bucket{ - {Name: awssdk.String("bucket-martin-test-drift")}, - }, - nil, - ) - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift", - ).Return( - "eu-west-3", - nil, - ) - - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On( - "ListBucketMetricsConfigurations", - &s3.Bucket{Name: awssdk.String("bucket-martin-test-drift")}, - "eu-west-3", - ).Return( - nil, - awsError, - ) - - alerter.On("SendAlert", resourceaws.AwsS3BucketMetricResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketMetricResourceType, resourceaws.AwsS3BucketMetricResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - session := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockS3Repository{} - c.mocks(fakeRepo, alerter) - var repo repository.S3Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewS3Repository(client.NewAWSClientFactory(session), cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewS3BucketMetricsEnumerator(repo, factory, tf.TerraformProviderConfig{ - Name: "test", - DefaultAlias: "eu-west-3", - }, alerter)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsS3BucketMetricResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsS3BucketMetricResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsS3BucketMetricResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestS3BucketPolicy(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockS3Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "single bucket without policy", - dirName: "aws_s3_bucket_policy_no_policy", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - repository.On( - "ListAllBuckets", - ).Return([]*s3.Bucket{ - {Name: awssdk.String("dritftctl-test-no-policy")}, - }, nil) - - repository.On( - "GetBucketLocation", - "dritftctl-test-no-policy", - ).Return( - "eu-west-3", - nil, - ) - - repository.On( - "GetBucketPolicy", - "dritftctl-test-no-policy", - "eu-west-3", - ).Return( - nil, - nil, - ) - }, - }, - { - test: "multiple bucket with policies", dirName: "aws_s3_bucket_policies_multiple", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - repository.On( - "ListAllBuckets", - ).Return([]*s3.Bucket{ - {Name: awssdk.String("bucket-martin-test-drift")}, - {Name: awssdk.String("bucket-martin-test-drift2")}, - {Name: awssdk.String("bucket-martin-test-drift3")}, - }, nil) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift", - ).Return( - "eu-west-1", - nil, - ) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift2", - ).Return( - "eu-west-3", - nil, - ) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift3", - ).Return( - "ap-northeast-1", - nil, - ) - - repository.On( - "GetBucketPolicy", - "bucket-martin-test-drift2", - "eu-west-3", - ).Return( - // The value here not matter, we only want something not empty - // to trigger the detail fetcher - awssdk.String("foobar"), - nil, - ) - - }, - }, - { - test: "cannot list bucket", dirName: "aws_s3_bucket_policies_list_bucket", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllBuckets").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsS3BucketPolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketPolicyResourceType, resourceaws.AwsS3BucketResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - session := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockS3Repository{} - c.mocks(fakeRepo, alerter) - var repo repository.S3Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewS3Repository(client.NewAWSClientFactory(session), cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewS3BucketPolicyEnumerator(repo, factory, tf.TerraformProviderConfig{ - Name: "test", - DefaultAlias: "eu-west-3", - }, alerter)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsS3BucketPolicyResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsS3BucketPolicyResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsS3BucketPolicyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestS3BucketPublicAccessBlock(t *testing.T) { - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockS3Repository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "multiple bucket, one with access block", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllBuckets").Return([]*s3.Bucket{ - {Name: awssdk.String("bucket-with-public-access-block")}, - {Name: awssdk.String("bucket-without-public-access-block")}, - }, nil) - - repository.On("GetBucketLocation", "bucket-with-public-access-block"). - Return("us-east-1", nil) - repository.On("GetBucketLocation", "bucket-without-public-access-block"). - Return("us-east-1", nil) - - repository.On("GetBucketPublicAccessBlock", "bucket-with-public-access-block", "us-east-1"). - Return(&s3.PublicAccessBlockConfiguration{ - BlockPublicAcls: awssdk.Bool(true), - BlockPublicPolicy: awssdk.Bool(false), - }, nil) - - repository.On("GetBucketPublicAccessBlock", "bucket-without-public-access-block", "us-east-1"). - Return(nil, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - assert.Equal(t, got[0].ResourceId(), "bucket-with-public-access-block") - assert.Equal(t, got[0].ResourceType(), resourceaws.AwsS3BucketPublicAccessBlockResourceType) - assert.Equal(t, got[0].Attributes(), &resource.Attributes{ - "block_public_acls": true, - "block_public_policy": false, - "ignore_public_acls": false, - "restrict_public_buckets": false, - }) - }, - }, - { - test: "cannot list bucket", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllBuckets").Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsS3BucketPublicAccessBlockResourceType, resourceaws.AwsS3BucketResourceType), - }, - { - test: "cannot list public access block", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllBuckets").Return([]*s3.Bucket{{Name: awssdk.String("foobar")}}, nil) - repository.On("GetBucketLocation", "foobar").Return("us-east-1", nil) - repository.On("GetBucketPublicAccessBlock", "foobar", "us-east-1").Return(nil, dummyError) - alerter.On("SendAlert", "aws_s3_bucket_public_access_block.foobar", alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceScanningError(dummyError, resourceaws.AwsS3BucketPublicAccessBlockResourceType, "foobar"), alerts.EnumerationPhase)).Return() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - } - - providerVersion := "3.19.0" - schemaRepository := testresource.InitFakeSchemaRepository("aws", providerVersion) - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockS3Repository{} - c.mocks(fakeRepo, alerter) - - var repo repository.S3Repository = fakeRepo - - remoteLibrary.AddEnumerator(aws.NewS3BucketPublicAccessBlockEnumerator( - repo, factory, - tf.TerraformProviderConfig{DefaultAlias: "us-east-1"}, - alerter, - )) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - }) - } -} - -func TestS3BucketAnalytic(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*repository.MockS3Repository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "multiple bucket with multiple analytics", - dirName: "aws_s3_bucket_analytics_multiple", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - repository.On( - "ListAllBuckets", - ).Return([]*s3.Bucket{ - {Name: awssdk.String("bucket-martin-test-drift")}, - {Name: awssdk.String("bucket-martin-test-drift2")}, - {Name: awssdk.String("bucket-martin-test-drift3")}, - }, nil) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift", - ).Return( - "eu-west-1", - nil, - ) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift2", - ).Return( - "eu-west-3", - nil, - ) - - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift3", - ).Return( - "ap-northeast-1", - nil, - ) - - repository.On( - "ListBucketAnalyticsConfigurations", - &s3.Bucket{Name: awssdk.String("bucket-martin-test-drift2")}, - "eu-west-3", - ).Return( - []*s3.AnalyticsConfiguration{ - {Id: awssdk.String("Analytics_Bucket2")}, - {Id: awssdk.String("Analytics2_Bucket2")}, - }, - nil, - ) - }, - }, - { - test: "cannot list bucket", dirName: "aws_s3_bucket_analytics_list_bucket", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On("ListAllBuckets").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsS3BucketAnalyticsConfigurationResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketAnalyticsConfigurationResourceType, resourceaws.AwsS3BucketResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - { - test: "cannot list Analytics", dirName: "aws_s3_bucket_analytics_list_analytics", - mocks: func(repository *repository.MockS3Repository, alerter *mocks.AlerterInterface) { - repository.On("ListAllBuckets").Return( - []*s3.Bucket{ - {Name: awssdk.String("bucket-martin-test-drift")}, - }, - nil, - ) - repository.On( - "GetBucketLocation", - "bucket-martin-test-drift", - ).Return( - "eu-west-3", - nil, - ) - - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - repository.On( - "ListBucketAnalyticsConfigurations", - &s3.Bucket{Name: awssdk.String("bucket-martin-test-drift")}, - "eu-west-3", - ).Return( - nil, - awsError, - ) - - alerter.On("SendAlert", resourceaws.AwsS3BucketAnalyticsConfigurationResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsS3BucketAnalyticsConfigurationResourceType, resourceaws.AwsS3BucketAnalyticsConfigurationResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - session := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockS3Repository{} - c.mocks(fakeRepo, alerter) - var repo repository.S3Repository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewS3Repository(client.NewAWSClientFactory(session), cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewS3BucketAnalyticEnumerator(repo, factory, tf.TerraformProviderConfig{ - Name: "test", - DefaultAlias: "eu-west-3", - }, alerter)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsS3BucketAnalyticsConfigurationResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsS3BucketAnalyticsConfigurationResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsS3BucketAnalyticsConfigurationResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_sns_scanner_test.go b/pkg/remote/aws_sns_scanner_test.go deleted file mode 100644 index 781d6245..00000000 --- a/pkg/remote/aws_sns_scanner_test.go +++ /dev/null @@ -1,348 +0,0 @@ -package remote - -import ( - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sns" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/mock" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - "github.com/stretchr/testify/assert" -) - -func TestScanSNSTopic(t *testing.T) { - - cases := []struct { - test string - dirName string - mocks func(*repository.MockSNSRepository, *mocks.AlerterInterface) - err error - }{ - { - test: "no SNS Topic", - dirName: "aws_sns_topic_empty", - mocks: func(client *repository.MockSNSRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllTopics").Return([]*sns.Topic{}, nil) - }, - err: nil, - }, - { - test: "Multiple SNSTopic", - dirName: "aws_sns_topic_multiple", - mocks: func(client *repository.MockSNSRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllTopics").Return([]*sns.Topic{ - {TopicArn: awssdk.String("arn:aws:sns:eu-west-3:526954929923:user-updates-topic")}, - {TopicArn: awssdk.String("arn:aws:sns:eu-west-3:526954929923:user-updates-topic2")}, - {TopicArn: awssdk.String("arn:aws:sns:eu-west-3:526954929923:user-updates-topic3")}, - }, nil) - }, - err: nil, - }, - { - test: "cannot list SNSTopic", - dirName: "aws_sns_topic_empty", - mocks: func(client *repository.MockSNSRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - client.On("ListAllTopics").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsSnsTopicResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSnsTopicResourceType, resourceaws.AwsSnsTopicResourceType), alerts.EnumerationPhase)).Return() - }, - err: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockSNSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.SNSRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewSNSRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewSNSTopicEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsSnsTopicResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsSnsTopicResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.err, err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsSnsTopicResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestSNSTopicPolicyScan(t *testing.T) { - - cases := []struct { - test string - dirName string - mocks func(*repository.MockSNSRepository, *mocks.AlerterInterface) - err error - }{ - { - test: "no SNS Topic policy", - dirName: "aws_sns_topic_policy_empty", - mocks: func(client *repository.MockSNSRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllTopics").Return([]*sns.Topic{}, nil) - }, - err: nil, - }, - { - test: "Multiple SNSTopicPolicy", - dirName: "aws_sns_topic_policy_multiple", - mocks: func(client *repository.MockSNSRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllTopics").Return([]*sns.Topic{ - {TopicArn: awssdk.String("arn:aws:sns:us-east-1:526954929923:my-topic-with-policy")}, - {TopicArn: awssdk.String("arn:aws:sns:us-east-1:526954929923:my-topic-with-policy2")}, - }, nil) - }, - err: nil, - }, - { - test: "cannot list SNSTopic", - dirName: "aws_sns_topic_policy_topic_list", - mocks: func(client *repository.MockSNSRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - client.On("ListAllTopics").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsSnsTopicPolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSnsTopicPolicyResourceType, resourceaws.AwsSnsTopicResourceType), alerts.EnumerationPhase)).Return() - }, - err: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockSNSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.SNSRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewSNSRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewSNSTopicPolicyEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsSnsTopicPolicyResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsSnsTopicPolicyResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.err, err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsSnsTopicPolicyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestSNSTopicSubscriptionScan(t *testing.T) { - - cases := []struct { - test string - dirName string - mocks func(*repository.MockSNSRepository, *mocks.AlerterInterface) - err error - }{ - { - test: "no SNS Topic Subscription", - dirName: "aws_sns_topic_subscription_empty", - mocks: func(client *repository.MockSNSRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllSubscriptions").Return([]*sns.Subscription{}, nil) - }, - err: nil, - }, - { - test: "Multiple SNSTopic Subscription", - dirName: "aws_sns_topic_subscription_multiple", - mocks: func(client *repository.MockSNSRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllSubscriptions").Return([]*sns.Subscription{ - {SubscriptionArn: awssdk.String("arn:aws:sns:us-east-1:526954929923:user-updates-topic2:c0f794c5-a009-4db4-9147-4c55959787fa")}, - {SubscriptionArn: awssdk.String("arn:aws:sns:us-east-1:526954929923:user-updates-topic:b6e66147-2b31-4486-8d4b-2a2272264c8e")}, - }, nil) - }, - err: nil, - }, - { - test: "Multiple SNSTopic Subscription with one pending and one incorrect", - dirName: "aws_sns_topic_subscription_multiple", - mocks: func(client *repository.MockSNSRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllSubscriptions").Return([]*sns.Subscription{ - {SubscriptionArn: awssdk.String("PendingConfirmation"), Endpoint: awssdk.String("TEST")}, - {SubscriptionArn: awssdk.String("Incorrect"), Endpoint: awssdk.String("INCORRECT")}, - {SubscriptionArn: awssdk.String("arn:aws:sns:us-east-1:526954929923:user-updates-topic2:c0f794c5-a009-4db4-9147-4c55959787fa")}, - {SubscriptionArn: awssdk.String("arn:aws:sns:us-east-1:526954929923:user-updates-topic:b6e66147-2b31-4486-8d4b-2a2272264c8e")}, - }, nil) - - alerter.On("SendAlert", "aws_sns_topic_subscription.PendingConfirmation", aws.NewWrongArnTopicAlert("PendingConfirmation", awssdk.String("TEST"))).Return() - - alerter.On("SendAlert", "aws_sns_topic_subscription.Incorrect", aws.NewWrongArnTopicAlert("Incorrect", awssdk.String("INCORRECT"))).Return() - }, - err: nil, - }, - { - test: "cannot list SNSTopic subscription", - dirName: "aws_sns_topic_subscription_list", - mocks: func(client *repository.MockSNSRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - client.On("ListAllSubscriptions").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsSnsTopicSubscriptionResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSnsTopicSubscriptionResourceType, resourceaws.AwsSnsTopicSubscriptionResourceType), alerts.EnumerationPhase)).Return() - }, - err: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockSNSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.SNSRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewSNSRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewSNSTopicSubscriptionEnumerator(repo, factory, alerter)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsSnsTopicSubscriptionResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsSnsTopicSubscriptionResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsSnsTopicSubscriptionResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/aws_sqs_scanner_test.go b/pkg/remote/aws_sqs_scanner_test.go deleted file mode 100644 index dd60d6c0..00000000 --- a/pkg/remote/aws_sqs_scanner_test.go +++ /dev/null @@ -1,253 +0,0 @@ -package remote - -import ( - "testing" - - awssdk "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sqs" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestSQSQueue(t *testing.T) { - cases := []struct { - test string - dirName string - mocks func(*repository.MockSQSRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no sqs queues", - dirName: "aws_sqs_queue_empty", - mocks: func(client *repository.MockSQSRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllQueues").Return([]*string{}, nil) - }, - wantErr: nil, - }, - { - test: "multiple sqs queues", - dirName: "aws_sqs_queue_multiple", - mocks: func(client *repository.MockSQSRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllQueues").Return([]*string{ - awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/bar.fifo"), - awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/foo"), - }, nil) - }, - wantErr: nil, - }, - { - test: "cannot list sqs queues", - dirName: "aws_sqs_queue_empty", - mocks: func(client *repository.MockSQSRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - client.On("ListAllQueues").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsSqsQueueResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSqsQueueResourceType, resourceaws.AwsSqsQueueResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockSQSRepository{} - c.mocks(fakeRepo, alerter) - var repo repository.SQSRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewSQSRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewSQSQueueEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsSqsQueueResourceType, aws.NewSQSQueueDetailsFetcher(provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsSqsQueueResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - fakeRepo.AssertExpectations(tt) - alerter.AssertExpectations(tt) - }) - } -} - -func TestSQSQueuePolicy(t *testing.T) { - cases := []struct { - test string - dirName string - mocks func(*repository.MockSQSRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - // sqs queue with no policy case is not possible - // as a default SQSDefaultPolicy (e.g. policy="") will always be present in each queue - test: "no sqs queue policies", - dirName: "aws_sqs_queue_policy_empty", - mocks: func(client *repository.MockSQSRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllQueues").Return([]*string{}, nil) - }, - wantErr: nil, - }, - { - test: "multiple sqs queue policies (default or not)", - dirName: "aws_sqs_queue_policy_multiple", - mocks: func(client *repository.MockSQSRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllQueues").Return([]*string{ - awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/bar.fifo"), - awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/foo"), - awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/baz"), - }, nil) - - client.On("GetQueueAttributes", mock.Anything).Return( - &sqs.GetQueueAttributesOutput{ - Attributes: map[string]*string{ - sqs.QueueAttributeNamePolicy: awssdk.String(""), - }, - }, - nil, - ) - }, - wantErr: nil, - }, - { - test: "multiple sqs queue policies (with nil attributes)", - dirName: "aws_sqs_queue_policy_multiple", - mocks: func(client *repository.MockSQSRepository, alerter *mocks.AlerterInterface) { - client.On("ListAllQueues").Return([]*string{ - awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/bar.fifo"), - awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/foo"), - awssdk.String("https://sqs.eu-west-3.amazonaws.com/047081014315/baz"), - }, nil) - - client.On("GetQueueAttributes", mock.Anything).Return( - &sqs.GetQueueAttributesOutput{}, - nil, - ) - }, - wantErr: nil, - }, - { - test: "cannot list sqs queues, thus sqs queue policies", - dirName: "aws_sqs_queue_policy_empty", - mocks: func(client *repository.MockSQSRepository, alerter *mocks.AlerterInterface) { - awsError := awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, "") - client.On("ListAllQueues").Return(nil, awsError) - - alerter.On("SendAlert", resourceaws.AwsSqsQueuePolicyResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awsError, resourceaws.AwsSqsQueuePolicyResourceType, resourceaws.AwsSqsQueueResourceType), alerts.EnumerationPhase)).Return() - }, - wantErr: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("aws", "3.19.0") - resourceaws.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - sess := session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - })) - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockSQSRepository{} - c.mocks(fakeRepo, alerter) - var repo repository.SQSRepository = fakeRepo - providerVersion := "3.19.0" - realProvider, err := terraform2.InitTestAwsProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = repository.NewSQSRepository(sess, cache.New(0)) - } - - remoteLibrary.AddEnumerator(aws.NewSQSQueuePolicyEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceaws.AwsSqsQueuePolicyResourceType, common.NewGenericDetailsFetcher(resourceaws.AwsSqsQueuePolicyResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceaws.AwsSqsQueuePolicyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - fakeRepo.AssertExpectations(tt) - alerter.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/azurerm/azurerm_container_registry_enumerator.go b/pkg/remote/azurerm/azurerm_container_registry_enumerator.go deleted file mode 100644 index 6c997f20..00000000 --- a/pkg/remote/azurerm/azurerm_container_registry_enumerator.go +++ /dev/null @@ -1,45 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermContainerRegistryEnumerator struct { - repository repository.ContainerRegistryRepository - factory resource.ResourceFactory -} - -func NewAzurermContainerRegistryEnumerator(repo repository.ContainerRegistryRepository, factory resource.ResourceFactory) *AzurermContainerRegistryEnumerator { - return &AzurermContainerRegistryEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermContainerRegistryEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzureContainerRegistryResourceType -} - -func (e *AzurermContainerRegistryEnumerator) Enumerate() ([]*resource.Resource, error) { - registries, err := e.repository.ListAllContainerRegistries() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0) - for _, registry := range registries { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *registry.ID, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_firewalls_enumerator.go b/pkg/remote/azurerm/azurerm_firewalls_enumerator.go deleted file mode 100644 index 101893d7..00000000 --- a/pkg/remote/azurerm/azurerm_firewalls_enumerator.go +++ /dev/null @@ -1,48 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermFirewallsEnumerator struct { - repository repository.NetworkRepository - factory resource.ResourceFactory -} - -func NewAzurermFirewallsEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermFirewallsEnumerator { - return &AzurermFirewallsEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermFirewallsEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzureFirewallResourceType -} - -func (e *AzurermFirewallsEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.ListAllFirewalls() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *res.ID, - map[string]interface{}{ - "name": *res.Name, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_image_enumerator.go b/pkg/remote/azurerm/azurerm_image_enumerator.go deleted file mode 100644 index cad98884..00000000 --- a/pkg/remote/azurerm/azurerm_image_enumerator.go +++ /dev/null @@ -1,65 +0,0 @@ -package azurerm - -import ( - "strings" - - "github.com/Azure/go-autorest/autorest/azure" - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermImageEnumerator struct { - repository repository.ComputeRepository - factory resource.ResourceFactory -} - -func NewAzurermImageEnumerator(repo repository.ComputeRepository, factory resource.ResourceFactory) *AzurermImageEnumerator { - return &AzurermImageEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermImageEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzureImageResourceType -} - -func (e *AzurermImageEnumerator) Enumerate() ([]*resource.Resource, error) { - images, err := e.repository.ListAllImages() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(images)) - - for _, res := range images { - r, err := azure.ParseResourceID(*res.ID) - if err != nil { - logrus.WithFields(map[string]interface{}{ - "id": *res.ID, - "type": string(e.SupportedType()), - }).Error("Failed to parse Azure resource ID") - continue - } - - // Here we turn the resource group into lowercase because for some reason the API returns it in uppercase. - resourceId := strings.Replace(*res.ID, r.ResourceGroup, strings.ToLower(r.ResourceGroup), 1) - - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - resourceId, - map[string]interface{}{ - "name": *res.Name, - }, - ), - ) - - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_lb_enumerator.go b/pkg/remote/azurerm/azurerm_lb_enumerator.go deleted file mode 100644 index 8d966aa2..00000000 --- a/pkg/remote/azurerm/azurerm_lb_enumerator.go +++ /dev/null @@ -1,48 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermLoadBalancerEnumerator struct { - repository repository.NetworkRepository - factory resource.ResourceFactory -} - -func NewAzurermLoadBalancerEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermLoadBalancerEnumerator { - return &AzurermLoadBalancerEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermLoadBalancerEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzureLoadBalancerResourceType -} - -func (e *AzurermLoadBalancerEnumerator) Enumerate() ([]*resource.Resource, error) { - loadBalancers, err := e.repository.ListAllLoadBalancers() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(loadBalancers)) - - for _, res := range loadBalancers { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *res.ID, - map[string]interface{}{ - "name": *res.Name, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_lb_rule_enumerator.go b/pkg/remote/azurerm/azurerm_lb_rule_enumerator.go deleted file mode 100644 index e601f10c..00000000 --- a/pkg/remote/azurerm/azurerm_lb_rule_enumerator.go +++ /dev/null @@ -1,56 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermLoadBalancerRuleEnumerator struct { - repository repository.NetworkRepository - factory resource.ResourceFactory -} - -func NewAzurermLoadBalancerRuleEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermLoadBalancerRuleEnumerator { - return &AzurermLoadBalancerRuleEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermLoadBalancerRuleEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzureLoadBalancerRuleResourceType -} - -func (e *AzurermLoadBalancerRuleEnumerator) Enumerate() ([]*resource.Resource, error) { - loadBalancers, err := e.repository.ListAllLoadBalancers() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzureLoadBalancerResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, res := range loadBalancers { - rules, err := e.repository.ListLoadBalancerRules(res) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, rule := range rules { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *rule.ID, - map[string]interface{}{ - "name": *rule.Name, - "loadbalancer_id": *res.ID, - }, - ), - ) - } - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_network_security_group_enumerator.go b/pkg/remote/azurerm/azurerm_network_security_group_enumerator.go deleted file mode 100644 index 69925f77..00000000 --- a/pkg/remote/azurerm/azurerm_network_security_group_enumerator.go +++ /dev/null @@ -1,48 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermNetworkSecurityGroupEnumerator struct { - repository repository.NetworkRepository - factory resource.ResourceFactory -} - -func NewAzurermNetworkSecurityGroupEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermNetworkSecurityGroupEnumerator { - return &AzurermNetworkSecurityGroupEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermNetworkSecurityGroupEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzureNetworkSecurityGroupResourceType -} - -func (e *AzurermNetworkSecurityGroupEnumerator) Enumerate() ([]*resource.Resource, error) { - securityGroups, err := e.repository.ListAllSecurityGroups() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzureNetworkSecurityGroupResourceType) - } - - results := make([]*resource.Resource, 0, len(securityGroups)) - - for _, res := range securityGroups { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *res.ID, - map[string]interface{}{ - "name": *res.Name, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_postgresql_database_enumerator.go b/pkg/remote/azurerm/azurerm_postgresql_database_enumerator.go deleted file mode 100644 index d1818e20..00000000 --- a/pkg/remote/azurerm/azurerm_postgresql_database_enumerator.go +++ /dev/null @@ -1,54 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermPostgresqlDatabaseEnumerator struct { - repository repository.PostgresqlRespository - factory resource.ResourceFactory -} - -func NewAzurermPostgresqlDatabaseEnumerator(repo repository.PostgresqlRespository, factory resource.ResourceFactory) *AzurermPostgresqlDatabaseEnumerator { - return &AzurermPostgresqlDatabaseEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermPostgresqlDatabaseEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzurePostgresqlDatabaseResourceType -} - -func (e *AzurermPostgresqlDatabaseEnumerator) Enumerate() ([]*resource.Resource, error) { - servers, err := e.repository.ListAllServers() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePostgresqlServerResourceType) - } - - results := make([]*resource.Resource, 0) - for _, server := range servers { - databases, err := e.repository.ListAllDatabasesByServer(server) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, db := range databases { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *db.ID, - map[string]interface{}{ - "name": *db.Name, - }, - ), - ) - } - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_postgresql_server_enumerator.go b/pkg/remote/azurerm/azurerm_postgresql_server_enumerator.go deleted file mode 100644 index 3646e128..00000000 --- a/pkg/remote/azurerm/azurerm_postgresql_server_enumerator.go +++ /dev/null @@ -1,47 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermPostgresqlServerEnumerator struct { - repository repository.PostgresqlRespository - factory resource.ResourceFactory -} - -func NewAzurermPostgresqlServerEnumerator(repo repository.PostgresqlRespository, factory resource.ResourceFactory) *AzurermPostgresqlServerEnumerator { - return &AzurermPostgresqlServerEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermPostgresqlServerEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzurePostgresqlServerResourceType -} - -func (e *AzurermPostgresqlServerEnumerator) Enumerate() ([]*resource.Resource, error) { - servers, err := e.repository.ListAllServers() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0) - for _, server := range servers { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *server.ID, - map[string]interface{}{ - "name": *server.Name, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_private_dns_cname_record_enumerator.go b/pkg/remote/azurerm/azurerm_private_dns_cname_record_enumerator.go deleted file mode 100644 index 16ee70e7..00000000 --- a/pkg/remote/azurerm/azurerm_private_dns_cname_record_enumerator.go +++ /dev/null @@ -1,57 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermPrivateDNSCNameRecordEnumerator struct { - repository repository.PrivateDNSRepository - factory resource.ResourceFactory -} - -func NewAzurermPrivateDNSCNameRecordEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSCNameRecordEnumerator { - return &AzurermPrivateDNSCNameRecordEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermPrivateDNSCNameRecordEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzurePrivateDNSCNameRecordResourceType -} - -func (e *AzurermPrivateDNSCNameRecordEnumerator) Enumerate() ([]*resource.Resource, error) { - - zones, err := e.repository.ListAllPrivateZones() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePrivateDNSZoneResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, zone := range zones { - records, err := e.repository.ListAllCNAMERecords(zone) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - for _, record := range records { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *record.ID, - map[string]interface{}{ - "name": *record.Name, - "zone_name": *zone.Name, - }, - ), - ) - } - - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_privatedns_a_record_enumerator.go b/pkg/remote/azurerm/azurerm_privatedns_a_record_enumerator.go deleted file mode 100644 index 1ef54d83..00000000 --- a/pkg/remote/azurerm/azurerm_privatedns_a_record_enumerator.go +++ /dev/null @@ -1,57 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermPrivateDNSARecordEnumerator struct { - repository repository.PrivateDNSRepository - factory resource.ResourceFactory -} - -func NewAzurermPrivateDNSARecordEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSARecordEnumerator { - return &AzurermPrivateDNSARecordEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermPrivateDNSARecordEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzurePrivateDNSARecordResourceType -} - -func (e *AzurermPrivateDNSARecordEnumerator) Enumerate() ([]*resource.Resource, error) { - - zones, err := e.repository.ListAllPrivateZones() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePrivateDNSZoneResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, zone := range zones { - records, err := e.repository.ListAllARecords(zone) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - for _, record := range records { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *record.ID, - map[string]interface{}{ - "name": *record.Name, - "zone_name": *zone.Name, - }, - ), - ) - } - - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_privatedns_aaaa_record_enumerator.go b/pkg/remote/azurerm/azurerm_privatedns_aaaa_record_enumerator.go deleted file mode 100644 index db7a04ff..00000000 --- a/pkg/remote/azurerm/azurerm_privatedns_aaaa_record_enumerator.go +++ /dev/null @@ -1,57 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermPrivateDNSAAAARecordEnumerator struct { - repository repository.PrivateDNSRepository - factory resource.ResourceFactory -} - -func NewAzurermPrivateDNSAAAARecordEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSAAAARecordEnumerator { - return &AzurermPrivateDNSAAAARecordEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermPrivateDNSAAAARecordEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzurePrivateDNSAAAARecordResourceType -} - -func (e *AzurermPrivateDNSAAAARecordEnumerator) Enumerate() ([]*resource.Resource, error) { - - zones, err := e.repository.ListAllPrivateZones() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePrivateDNSZoneResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, zone := range zones { - records, err := e.repository.ListAllAAAARecords(zone) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - for _, record := range records { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *record.ID, - map[string]interface{}{ - "name": *record.Name, - "zone_name": *zone.Name, - }, - ), - ) - } - - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_privatedns_mx_record_enumerator.go b/pkg/remote/azurerm/azurerm_privatedns_mx_record_enumerator.go deleted file mode 100644 index 915f9d8c..00000000 --- a/pkg/remote/azurerm/azurerm_privatedns_mx_record_enumerator.go +++ /dev/null @@ -1,57 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermPrivateDNSMXRecordEnumerator struct { - repository repository.PrivateDNSRepository - factory resource.ResourceFactory -} - -func NewAzurermPrivateDNSMXRecordEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSMXRecordEnumerator { - return &AzurermPrivateDNSMXRecordEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermPrivateDNSMXRecordEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzurePrivateDNSMXRecordResourceType -} - -func (e *AzurermPrivateDNSMXRecordEnumerator) Enumerate() ([]*resource.Resource, error) { - - zones, err := e.repository.ListAllPrivateZones() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePrivateDNSZoneResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, zone := range zones { - records, err := e.repository.ListAllMXRecords(zone) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - for _, record := range records { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *record.ID, - map[string]interface{}{ - "name": *record.Name, - "zone_name": *zone.Name, - }, - ), - ) - } - - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_privatedns_ptr_record_enumerator.go b/pkg/remote/azurerm/azurerm_privatedns_ptr_record_enumerator.go deleted file mode 100644 index 224e8f9c..00000000 --- a/pkg/remote/azurerm/azurerm_privatedns_ptr_record_enumerator.go +++ /dev/null @@ -1,57 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermPrivateDNSPTRRecordEnumerator struct { - repository repository.PrivateDNSRepository - factory resource.ResourceFactory -} - -func NewAzurermPrivateDNSPTRRecordEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSPTRRecordEnumerator { - return &AzurermPrivateDNSPTRRecordEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermPrivateDNSPTRRecordEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzurePrivateDNSPTRRecordResourceType -} - -func (e *AzurermPrivateDNSPTRRecordEnumerator) Enumerate() ([]*resource.Resource, error) { - - zones, err := e.repository.ListAllPrivateZones() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePrivateDNSZoneResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, zone := range zones { - records, err := e.repository.ListAllPTRRecords(zone) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - for _, record := range records { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *record.ID, - map[string]interface{}{ - "name": *record.Name, - "zone_name": *zone.Name, - }, - ), - ) - } - - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_privatedns_srv_record_enumerator.go b/pkg/remote/azurerm/azurerm_privatedns_srv_record_enumerator.go deleted file mode 100644 index 5855a789..00000000 --- a/pkg/remote/azurerm/azurerm_privatedns_srv_record_enumerator.go +++ /dev/null @@ -1,57 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermPrivateDNSSRVRecordEnumerator struct { - repository repository.PrivateDNSRepository - factory resource.ResourceFactory -} - -func NewAzurermPrivateDNSSRVRecordEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSSRVRecordEnumerator { - return &AzurermPrivateDNSSRVRecordEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermPrivateDNSSRVRecordEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzurePrivateDNSSRVRecordResourceType -} - -func (e *AzurermPrivateDNSSRVRecordEnumerator) Enumerate() ([]*resource.Resource, error) { - - zones, err := e.repository.ListAllPrivateZones() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePrivateDNSZoneResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, zone := range zones { - records, err := e.repository.ListAllSRVRecords(zone) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - for _, record := range records { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *record.ID, - map[string]interface{}{ - "name": *record.Name, - "zone_name": *zone.Name, - }, - ), - ) - } - - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_privatedns_txt_record_enumerator.go b/pkg/remote/azurerm/azurerm_privatedns_txt_record_enumerator.go deleted file mode 100644 index dd56f6c3..00000000 --- a/pkg/remote/azurerm/azurerm_privatedns_txt_record_enumerator.go +++ /dev/null @@ -1,57 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermPrivateDNSTXTRecordEnumerator struct { - repository repository.PrivateDNSRepository - factory resource.ResourceFactory -} - -func NewAzurermPrivateDNSTXTRecordEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSTXTRecordEnumerator { - return &AzurermPrivateDNSTXTRecordEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermPrivateDNSTXTRecordEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzurePrivateDNSTXTRecordResourceType -} - -func (e *AzurermPrivateDNSTXTRecordEnumerator) Enumerate() ([]*resource.Resource, error) { - - zones, err := e.repository.ListAllPrivateZones() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzurePrivateDNSZoneResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, zone := range zones { - records, err := e.repository.ListAllTXTRecords(zone) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - for _, record := range records { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *record.ID, - map[string]interface{}{ - "name": *record.Name, - "zone_name": *zone.Name, - }, - ), - ) - } - - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_privatedns_zone_enumerator.go b/pkg/remote/azurerm/azurerm_privatedns_zone_enumerator.go deleted file mode 100644 index a1e01097..00000000 --- a/pkg/remote/azurerm/azurerm_privatedns_zone_enumerator.go +++ /dev/null @@ -1,49 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermPrivateDNSZoneEnumerator struct { - repository repository.PrivateDNSRepository - factory resource.ResourceFactory -} - -func NewAzurermPrivateDNSZoneEnumerator(repo repository.PrivateDNSRepository, factory resource.ResourceFactory) *AzurermPrivateDNSZoneEnumerator { - return &AzurermPrivateDNSZoneEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermPrivateDNSZoneEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzurePrivateDNSZoneResourceType -} - -func (e *AzurermPrivateDNSZoneEnumerator) Enumerate() ([]*resource.Resource, error) { - - zones, err := e.repository.ListAllPrivateZones() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0) - - for _, zone := range zones { - - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *zone.ID, - map[string]interface{}{}, - ), - ) - - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_public_ip_enumerator.go b/pkg/remote/azurerm/azurerm_public_ip_enumerator.go deleted file mode 100644 index 4a3c45c8..00000000 --- a/pkg/remote/azurerm/azurerm_public_ip_enumerator.go +++ /dev/null @@ -1,48 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermPublicIPEnumerator struct { - repository repository.NetworkRepository - factory resource.ResourceFactory -} - -func NewAzurermPublicIPEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermPublicIPEnumerator { - return &AzurermPublicIPEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermPublicIPEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzurePublicIPResourceType -} - -func (e *AzurermPublicIPEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.ListAllPublicIPAddresses() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *res.ID, - map[string]interface{}{ - "name": *res.Name, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_resource_group_enumerator.go b/pkg/remote/azurerm/azurerm_resource_group_enumerator.go deleted file mode 100644 index 0867cca7..00000000 --- a/pkg/remote/azurerm/azurerm_resource_group_enumerator.go +++ /dev/null @@ -1,47 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermResourceGroupEnumerator struct { - repository repository.ResourcesRepository - factory resource.ResourceFactory -} - -func NewAzurermResourceGroupEnumerator(repo repository.ResourcesRepository, factory resource.ResourceFactory) *AzurermResourceGroupEnumerator { - return &AzurermResourceGroupEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermResourceGroupEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzureResourceGroupResourceType -} - -func (e *AzurermResourceGroupEnumerator) Enumerate() ([]*resource.Resource, error) { - groups, err := e.repository.ListAllResourceGroups() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0) - for _, group := range groups { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *group.ID, - map[string]interface{}{ - "name": *group.Name, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_route_enumerator.go b/pkg/remote/azurerm/azurerm_route_enumerator.go deleted file mode 100644 index 6f3ce148..00000000 --- a/pkg/remote/azurerm/azurerm_route_enumerator.go +++ /dev/null @@ -1,52 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermRouteEnumerator struct { - repository repository.NetworkRepository - factory resource.ResourceFactory -} - -func NewAzurermRouteEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermRouteEnumerator { - return &AzurermRouteEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermRouteEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzureRouteResourceType -} - -func (e *AzurermRouteEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.ListAllRouteTables() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzureRouteTableResourceType) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - for _, route := range res.Properties.Routes { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *route.ID, - map[string]interface{}{ - "name": *route.Name, - "route_table_name": *res.Name, - }, - ), - ) - } - - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_route_table_enumerator.go b/pkg/remote/azurerm/azurerm_route_table_enumerator.go deleted file mode 100644 index 6731de6c..00000000 --- a/pkg/remote/azurerm/azurerm_route_table_enumerator.go +++ /dev/null @@ -1,48 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermRouteTableEnumerator struct { - repository repository.NetworkRepository - factory resource.ResourceFactory -} - -func NewAzurermRouteTableEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermRouteTableEnumerator { - return &AzurermRouteTableEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermRouteTableEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzureRouteTableResourceType -} - -func (e *AzurermRouteTableEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.ListAllRouteTables() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *res.ID, - map[string]interface{}{ - "name": *res.Name, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_ssh_public_key_enumerator.go b/pkg/remote/azurerm/azurerm_ssh_public_key_enumerator.go deleted file mode 100644 index 604f8cf3..00000000 --- a/pkg/remote/azurerm/azurerm_ssh_public_key_enumerator.go +++ /dev/null @@ -1,48 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermSSHPublicKeyEnumerator struct { - repository repository.ComputeRepository - factory resource.ResourceFactory -} - -func NewAzurermSSHPublicKeyEnumerator(repo repository.ComputeRepository, factory resource.ResourceFactory) *AzurermSSHPublicKeyEnumerator { - return &AzurermSSHPublicKeyEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermSSHPublicKeyEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzureSSHPublicKeyResourceType -} - -func (e *AzurermSSHPublicKeyEnumerator) Enumerate() ([]*resource.Resource, error) { - keys, err := e.repository.ListAllSSHPublicKeys() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(keys)) - - for _, res := range keys { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *res.ID, - map[string]interface{}{ - "name": *res.Name, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_storage_account_enumerator.go b/pkg/remote/azurerm/azurerm_storage_account_enumerator.go deleted file mode 100644 index ba6a1877..00000000 --- a/pkg/remote/azurerm/azurerm_storage_account_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermStorageAccountEnumerator struct { - repository repository.StorageRespository - factory resource.ResourceFactory -} - -func NewAzurermStorageAccountEnumerator(repo repository.StorageRespository, factory resource.ResourceFactory) *AzurermStorageAccountEnumerator { - return &AzurermStorageAccountEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermStorageAccountEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzureStorageAccountResourceType -} - -func (e *AzurermStorageAccountEnumerator) Enumerate() ([]*resource.Resource, error) { - accounts, err := e.repository.ListAllStorageAccount() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(accounts)) - - for _, account := range accounts { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *account.ID, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_storage_container_enumerator.go b/pkg/remote/azurerm/azurerm_storage_container_enumerator.go deleted file mode 100644 index 09bb9b35..00000000 --- a/pkg/remote/azurerm/azurerm_storage_container_enumerator.go +++ /dev/null @@ -1,54 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermStorageContainerEnumerator struct { - repository repository.StorageRespository - factory resource.ResourceFactory -} - -func NewAzurermStorageContainerEnumerator(repo repository.StorageRespository, factory resource.ResourceFactory) *AzurermStorageContainerEnumerator { - return &AzurermStorageContainerEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermStorageContainerEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzureStorageContainerResourceType -} - -func (e *AzurermStorageContainerEnumerator) Enumerate() ([]*resource.Resource, error) { - - accounts, err := e.repository.ListAllStorageAccount() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzureStorageAccountResourceType) - } - - results := make([]*resource.Resource, 0) - - for _, account := range accounts { - containers, err := e.repository.ListAllStorageContainer(account) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - for _, container := range containers { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - container, - map[string]interface{}{}, - ), - ) - } - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_subnets_enumerator.go b/pkg/remote/azurerm/azurerm_subnets_enumerator.go deleted file mode 100644 index 64f72523..00000000 --- a/pkg/remote/azurerm/azurerm_subnets_enumerator.go +++ /dev/null @@ -1,51 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermSubnetEnumerator struct { - repository repository.NetworkRepository - factory resource.ResourceFactory -} - -func NewAzurermSubnetEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermSubnetEnumerator { - return &AzurermSubnetEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermSubnetEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzureSubnetResourceType -} - -func (e *AzurermSubnetEnumerator) Enumerate() ([]*resource.Resource, error) { - networks, err := e.repository.ListAllVirtualNetworks() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), azurerm.AzureVirtualNetworkResourceType) - } - - results := make([]*resource.Resource, 0) - for _, network := range networks { - resources, err := e.repository.ListAllSubnets(network) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - for _, res := range resources { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *res.ID, - map[string]interface{}{}, - ), - ) - } - } - - return results, err -} diff --git a/pkg/remote/azurerm/azurerm_virtual_network_enumerator.go b/pkg/remote/azurerm/azurerm_virtual_network_enumerator.go deleted file mode 100644 index 5069539e..00000000 --- a/pkg/remote/azurerm/azurerm_virtual_network_enumerator.go +++ /dev/null @@ -1,48 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" -) - -type AzurermVirtualNetworkEnumerator struct { - repository repository.NetworkRepository - factory resource.ResourceFactory -} - -func NewAzurermVirtualNetworkEnumerator(repo repository.NetworkRepository, factory resource.ResourceFactory) *AzurermVirtualNetworkEnumerator { - return &AzurermVirtualNetworkEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *AzurermVirtualNetworkEnumerator) SupportedType() resource.ResourceType { - return azurerm.AzureVirtualNetworkResourceType -} - -func (e *AzurermVirtualNetworkEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.ListAllVirtualNetworks() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - *res.ID, - map[string]interface{}{ - "name": *res.Name, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/azurerm/init.go b/pkg/remote/azurerm/init.go deleted file mode 100644 index 23523eee..00000000 --- a/pkg/remote/azurerm/init.go +++ /dev/null @@ -1,105 +0,0 @@ -package azurerm - -import ( - "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/azurerm" - "github.com/snyk/driftctl/pkg/terraform" -) - -func Init( - version string, - alerter *alerter.Alerter, - providerLibrary *terraform.ProviderLibrary, - remoteLibrary *common.RemoteLibrary, - progress output.Progress, - resourceSchemaRepository *resource.SchemaRepository, - factory resource.ResourceFactory, - configDir string) error { - - provider, err := NewAzureTerraformProvider(version, progress, configDir) - if err != nil { - return err - } - err = provider.CheckCredentialsExist() - if err != nil { - return err - } - err = provider.Init() - if err != nil { - return err - } - - providerConfig := provider.GetConfig() - cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) - if err != nil { - return err - } - clientOptions := &arm.ClientOptions{} - - c := cache.New(100) - - storageAccountRepo := repository.NewStorageRepository(cred, clientOptions, providerConfig, c) - networkRepo := repository.NewNetworkRepository(cred, clientOptions, providerConfig, c) - resourcesRepo := repository.NewResourcesRepository(cred, clientOptions, providerConfig, c) - containerRegistryRepo := repository.NewContainerRegistryRepository(cred, clientOptions, providerConfig, c) - postgresqlRepo := repository.NewPostgresqlRepository(cred, clientOptions, providerConfig, c) - privateDNSRepo := repository.NewPrivateDNSRepository(cred, clientOptions, providerConfig, c) - computeRepo := repository.NewComputeRepository(cred, clientOptions, providerConfig, c) - - providerLibrary.AddProvider(terraform.AZURE, provider) - deserializer := resource.NewDeserializer(factory) - - remoteLibrary.AddEnumerator(NewAzurermStorageAccountEnumerator(storageAccountRepo, factory)) - remoteLibrary.AddEnumerator(NewAzurermStorageContainerEnumerator(storageAccountRepo, factory)) - remoteLibrary.AddEnumerator(NewAzurermVirtualNetworkEnumerator(networkRepo, factory)) - remoteLibrary.AddEnumerator(NewAzurermRouteTableEnumerator(networkRepo, factory)) - remoteLibrary.AddEnumerator(NewAzurermRouteEnumerator(networkRepo, factory)) - remoteLibrary.AddEnumerator(NewAzurermResourceGroupEnumerator(resourcesRepo, factory)) - remoteLibrary.AddEnumerator(NewAzurermSubnetEnumerator(networkRepo, factory)) - remoteLibrary.AddEnumerator(NewAzurermContainerRegistryEnumerator(containerRegistryRepo, factory)) - remoteLibrary.AddEnumerator(NewAzurermFirewallsEnumerator(networkRepo, factory)) - remoteLibrary.AddEnumerator(NewAzurermPostgresqlServerEnumerator(postgresqlRepo, factory)) - remoteLibrary.AddEnumerator(NewAzurermPublicIPEnumerator(networkRepo, factory)) - remoteLibrary.AddEnumerator(NewAzurermPostgresqlDatabaseEnumerator(postgresqlRepo, factory)) - remoteLibrary.AddEnumerator(NewAzurermNetworkSecurityGroupEnumerator(networkRepo, factory)) - remoteLibrary.AddDetailsFetcher(azurerm.AzureNetworkSecurityGroupResourceType, common.NewGenericDetailsFetcher(azurerm.AzureNetworkSecurityGroupResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewAzurermLoadBalancerEnumerator(networkRepo, factory)) - remoteLibrary.AddEnumerator(NewAzurermLoadBalancerRuleEnumerator(networkRepo, factory)) - remoteLibrary.AddDetailsFetcher(azurerm.AzureLoadBalancerRuleResourceType, common.NewGenericDetailsFetcher(azurerm.AzureLoadBalancerRuleResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewAzurermPrivateDNSZoneEnumerator(privateDNSRepo, factory)) - remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSZoneResourceType, common.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSZoneResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewAzurermPrivateDNSARecordEnumerator(privateDNSRepo, factory)) - remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSARecordResourceType, common.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSARecordResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewAzurermPrivateDNSAAAARecordEnumerator(privateDNSRepo, factory)) - remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSAAAARecordResourceType, common.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSAAAARecordResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewAzurermPrivateDNSMXRecordEnumerator(privateDNSRepo, factory)) - remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSMXRecordResourceType, common.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSMXRecordResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewAzurermPrivateDNSCNameRecordEnumerator(privateDNSRepo, factory)) - remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSCNameRecordResourceType, common.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSCNameRecordResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewAzurermPrivateDNSPTRRecordEnumerator(privateDNSRepo, factory)) - remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSPTRRecordResourceType, common.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSPTRRecordResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewAzurermPrivateDNSSRVRecordEnumerator(privateDNSRepo, factory)) - remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSSRVRecordResourceType, common.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSSRVRecordResourceType, provider, deserializer)) - remoteLibrary.AddEnumerator(NewAzurermPrivateDNSTXTRecordEnumerator(privateDNSRepo, factory)) - remoteLibrary.AddDetailsFetcher(azurerm.AzurePrivateDNSTXTRecordResourceType, common.NewGenericDetailsFetcher(azurerm.AzurePrivateDNSTXTRecordResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewAzurermImageEnumerator(computeRepo, factory)) - remoteLibrary.AddEnumerator(NewAzurermSSHPublicKeyEnumerator(computeRepo, factory)) - remoteLibrary.AddDetailsFetcher(azurerm.AzureSSHPublicKeyResourceType, common.NewGenericDetailsFetcher(azurerm.AzureSSHPublicKeyResourceType, provider, deserializer)) - - err = resourceSchemaRepository.Init(terraform.AZURE, provider.Version(), provider.Schema()) - if err != nil { - return err - } - azurerm.InitResourcesMetadata(resourceSchemaRepository) - - return nil -} diff --git a/pkg/remote/azurerm/provider.go b/pkg/remote/azurerm/provider.go deleted file mode 100644 index b9a3bc8b..00000000 --- a/pkg/remote/azurerm/provider.go +++ /dev/null @@ -1,95 +0,0 @@ -package azurerm - -import ( - "context" - "errors" - "os" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/remote/azurerm/common" - "github.com/snyk/driftctl/pkg/remote/terraform" - tf "github.com/snyk/driftctl/pkg/terraform" -) - -type AzureTerraformProvider struct { - *terraform.TerraformProvider - name string - version string -} - -func NewAzureTerraformProvider(version string, progress output.Progress, configDir string) (*AzureTerraformProvider, error) { - if version == "" { - version = "2.71.0" - } - // Just pass your version and name - p := &AzureTerraformProvider{ - version: version, - name: tf.AZURE, - } - // Use TerraformProviderInstaller to retrieve the provider if needed - installer, err := tf.NewProviderInstaller(tf.ProviderConfig{ - Key: p.name, - Version: version, - ConfigDir: configDir, - }) - if err != nil { - return nil, err - } - - tfProvider, err := terraform.NewTerraformProvider(installer, terraform.TerraformProviderConfig{ - Name: p.name, - GetProviderConfig: func(_ string) interface{} { - c := p.GetConfig() - return map[string]interface{}{ - "subscription_id": c.SubscriptionID, - "tenant_id": c.TenantID, - "client_id": c.ClientID, - "client_secret": c.ClientSecret, - "skip_provider_registration": true, - } - }, - }, progress) - if err != nil { - return nil, err - } - p.TerraformProvider = tfProvider - return p, err -} - -func (p *AzureTerraformProvider) GetConfig() common.AzureProviderConfig { - return common.AzureProviderConfig{ - SubscriptionID: os.Getenv("AZURE_SUBSCRIPTION_ID"), - TenantID: os.Getenv("AZURE_TENANT_ID"), - ClientID: os.Getenv("AZURE_CLIENT_ID"), - ClientSecret: os.Getenv("AZURE_CLIENT_SECRET"), - } -} - -func (p *AzureTerraformProvider) Name() string { - return p.name -} - -func (p *AzureTerraformProvider) Version() string { - return p.version -} - -func (p *AzureTerraformProvider) CheckCredentialsExist() error { - cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) - if err != nil { - return err - } - - _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"https://management.azure.com//.default"}}) - if err != nil { - return errors.New("Could not find any authentication method for Azure.\n" + - "For more information, please check the official Azure documentation: https://docs.microsoft.com/en-us/azure/developer/go/azure-sdk-authorization#use-environment-based-authentication") - } - - if p.GetConfig().SubscriptionID == "" { - return errors.New("Please provide an Azure subscription ID by setting the `AZURE_SUBSCRIPTION_ID` environment variable.") - } - - return nil -} diff --git a/pkg/remote/azurerm/repository/compute.go b/pkg/remote/azurerm/repository/compute.go deleted file mode 100644 index 31aa6544..00000000 --- a/pkg/remote/azurerm/repository/compute.go +++ /dev/null @@ -1,110 +0,0 @@ -package repository - -import ( - "context" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" - "github.com/snyk/driftctl/pkg/remote/azurerm/common" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type ComputeRepository interface { - ListAllImages() ([]*armcompute.Image, error) - ListAllSSHPublicKeys() ([]*armcompute.SSHPublicKeyResource, error) -} - -type imagesListPager interface { - pager - PageResponse() armcompute.ImagesListResponse -} - -type imagesClient interface { - List(options *armcompute.ImagesListOptions) imagesListPager -} - -type imagesClientImpl struct { - client *armcompute.ImagesClient -} - -func (c imagesClientImpl) List(options *armcompute.ImagesListOptions) imagesListPager { - return c.client.List(options) -} - -type sshPublicKeyListPager interface { - pager - PageResponse() armcompute.SSHPublicKeysListBySubscriptionResponse -} - -type sshPublicKeyClient interface { - ListBySubscription(options *armcompute.SSHPublicKeysListBySubscriptionOptions) sshPublicKeyListPager -} - -type sshPublicKeyClientImpl struct { - client *armcompute.SSHPublicKeysClient -} - -func (c sshPublicKeyClientImpl) ListBySubscription(options *armcompute.SSHPublicKeysListBySubscriptionOptions) sshPublicKeyListPager { - return c.client.ListBySubscription(options) -} - -type computeRepository struct { - imagesClient imagesClient - sshPublicKeyClient sshPublicKeyClient - cache cache.Cache -} - -func NewComputeRepository(cred azcore.TokenCredential, options *arm.ClientOptions, config common.AzureProviderConfig, cache cache.Cache) *computeRepository { - return &computeRepository{ - &imagesClientImpl{armcompute.NewImagesClient(config.SubscriptionID, cred, options)}, - &sshPublicKeyClientImpl{armcompute.NewSSHPublicKeysClient(config.SubscriptionID, cred, options)}, - cache, - } -} - -func (s *computeRepository) ListAllImages() ([]*armcompute.Image, error) { - cacheKey := "computeListAllImages" - if v := s.cache.Get(cacheKey); v != nil { - return v.([]*armcompute.Image), nil - } - - pager := s.imagesClient.List(nil) - results := make([]*armcompute.Image, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - results = append(results, resp.Value...) - } - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put(cacheKey, results) - return results, nil -} - -func (s *computeRepository) ListAllSSHPublicKeys() ([]*armcompute.SSHPublicKeyResource, error) { - cacheKey := "computeListAllSSHPublicKeys" - if v := s.cache.Get(cacheKey); v != nil { - return v.([]*armcompute.SSHPublicKeyResource), nil - } - - pager := s.sshPublicKeyClient.ListBySubscription(nil) - results := make([]*armcompute.SSHPublicKeyResource, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - results = append(results, resp.Value...) - } - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put(cacheKey, results) - return results, nil -} diff --git a/pkg/remote/azurerm/repository/compute_test.go b/pkg/remote/azurerm/repository/compute_test.go deleted file mode 100644 index 8d416bcd..00000000 --- a/pkg/remote/azurerm/repository/compute_test.go +++ /dev/null @@ -1,275 +0,0 @@ -package repository - -import ( - "reflect" - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" - "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func Test_Compute_ListAllImages(t *testing.T) { - expectedResults := []*armcompute.Image{ - { - Resource: armcompute.Resource{ - ID: to.StringPtr("/subscriptions/2c361f34-30fb-47ae-a227-83a5d3a26c66/resourceGroups/tfvmex-resources/providers/Microsoft.Compute/images/image1"), - Name: to.StringPtr("image1"), - }, - }, - { - Resource: armcompute.Resource{ - ID: to.StringPtr("/subscriptions/2c361f34-30fb-47ae-a227-83a5d3a26c66/resourceGroups/tfvmex-resources/providers/Microsoft.Compute/images/image2"), - Name: to.StringPtr("image2"), - }, - }, - { - Resource: armcompute.Resource{ - ID: to.StringPtr("/subscriptions/2c361f34-30fb-47ae-a227-83a5d3a26c66/resourceGroups/tfvmex-resources/providers/Microsoft.Compute/images/image3"), - Name: to.StringPtr("image3"), - }, - }, - } - - testcases := []struct { - name string - mocks func(*mockImagesListPager, *cache.MockCache) - expected []*armcompute.Image - wantErr string - }{ - { - name: "should return images", - mocks: func(mockPager *mockImagesListPager, mockCache *cache.MockCache) { - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armcompute.ImagesListResponse{ - ImagesListResult: armcompute.ImagesListResult{ - ImageListResult: armcompute.ImageListResult{ - Value: expectedResults[:2], - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armcompute.ImagesListResponse{ - ImagesListResult: armcompute.ImagesListResult{ - ImageListResult: armcompute.ImageListResult{ - Value: expectedResults[2:], - }, - }, - }).Times(1) - - mockCache.On("Get", "computeListAllImages").Return(nil).Times(1) - mockCache.On("Put", "computeListAllImages", expectedResults).Return(false).Times(1) - }, - expected: expectedResults, - }, - { - name: "should hit cache and return images", - mocks: func(mockPager *mockImagesListPager, mockCache *cache.MockCache) { - mockCache.On("Get", "computeListAllImages").Return(expectedResults).Times(1) - }, - expected: expectedResults, - }, - { - name: "should return remote error", - mocks: func(mockPager *mockImagesListPager, mockCache *cache.MockCache) { - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armcompute.ImagesListResponse{ - ImagesListResult: armcompute.ImagesListResult{ - ImageListResult: armcompute.ImageListResult{ - Value: []*armcompute.Image{}, - }, - }, - }).Times(1) - mockPager.On("Err").Return(errors.New("remote error")).Times(1) - - mockCache.On("Get", "computeListAllImages").Return(nil).Times(1) - }, - wantErr: "remote error", - }, - { - name: "should return remote error after fetching all pages", - mocks: func(mockPager *mockImagesListPager, mockCache *cache.MockCache) { - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armcompute.ImagesListResponse{ - ImagesListResult: armcompute.ImagesListResult{ - ImageListResult: armcompute.ImageListResult{ - Value: []*armcompute.Image{}, - }, - }, - }).Times(1) - mockPager.On("Err").Return(nil).Times(1) - mockPager.On("Err").Return(errors.New("remote error")).Times(1) - - mockCache.On("Get", "computeListAllImages").Return(nil).Times(1) - }, - wantErr: "remote error", - }, - } - - for _, tt := range testcases { - t.Run(tt.name, func(t *testing.T) { - fakeClient := &mockImagesClient{} - mockPager := &mockImagesListPager{} - mockCache := &cache.MockCache{} - - fakeClient.On("List", mock.Anything).Maybe().Return(mockPager) - - tt.mocks(mockPager, mockCache) - - s := &computeRepository{ - imagesClient: fakeClient, - cache: mockCache, - } - got, err := s.ListAllImages() - if tt.wantErr != "" { - assert.EqualError(t, err, tt.wantErr) - } else { - assert.Nil(t, err) - } - - fakeClient.AssertExpectations(t) - mockPager.AssertExpectations(t) - mockCache.AssertExpectations(t) - - if !reflect.DeepEqual(got, tt.expected) { - t.Errorf("ListAllResourceGroups() got = %v, want %v", got, tt.expected) - } - }) - } -} - -func Test_Compute_ListAllSSHPublicKeys(t *testing.T) { - expectedResults := []*armcompute.SSHPublicKeyResource{ - { - Resource: armcompute.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/TESTRESGROUP/providers/Microsoft.Compute/sshPublicKeys/key1"), - Name: to.StringPtr("key1"), - }, - }, - { - Resource: armcompute.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/TESTRESGROUP/providers/Microsoft.Compute/sshPublicKeys/key2"), - Name: to.StringPtr("key2"), - }, - }, - { - Resource: armcompute.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/TESTRESGROUP/providers/Microsoft.Compute/sshPublicKeys/key3"), - Name: to.StringPtr("key3"), - }, - }, - } - - testcases := []struct { - name string - mocks func(*mockSshPublicKeyListPager, *cache.MockCache) - expected []*armcompute.SSHPublicKeyResource - wantErr string - }{ - { - name: "should return SSH public keys", - mocks: func(mockPager *mockSshPublicKeyListPager, mockCache *cache.MockCache) { - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armcompute.SSHPublicKeysListBySubscriptionResponse{ - SSHPublicKeysListBySubscriptionResult: armcompute.SSHPublicKeysListBySubscriptionResult{ - SSHPublicKeysGroupListResult: armcompute.SSHPublicKeysGroupListResult{ - Value: expectedResults[:2], - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armcompute.SSHPublicKeysListBySubscriptionResponse{ - SSHPublicKeysListBySubscriptionResult: armcompute.SSHPublicKeysListBySubscriptionResult{ - SSHPublicKeysGroupListResult: armcompute.SSHPublicKeysGroupListResult{ - Value: expectedResults[2:], - }, - }, - }).Times(1) - - mockCache.On("Get", "computeListAllSSHPublicKeys").Return(nil).Times(1) - mockCache.On("Put", "computeListAllSSHPublicKeys", expectedResults).Return(false).Times(1) - }, - expected: expectedResults, - }, - { - name: "should hit cache and return SSH public keys", - mocks: func(mockPager *mockSshPublicKeyListPager, mockCache *cache.MockCache) { - mockCache.On("Get", "computeListAllSSHPublicKeys").Return(expectedResults).Times(1) - }, - expected: expectedResults, - }, - { - name: "should return remote error", - mocks: func(mockPager *mockSshPublicKeyListPager, mockCache *cache.MockCache) { - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armcompute.SSHPublicKeysListBySubscriptionResponse{ - SSHPublicKeysListBySubscriptionResult: armcompute.SSHPublicKeysListBySubscriptionResult{ - SSHPublicKeysGroupListResult: armcompute.SSHPublicKeysGroupListResult{ - Value: []*armcompute.SSHPublicKeyResource{}, - }, - }, - }).Times(1) - mockPager.On("Err").Return(errors.New("remote error")).Times(1) - - mockCache.On("Get", "computeListAllSSHPublicKeys").Return(nil).Times(1) - }, - wantErr: "remote error", - }, - { - name: "should return remote error after fetching all pages", - mocks: func(mockPager *mockSshPublicKeyListPager, mockCache *cache.MockCache) { - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armcompute.SSHPublicKeysListBySubscriptionResponse{ - SSHPublicKeysListBySubscriptionResult: armcompute.SSHPublicKeysListBySubscriptionResult{ - SSHPublicKeysGroupListResult: armcompute.SSHPublicKeysGroupListResult{ - Value: []*armcompute.SSHPublicKeyResource{}, - }, - }, - }).Times(1) - mockPager.On("Err").Return(nil).Times(1) - mockPager.On("Err").Return(errors.New("remote error")).Times(1) - - mockCache.On("Get", "computeListAllSSHPublicKeys").Return(nil).Times(1) - }, - wantErr: "remote error", - }, - } - - for _, tt := range testcases { - t.Run(tt.name, func(t *testing.T) { - fakeClient := &mockSshPublicKeyClient{} - mockPager := &mockSshPublicKeyListPager{} - mockCache := &cache.MockCache{} - - fakeClient.On("ListBySubscription", mock.Anything).Maybe().Return(mockPager) - - tt.mocks(mockPager, mockCache) - - s := &computeRepository{ - sshPublicKeyClient: fakeClient, - cache: mockCache, - } - got, err := s.ListAllSSHPublicKeys() - if tt.wantErr != "" { - assert.EqualError(t, err, tt.wantErr) - } else { - assert.Nil(t, err) - } - - fakeClient.AssertExpectations(t) - mockPager.AssertExpectations(t) - mockCache.AssertExpectations(t) - - if !reflect.DeepEqual(got, tt.expected) { - t.Errorf("ListAllResourceGroups() got = %v, want %v", got, tt.expected) - } - }) - } -} diff --git a/pkg/remote/azurerm/repository/containerregistry.go b/pkg/remote/azurerm/repository/containerregistry.go deleted file mode 100644 index fd1b35b5..00000000 --- a/pkg/remote/azurerm/repository/containerregistry.go +++ /dev/null @@ -1,69 +0,0 @@ -package repository - -import ( - "context" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry" - "github.com/snyk/driftctl/pkg/remote/azurerm/common" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type ContainerRegistryRepository interface { - ListAllContainerRegistries() ([]*armcontainerregistry.Registry, error) -} - -type registryClient interface { - List(options *armcontainerregistry.RegistriesListOptions) registryListAllPager -} - -type registryListAllPager interface { - pager - PageResponse() armcontainerregistry.RegistriesListResponse -} - -type registryClientImpl struct { - client *armcontainerregistry.RegistriesClient -} - -func (c registryClientImpl) List(options *armcontainerregistry.RegistriesListOptions) registryListAllPager { - return c.client.List(options) -} - -type containerRegistryRepository struct { - registryClient registryClient - cache cache.Cache -} - -func NewContainerRegistryRepository(cred azcore.TokenCredential, options *arm.ClientOptions, config common.AzureProviderConfig, cache cache.Cache) *containerRegistryRepository { - return &containerRegistryRepository{ - ®istryClientImpl{client: armcontainerregistry.NewRegistriesClient(config.SubscriptionID, cred, options)}, - cache, - } -} - -func (s *containerRegistryRepository) ListAllContainerRegistries() ([]*armcontainerregistry.Registry, error) { - - if v := s.cache.Get("ListAllContainerRegistries"); v != nil { - return v.([]*armcontainerregistry.Registry), nil - } - - pager := s.registryClient.List(nil) - results := make([]*armcontainerregistry.Registry, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - results = append(results, resp.Value...) - } - - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put("ListAllContainerRegistries", results) - - return results, nil -} diff --git a/pkg/remote/azurerm/repository/containerregistry_test.go b/pkg/remote/azurerm/repository/containerregistry_test.go deleted file mode 100644 index 5df8159d..00000000 --- a/pkg/remote/azurerm/repository/containerregistry_test.go +++ /dev/null @@ -1,144 +0,0 @@ -package repository - -import ( - "reflect" - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry" - "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func Test_Resources_ListAllContainerRegistries(t *testing.T) { - expectedResults := []*armcontainerregistry.Registry{ - { - Resource: armcontainerregistry.Resource{ - ID: to.StringPtr("/subscriptions/2c361f34-30fb-47ae-a227-83a5d3a26c66/resourceGroups/my-group/providers/Microsoft.ContainerRegistry/registries/containerRegistry1"), - Name: to.StringPtr("containerRegistry1"), - }, - }, - { - Resource: armcontainerregistry.Resource{ - ID: to.StringPtr("/subscriptions/2c361f34-30fb-47ae-a227-83a5d3a26c66/resourceGroups/my-group/providers/Microsoft.ContainerRegistry/registries/containerRegistry1"), - Name: to.StringPtr("containerRegistry2"), - }, - }, - { - Resource: armcontainerregistry.Resource{ - ID: to.StringPtr("/subscriptions/008b5f48-1b66-4d92-a6b6-d215b4c9b473/-/resource-3"), - Name: to.StringPtr("resource-3"), - }, - }, - } - - testcases := []struct { - name string - mocks func(*mockRegistryListAllPager, *cache.MockCache) - expected []*armcontainerregistry.Registry - wantErr string - }{ - { - name: "should return container registries", - mocks: func(mockPager *mockRegistryListAllPager, mockCache *cache.MockCache) { - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armcontainerregistry.RegistriesListResponse{ - RegistriesListResult: armcontainerregistry.RegistriesListResult{ - RegistryListResult: armcontainerregistry.RegistryListResult{ - Value: expectedResults[:2], - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armcontainerregistry.RegistriesListResponse{ - RegistriesListResult: armcontainerregistry.RegistriesListResult{ - RegistryListResult: armcontainerregistry.RegistryListResult{ - Value: expectedResults[2:], - }, - }, - }).Times(1) - - mockCache.On("Get", "ListAllContainerRegistries").Return(nil).Times(1) - mockCache.On("Put", "ListAllContainerRegistries", expectedResults).Return(false).Times(1) - }, - expected: expectedResults, - }, - { - name: "should hit cache and return container registries", - mocks: func(mockPager *mockRegistryListAllPager, mockCache *cache.MockCache) { - mockCache.On("Get", "ListAllContainerRegistries").Return(expectedResults).Times(1) - }, - expected: expectedResults, - }, - { - name: "should return remote error", - mocks: func(mockPager *mockRegistryListAllPager, mockCache *cache.MockCache) { - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armcontainerregistry.RegistriesListResponse{ - RegistriesListResult: armcontainerregistry.RegistriesListResult{ - RegistryListResult: armcontainerregistry.RegistryListResult{ - Value: []*armcontainerregistry.Registry{}, - }, - }, - }).Times(1) - mockPager.On("Err").Return(errors.New("remote error")).Times(1) - - mockCache.On("Get", "ListAllContainerRegistries").Return(nil).Times(1) - }, - wantErr: "remote error", - }, - { - name: "should return remote error after fetching all pages", - mocks: func(mockPager *mockRegistryListAllPager, mockCache *cache.MockCache) { - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armcontainerregistry.RegistriesListResponse{ - RegistriesListResult: armcontainerregistry.RegistriesListResult{ - RegistryListResult: armcontainerregistry.RegistryListResult{ - Value: []*armcontainerregistry.Registry{}, - }, - }, - }).Times(1) - mockPager.On("Err").Return(nil).Times(1) - mockPager.On("Err").Return(errors.New("remote error")).Times(1) - - mockCache.On("Get", "ListAllContainerRegistries").Return(nil).Times(1) - }, - wantErr: "remote error", - }, - } - - for _, tt := range testcases { - t.Run(tt.name, func(t *testing.T) { - fakeClient := &mockRegistryClient{} - mockPager := &mockRegistryListAllPager{} - mockCache := &cache.MockCache{} - - fakeClient.On("List", mock.Anything).Maybe().Return(mockPager) - - tt.mocks(mockPager, mockCache) - - s := &containerRegistryRepository{ - registryClient: fakeClient, - cache: mockCache, - } - got, err := s.ListAllContainerRegistries() - if tt.wantErr != "" { - assert.EqualError(t, err, tt.wantErr) - } else { - assert.Nil(t, err) - } - - fakeClient.AssertExpectations(t) - mockPager.AssertExpectations(t) - mockCache.AssertExpectations(t) - - if !reflect.DeepEqual(got, tt.expected) { - t.Errorf("ListAllResourceGroups() got = %v, want %v", got, tt.expected) - } - }) - } -} diff --git a/pkg/remote/azurerm/repository/network.go b/pkg/remote/azurerm/repository/network.go deleted file mode 100644 index da9c131e..00000000 --- a/pkg/remote/azurerm/repository/network.go +++ /dev/null @@ -1,405 +0,0 @@ -package repository - -import ( - "context" - "fmt" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" - "github.com/Azure/go-autorest/autorest/azure" - "github.com/snyk/driftctl/pkg/remote/azurerm/common" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type NetworkRepository interface { - ListAllVirtualNetworks() ([]*armnetwork.VirtualNetwork, error) - ListAllRouteTables() ([]*armnetwork.RouteTable, error) - ListAllSubnets(virtualNetwork *armnetwork.VirtualNetwork) ([]*armnetwork.Subnet, error) - ListAllFirewalls() ([]*armnetwork.AzureFirewall, error) - ListAllPublicIPAddresses() ([]*armnetwork.PublicIPAddress, error) - ListAllSecurityGroups() ([]*armnetwork.NetworkSecurityGroup, error) - ListAllLoadBalancers() ([]*armnetwork.LoadBalancer, error) - ListLoadBalancerRules(*armnetwork.LoadBalancer) ([]*armnetwork.LoadBalancingRule, error) -} - -type publicIPAddressesClient interface { - ListAll(options *armnetwork.PublicIPAddressesListAllOptions) publicIPAddressesListAllPager -} - -type publicIPAddressesListAllPager interface { - pager - PageResponse() armnetwork.PublicIPAddressesListAllResponse -} - -type publicIPAddressesClientImpl struct { - client *armnetwork.PublicIPAddressesClient -} - -func (p publicIPAddressesClientImpl) ListAll(options *armnetwork.PublicIPAddressesListAllOptions) publicIPAddressesListAllPager { - return p.client.ListAll(options) -} - -type firewallsListAllPager interface { - pager - PageResponse() armnetwork.AzureFirewallsListAllResponse -} - -type firewallsClient interface { - ListAll(options *armnetwork.AzureFirewallsListAllOptions) firewallsListAllPager -} - -type firewallsClientImpl struct { - client *armnetwork.AzureFirewallsClient -} - -func (s firewallsClientImpl) ListAll(options *armnetwork.AzureFirewallsListAllOptions) firewallsListAllPager { - return s.client.ListAll(options) -} - -type subnetsListPager interface { - pager - PageResponse() armnetwork.SubnetsListResponse -} - -type subnetsClient interface { - List(resourceGroupName, virtualNetworkName string, options *armnetwork.SubnetsListOptions) subnetsListPager -} - -type subnetsClientImpl struct { - client *armnetwork.SubnetsClient -} - -func (s subnetsClientImpl) List(resourceGroupName, virtualNetworkName string, options *armnetwork.SubnetsListOptions) subnetsListPager { - return s.client.List(resourceGroupName, virtualNetworkName, options) -} - -type virtualNetworksClient interface { - ListAll(options *armnetwork.VirtualNetworksListAllOptions) virtualNetworksListAllPager -} - -type virtualNetworksListAllPager interface { - pager - PageResponse() armnetwork.VirtualNetworksListAllResponse -} - -type virtualNetworksClientImpl struct { - client *armnetwork.VirtualNetworksClient -} - -func (c virtualNetworksClientImpl) ListAll(options *armnetwork.VirtualNetworksListAllOptions) virtualNetworksListAllPager { - return c.client.ListAll(options) -} - -type routeTablesClient interface { - ListAll(options *armnetwork.RouteTablesListAllOptions) routeTablesListAllPager -} - -type routeTablesListAllPager interface { - pager - PageResponse() armnetwork.RouteTablesListAllResponse -} - -type routeTablesClientImpl struct { - client *armnetwork.RouteTablesClient -} - -func (c routeTablesClientImpl) ListAll(options *armnetwork.RouteTablesListAllOptions) routeTablesListAllPager { - return c.client.ListAll(options) -} - -type networkSecurityGroupsListAllPager interface { - pager - PageResponse() armnetwork.NetworkSecurityGroupsListAllResponse -} - -type networkSecurityGroupsClient interface { - ListAll(options *armnetwork.NetworkSecurityGroupsListAllOptions) networkSecurityGroupsListAllPager -} - -type networkSecurityGroupsClientImpl struct { - client *armnetwork.NetworkSecurityGroupsClient -} - -func (s networkSecurityGroupsClientImpl) ListAll(options *armnetwork.NetworkSecurityGroupsListAllOptions) networkSecurityGroupsListAllPager { - return s.client.ListAll(options) -} - -type loadBalancersListAllPager interface { - pager - PageResponse() armnetwork.LoadBalancersListAllResponse -} - -type loadBalancersClient interface { - ListAll(options *armnetwork.LoadBalancersListAllOptions) loadBalancersListAllPager -} - -type loadBalancersClientImpl struct { - client *armnetwork.LoadBalancersClient -} - -func (s loadBalancersClientImpl) ListAll(options *armnetwork.LoadBalancersListAllOptions) loadBalancersListAllPager { - return s.client.ListAll(options) -} - -type loadBalancerRulesListAllPager interface { - pager - PageResponse() armnetwork.LoadBalancerLoadBalancingRulesListResponse -} - -type loadBalancerRulesClient interface { - List(string, string, *armnetwork.LoadBalancerLoadBalancingRulesListOptions) loadBalancerRulesListAllPager -} - -type loadBalancerRulesClientImpl struct { - client *armnetwork.LoadBalancerLoadBalancingRulesClient -} - -func (s loadBalancerRulesClientImpl) List(resourceGroupName string, loadBalancerName string, options *armnetwork.LoadBalancerLoadBalancingRulesListOptions) loadBalancerRulesListAllPager { - return s.client.List(resourceGroupName, loadBalancerName, options) -} - -type networkRepository struct { - virtualNetworksClient virtualNetworksClient - routeTableClient routeTablesClient - subnetsClient subnetsClient - firewallsClient firewallsClient - publicIPAddressesClient publicIPAddressesClient - networkSecurityGroupsClient networkSecurityGroupsClient - loadBalancersClient loadBalancersClient - loadBalancerRulesClient loadBalancerRulesClient - cache cache.Cache -} - -func NewNetworkRepository(cred azcore.TokenCredential, options *arm.ClientOptions, config common.AzureProviderConfig, cache cache.Cache) *networkRepository { - return &networkRepository{ - &virtualNetworksClientImpl{client: armnetwork.NewVirtualNetworksClient(config.SubscriptionID, cred, options)}, - &routeTablesClientImpl{client: armnetwork.NewRouteTablesClient(config.SubscriptionID, cred, options)}, - &subnetsClientImpl{client: armnetwork.NewSubnetsClient(config.SubscriptionID, cred, options)}, - &firewallsClientImpl{client: armnetwork.NewAzureFirewallsClient(config.SubscriptionID, cred, options)}, - &publicIPAddressesClientImpl{client: armnetwork.NewPublicIPAddressesClient(config.SubscriptionID, cred, options)}, - &networkSecurityGroupsClientImpl{client: armnetwork.NewNetworkSecurityGroupsClient(config.SubscriptionID, cred, options)}, - &loadBalancersClientImpl{client: armnetwork.NewLoadBalancersClient(config.SubscriptionID, cred, options)}, - &loadBalancerRulesClientImpl{armnetwork.NewLoadBalancerLoadBalancingRulesClient(config.SubscriptionID, cred, options)}, - cache, - } -} - -func (s *networkRepository) ListAllVirtualNetworks() ([]*armnetwork.VirtualNetwork, error) { - - cacheKey := "ListAllVirtualNetworks" - v := s.cache.GetAndLock(cacheKey) - defer s.cache.Unlock(cacheKey) - if v != nil { - return v.([]*armnetwork.VirtualNetwork), nil - } - - pager := s.virtualNetworksClient.ListAll(nil) - results := make([]*armnetwork.VirtualNetwork, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - results = append(results, resp.VirtualNetworksListAllResult.VirtualNetworkListResult.Value...) - } - - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put(cacheKey, results) - - return results, nil -} - -func (s *networkRepository) ListAllRouteTables() ([]*armnetwork.RouteTable, error) { - cacheKey := "ListAllRouteTables" - v := s.cache.GetAndLock(cacheKey) - defer s.cache.Unlock(cacheKey) - if v != nil { - return v.([]*armnetwork.RouteTable), nil - } - - pager := s.routeTableClient.ListAll(nil) - results := make([]*armnetwork.RouteTable, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - results = append(results, resp.RouteTablesListAllResult.RouteTableListResult.Value...) - } - - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put(cacheKey, results) - - return results, nil -} - -func (s *networkRepository) ListAllSubnets(virtualNetwork *armnetwork.VirtualNetwork) ([]*armnetwork.Subnet, error) { - - cacheKey := fmt.Sprintf("ListAllSubnets_%s", *virtualNetwork.ID) - - if v := s.cache.Get(cacheKey); v != nil { - return v.([]*armnetwork.Subnet), nil - } - - res, err := azure.ParseResourceID(*virtualNetwork.ID) - if err != nil { - return nil, err - } - - pager := s.subnetsClient.List(res.ResourceGroup, *virtualNetwork.Name, nil) - results := make([]*armnetwork.Subnet, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - results = append(results, resp.SubnetsListResult.SubnetListResult.Value...) - } - - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put(cacheKey, results) - - return results, nil -} - -func (s *networkRepository) ListAllFirewalls() ([]*armnetwork.AzureFirewall, error) { - - cacheKey := "ListAllFirewalls" - - if v := s.cache.Get(cacheKey); v != nil { - return v.([]*armnetwork.AzureFirewall), nil - } - - pager := s.firewallsClient.ListAll(nil) - results := make([]*armnetwork.AzureFirewall, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - results = append(results, resp.AzureFirewallsListAllResult.AzureFirewallListResult.Value...) - } - - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put(cacheKey, results) - - return results, nil -} - -func (s *networkRepository) ListAllPublicIPAddresses() ([]*armnetwork.PublicIPAddress, error) { - cacheKey := "ListAllPublicIPAddresses" - - if v := s.cache.Get(cacheKey); v != nil { - return v.([]*armnetwork.PublicIPAddress), nil - } - - pager := s.publicIPAddressesClient.ListAll(nil) - results := make([]*armnetwork.PublicIPAddress, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - results = append(results, resp.PublicIPAddressesListAllResult.PublicIPAddressListResult.Value...) - } - - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put(cacheKey, results) - - return results, nil -} - -func (s *networkRepository) ListAllSecurityGroups() ([]*armnetwork.NetworkSecurityGroup, error) { - cacheKey := "networkListAllSecurityGroups" - if v := s.cache.Get(cacheKey); v != nil { - return v.([]*armnetwork.NetworkSecurityGroup), nil - } - - pager := s.networkSecurityGroupsClient.ListAll(nil) - results := make([]*armnetwork.NetworkSecurityGroup, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - results = append(results, resp.Value...) - } - - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put(cacheKey, results) - - return results, nil -} - -func (s *networkRepository) ListAllLoadBalancers() ([]*armnetwork.LoadBalancer, error) { - cacheKey := "networkListAllLoadBalancers" - defer s.cache.Unlock(cacheKey) - if v := s.cache.GetAndLock(cacheKey); v != nil { - return v.([]*armnetwork.LoadBalancer), nil - } - - pager := s.loadBalancersClient.ListAll(nil) - results := make([]*armnetwork.LoadBalancer, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - results = append(results, resp.Value...) - } - - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put(cacheKey, results) - return results, nil -} - -func (s *networkRepository) ListLoadBalancerRules(loadBalancer *armnetwork.LoadBalancer) ([]*armnetwork.LoadBalancingRule, error) { - cacheKey := fmt.Sprintf("networkListLoadBalancerRules_%s", *loadBalancer.ID) - if v := s.cache.Get(cacheKey); v != nil { - return v.([]*armnetwork.LoadBalancingRule), nil - } - - loadBalancerResource, err := azure.ParseResourceID(*loadBalancer.ID) - if err != nil { - return nil, err - } - - pager := s.loadBalancerRulesClient.List(loadBalancerResource.ResourceGroup, loadBalancerResource.ResourceName, &armnetwork.LoadBalancerLoadBalancingRulesListOptions{}) - results := make([]*armnetwork.LoadBalancingRule, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - results = append(results, resp.Value...) - } - - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put(cacheKey, results) - return results, nil -} diff --git a/pkg/remote/azurerm/repository/network_test.go b/pkg/remote/azurerm/repository/network_test.go deleted file mode 100644 index fa42fdfa..00000000 --- a/pkg/remote/azurerm/repository/network_test.go +++ /dev/null @@ -1,1172 +0,0 @@ -package repository - -import ( - "context" - "fmt" - "reflect" - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" - "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func Test_ListAllVirtualNetwork_MultiplesResults(t *testing.T) { - - expected := []*armnetwork.VirtualNetwork{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("network1"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("network2"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("network3"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("network4"), - }, - }, - } - - fakeClient := &mockVirtualNetworkClient{} - - mockPager := &mockVirtualNetworksListAllPager{} - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armnetwork.VirtualNetworksListAllResponse{ - VirtualNetworksListAllResult: armnetwork.VirtualNetworksListAllResult{ - VirtualNetworkListResult: armnetwork.VirtualNetworkListResult{ - Value: []*armnetwork.VirtualNetwork{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("network1"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("network2"), - }, - }, - }, - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armnetwork.VirtualNetworksListAllResponse{ - VirtualNetworksListAllResult: armnetwork.VirtualNetworksListAllResult{ - VirtualNetworkListResult: armnetwork.VirtualNetworkListResult{ - Value: []*armnetwork.VirtualNetwork{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("network3"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("network4"), - }, - }, - }, - }, - }, - }).Times(1) - - fakeClient.On("ListAll", mock.Anything).Return(mockPager) - - c := &cache.MockCache{} - c.On("GetAndLock", "ListAllVirtualNetworks").Return(nil).Times(1) - c.On("Unlock", "ListAllVirtualNetworks").Times(1) - c.On("Put", "ListAllVirtualNetworks", expected).Return(true).Times(1) - s := &networkRepository{ - virtualNetworksClient: fakeClient, - cache: c, - } - got, err := s.ListAllVirtualNetworks() - if err != nil { - t.Errorf("ListAllVirtualNetworks() error = %v", err) - return - } - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllVirtualNetworks() got = %v, want %v", got, expected) - } -} - -func Test_ListAllVirtualNetwork_MultiplesResults_WithCache(t *testing.T) { - - expected := []*armnetwork.VirtualNetwork{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("network3"), - }, - }, - } - - fakeClient := &mockVirtualNetworkClient{} - - c := &cache.MockCache{} - c.On("GetAndLock", "ListAllVirtualNetworks").Return(expected).Times(1) - c.On("Unlock", "ListAllVirtualNetworks").Times(1) - s := &networkRepository{ - virtualNetworksClient: fakeClient, - cache: c, - } - got, err := s.ListAllVirtualNetworks() - if err != nil { - t.Errorf("ListAllVirtualNetworks() error = %v", err) - return - } - - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllVirtualNetworks() got = %v, want %v", got, expected) - } -} - -func Test_ListAllVirtualNetwork_Error_OnPageResponse(t *testing.T) { - - fakeClient := &mockVirtualNetworkClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockVirtualNetworksListAllPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armnetwork.VirtualNetworksListAllResponse{}).Times(1) - - fakeClient.On("ListAll", mock.Anything).Return(mockPager) - - s := &networkRepository{ - virtualNetworksClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllVirtualNetworks() - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -func Test_ListAllVirtualNetwork_Error(t *testing.T) { - - fakeClient := &mockVirtualNetworkClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockVirtualNetworksListAllPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - - fakeClient.On("ListAll", mock.Anything).Return(mockPager) - - s := &networkRepository{ - virtualNetworksClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllVirtualNetworks() - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -func Test_ListAllRouteTables_MultiplesResults(t *testing.T) { - - expected := []*armnetwork.RouteTable{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("table1"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("table2"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("table3"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("table4"), - }, - }, - } - - fakeClient := &mockRouteTablesClient{} - - mockPager := &mockRouteTablesListAllPager{} - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armnetwork.RouteTablesListAllResponse{ - RouteTablesListAllResult: armnetwork.RouteTablesListAllResult{ - RouteTableListResult: armnetwork.RouteTableListResult{ - Value: expected[:2], - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armnetwork.RouteTablesListAllResponse{ - RouteTablesListAllResult: armnetwork.RouteTablesListAllResult{ - RouteTableListResult: armnetwork.RouteTableListResult{ - Value: expected[2:], - }, - }, - }).Times(1) - - fakeClient.On("ListAll", mock.Anything).Return(mockPager) - - c := &cache.MockCache{} - c.On("GetAndLock", "ListAllRouteTables").Return(nil).Times(1) - c.On("Unlock", "ListAllRouteTables").Times(1) - c.On("Put", "ListAllRouteTables", expected).Return(true).Times(1) - s := &networkRepository{ - routeTableClient: fakeClient, - cache: c, - } - got, err := s.ListAllRouteTables() - if err != nil { - t.Errorf("ListAllRouteTables() error = %v", err) - return - } - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllRouteTables() got = %v, want %v", got, expected) - } -} - -func Test_ListAllRouteTables_MultiplesResults_WithCache(t *testing.T) { - - expected := []*armnetwork.RouteTable{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("table1"), - }, - }, - } - - fakeClient := &mockRouteTablesClient{} - - c := &cache.MockCache{} - c.On("GetAndLock", "ListAllRouteTables").Return(expected).Times(1) - c.On("Unlock", "ListAllRouteTables").Times(1) - s := &networkRepository{ - routeTableClient: fakeClient, - cache: c, - } - got, err := s.ListAllRouteTables() - if err != nil { - t.Errorf("ListAllRouteTables() error = %v", err) - return - } - - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllRouteTables() got = %v, want %v", got, expected) - } -} - -func Test_ListAllRouteTables_Error_OnPageResponse(t *testing.T) { - - fakeClient := &mockRouteTablesClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockRouteTablesListAllPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armnetwork.RouteTablesListAllResponse{}).Times(1) - - fakeClient.On("ListAll", mock.Anything).Return(mockPager) - - s := &networkRepository{ - routeTableClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllRouteTables() - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -func Test_ListAllRouteTables_Error(t *testing.T) { - - fakeClient := &mockRouteTablesClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockRouteTablesListAllPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - - fakeClient.On("ListAll", mock.Anything).Return(mockPager) - - s := &networkRepository{ - routeTableClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllRouteTables() - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -func Test_ListAllSubnets_MultiplesResults(t *testing.T) { - - network := &armnetwork.VirtualNetwork{ - Resource: armnetwork.Resource{ - Name: to.StringPtr("network1"), - ID: to.StringPtr("/subscriptions/7bfb2c5c-0000-0000-0000-fffa356eb406/resourceGroups/test-dev/providers/Microsoft.Network/virtualNetworks/network1"), - }, - } - - expected := []*armnetwork.Subnet{ - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("subnet1"), - }, - }, - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("subnet2"), - }, - }, - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("subnet3"), - }, - }, - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("subnet4"), - }, - }, - } - - fakeClient := &mockSubnetsClient{} - - mockPager := &mockSubnetsListPager{} - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armnetwork.SubnetsListResponse{ - SubnetsListResult: armnetwork.SubnetsListResult{ - SubnetListResult: armnetwork.SubnetListResult{ - Value: []*armnetwork.Subnet{ - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("subnet1"), - }, - }, - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("subnet2"), - }, - }, - }, - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armnetwork.SubnetsListResponse{ - SubnetsListResult: armnetwork.SubnetsListResult{ - SubnetListResult: armnetwork.SubnetListResult{ - Value: []*armnetwork.Subnet{ - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("subnet3"), - }, - }, - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("subnet4"), - }, - }, - }, - }, - }, - }).Times(1) - - fakeClient.On("List", "test-dev", "network1", mock.Anything).Return(mockPager) - - c := &cache.MockCache{} - cacheKey := fmt.Sprintf("ListAllSubnets_%s", *network.ID) - c.On("Get", cacheKey).Return(nil).Times(1) - c.On("Put", cacheKey, expected).Return(true).Times(1) - s := &networkRepository{ - subnetsClient: fakeClient, - cache: c, - } - got, err := s.ListAllSubnets(network) - if err != nil { - t.Errorf("ListAllSubnets() error = %v", err) - return - } - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllSubnets() got = %v, want %v", got, expected) - } -} - -func Test_ListAllSubnets_MultiplesResults_WithCache(t *testing.T) { - - network := &armnetwork.VirtualNetwork{ - Resource: armnetwork.Resource{ - ID: to.StringPtr("networkID"), - }, - } - - expected := []*armnetwork.Subnet{ - { - Name: to.StringPtr("network1"), - }, - } - fakeClient := &mockSubnetsClient{} - - c := &cache.MockCache{} - c.On("Get", "ListAllSubnets_networkID").Return(expected).Times(1) - s := &networkRepository{ - subnetsClient: fakeClient, - cache: c, - } - got, err := s.ListAllSubnets(network) - if err != nil { - t.Errorf("ListAllSubnets() error = %v", err) - return - } - - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllSubnets() got = %v, want %v", got, expected) - } -} - -func Test_ListAllSubnets_Error_OnPageResponse(t *testing.T) { - - network := &armnetwork.VirtualNetwork{ - Resource: armnetwork.Resource{ - Name: to.StringPtr("network1"), - ID: to.StringPtr("/subscriptions/7bfb2c5c-0000-0000-0000-fffa356eb406/resourceGroups/test-dev/providers/Microsoft.Network/virtualNetworks/network1"), - }, - } - - fakeClient := &mockSubnetsClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockSubnetsListPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armnetwork.SubnetsListResponse{}).Times(1) - - fakeClient.On("List", "test-dev", "network1", mock.Anything).Return(mockPager) - - s := &networkRepository{ - subnetsClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllSubnets(network) - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -func Test_ListAllSubnets_Error(t *testing.T) { - - network := &armnetwork.VirtualNetwork{ - Resource: armnetwork.Resource{ - Name: to.StringPtr("network1"), - ID: to.StringPtr("/subscriptions/7bfb2c5c-0000-0000-0000-fffa356eb406/resourceGroups/test-dev/providers/Microsoft.Network/virtualNetworks/network1"), - }, - } - - fakeClient := &mockSubnetsClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockSubnetsListPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - - fakeClient.On("List", "test-dev", "network1", mock.Anything).Return(mockPager) - - s := &networkRepository{ - subnetsClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllSubnets(network) - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -func Test_ListAllSubnets_ErrorOnInvalidNetworkID(t *testing.T) { - - network := &armnetwork.VirtualNetwork{ - Resource: armnetwork.Resource{ - Name: to.StringPtr("network1"), - ID: to.StringPtr("foobar"), - }, - } - - fakeClient := &mockSubnetsClient{} - - expectedErr := errors.New("parsing failed for foobar. Invalid resource Id format") - - s := &networkRepository{ - subnetsClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllSubnets(network) - - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr.Error(), err.Error()) - assert.Nil(t, got) -} - -func Test_ListAllFirewalls_MultiplesResults(t *testing.T) { - - expected := []*armnetwork.AzureFirewall{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("firewall1"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("firewall2"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("firewall3"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("firewall4"), - }, - }, - } - - fakeClient := &mockFirewallsClient{} - - mockPager := &mockFirewallsListAllPager{} - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armnetwork.AzureFirewallsListAllResponse{ - AzureFirewallsListAllResult: armnetwork.AzureFirewallsListAllResult{ - AzureFirewallListResult: armnetwork.AzureFirewallListResult{ - Value: expected[:2], - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armnetwork.AzureFirewallsListAllResponse{ - AzureFirewallsListAllResult: armnetwork.AzureFirewallsListAllResult{ - AzureFirewallListResult: armnetwork.AzureFirewallListResult{ - Value: expected[2:], - }, - }, - }).Times(1) - - fakeClient.On("ListAll", mock.Anything).Return(mockPager) - - c := &cache.MockCache{} - c.On("Get", "ListAllFirewalls").Return(nil).Times(1) - c.On("Put", "ListAllFirewalls", expected).Return(true).Times(1) - s := &networkRepository{ - firewallsClient: fakeClient, - cache: c, - } - got, err := s.ListAllFirewalls() - if err != nil { - t.Errorf("ListAllFirewalls() error = %v", err) - return - } - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllFirewalls() got = %v, want %v", got, expected) - } -} - -func Test_ListAllFirewalls_MultiplesResults_WithCache(t *testing.T) { - - expected := []*armnetwork.AzureFirewall{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("firewall1"), - }, - }, - } - - fakeClient := &mockFirewallsClient{} - - c := &cache.MockCache{} - c.On("Get", "ListAllFirewalls").Return(expected).Times(1) - s := &networkRepository{ - firewallsClient: fakeClient, - cache: c, - } - got, err := s.ListAllFirewalls() - if err != nil { - t.Errorf("ListAllFirewalls() error = %v", err) - return - } - - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllFirewalls() got = %v, want %v", got, expected) - } -} - -func Test_ListAllFirewalls_Error_OnPageResponse(t *testing.T) { - - fakeClient := &mockFirewallsClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockFirewallsListAllPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armnetwork.AzureFirewallsListAllResponse{}).Times(1) - - fakeClient.On("ListAll", mock.Anything).Return(mockPager) - - s := &networkRepository{ - firewallsClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllFirewalls() - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -func Test_ListAllFirewalls_Error(t *testing.T) { - - fakeClient := &mockFirewallsClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockFirewallsListAllPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - - fakeClient.On("ListAll", mock.Anything).Return(mockPager) - - s := &networkRepository{ - firewallsClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllFirewalls() - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -func Test_ListAllPublicIPAddresses_MultiplesResults(t *testing.T) { - - expected := []*armnetwork.PublicIPAddress{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("ip1"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("ip2"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("ip3"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("ip4"), - }, - }, - } - - fakeClient := &mockPublicIPAddressesClient{} - - mockPager := &mockPublicIPAddressesListAllPager{} - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armnetwork.PublicIPAddressesListAllResponse{ - PublicIPAddressesListAllResult: armnetwork.PublicIPAddressesListAllResult{ - PublicIPAddressListResult: armnetwork.PublicIPAddressListResult{ - Value: expected[:2], - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armnetwork.PublicIPAddressesListAllResponse{ - PublicIPAddressesListAllResult: armnetwork.PublicIPAddressesListAllResult{ - PublicIPAddressListResult: armnetwork.PublicIPAddressListResult{ - Value: expected[2:], - }, - }, - }).Times(1) - - fakeClient.On("ListAll", mock.Anything).Return(mockPager) - - c := &cache.MockCache{} - c.On("Get", "ListAllPublicIPAddresses").Return(nil).Times(1) - c.On("Put", "ListAllPublicIPAddresses", expected).Return(true).Times(1) - s := &networkRepository{ - publicIPAddressesClient: fakeClient, - cache: c, - } - got, err := s.ListAllPublicIPAddresses() - if err != nil { - t.Errorf("ListAllPublicIPAddresses() error = %v", err) - return - } - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllPublicIPAddresses() got = %v, want %v", got, expected) - } -} - -func Test_ListAllPublicIPAddresses_MultiplesResults_WithCache(t *testing.T) { - - expected := []*armnetwork.PublicIPAddress{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("ip1"), - }, - }, - } - - fakeClient := &mockPublicIPAddressesClient{} - - c := &cache.MockCache{} - c.On("Get", "ListAllPublicIPAddresses").Return(expected).Times(1) - s := &networkRepository{ - publicIPAddressesClient: fakeClient, - cache: c, - } - got, err := s.ListAllPublicIPAddresses() - if err != nil { - t.Errorf("ListAllPublicIPAddresses() error = %v", err) - return - } - - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllPublicIPAddresses() got = %v, want %v", got, expected) - } -} - -func Test_ListAllPublicIPAddresses_Error_OnPageResponse(t *testing.T) { - - fakeClient := &mockPublicIPAddressesClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockPublicIPAddressesListAllPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armnetwork.PublicIPAddressesListAllResponse{}).Times(1) - - fakeClient.On("ListAll", mock.Anything).Return(mockPager) - - s := &networkRepository{ - publicIPAddressesClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllPublicIPAddresses() - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -func Test_ListAllPublicIPAddresses_Error(t *testing.T) { - - fakeClient := &mockPublicIPAddressesClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockPublicIPAddressesListAllPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - - fakeClient.On("ListAll", mock.Anything).Return(mockPager) - - s := &networkRepository{ - publicIPAddressesClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllPublicIPAddresses() - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -func Test_Network_ListAllSecurityGroups(t *testing.T) { - expectedResults := []*armnetwork.NetworkSecurityGroup{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("sgroup-1"), - Name: to.StringPtr("sgroup-1"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("sgroup-2"), - Name: to.StringPtr("sgroup-2"), - }, - }, - } - - testcases := []struct { - name string - mocks func(*mockNetworkSecurityGroupsListAllPager, *cache.MockCache) - expected []*armnetwork.NetworkSecurityGroup - wantErr string - }{ - { - name: "should return security groups", - mocks: func(pager *mockNetworkSecurityGroupsListAllPager, mockCache *cache.MockCache) { - pager.On("NextPage", context.Background()).Return(true).Times(1) - pager.On("NextPage", context.Background()).Return(false).Times(1) - pager.On("PageResponse").Return(armnetwork.NetworkSecurityGroupsListAllResponse{ - NetworkSecurityGroupsListAllResult: armnetwork.NetworkSecurityGroupsListAllResult{ - NetworkSecurityGroupListResult: armnetwork.NetworkSecurityGroupListResult{ - Value: expectedResults, - }, - }, - }).Times(1) - pager.On("Err").Return(nil).Times(2) - - mockCache.On("Get", "networkListAllSecurityGroups").Return(nil).Times(1) - mockCache.On("Put", "networkListAllSecurityGroups", expectedResults).Return(false).Times(1) - }, - expected: expectedResults, - }, - { - name: "should hit cache and return security groups", - mocks: func(pager *mockNetworkSecurityGroupsListAllPager, mockCache *cache.MockCache) { - mockCache.On("Get", "networkListAllSecurityGroups").Return(expectedResults).Times(1) - }, - expected: expectedResults, - }, - { - name: "should return remote error", - mocks: func(pager *mockNetworkSecurityGroupsListAllPager, mockCache *cache.MockCache) { - pager.On("NextPage", context.Background()).Return(true).Times(1) - pager.On("NextPage", context.Background()).Return(false).Times(1) - pager.On("PageResponse").Return(armnetwork.NetworkSecurityGroupsListAllResponse{}).Times(1) - pager.On("Err").Return(errors.New("remote error")).Times(1) - - mockCache.On("Get", "networkListAllSecurityGroups").Return(nil).Times(1) - }, - wantErr: "remote error", - }, - } - - for _, tt := range testcases { - t.Run(tt.name, func(t *testing.T) { - fakePager := &mockNetworkSecurityGroupsListAllPager{} - fakeClient := &mockNetworkSecurityGroupsClient{} - mockCache := &cache.MockCache{} - - fakeClient.On("ListAll", (*armnetwork.NetworkSecurityGroupsListAllOptions)(nil)).Return(fakePager).Maybe() - - tt.mocks(fakePager, mockCache) - - s := &networkRepository{ - networkSecurityGroupsClient: fakeClient, - cache: mockCache, - } - got, err := s.ListAllSecurityGroups() - if tt.wantErr != "" { - assert.EqualError(t, err, tt.wantErr) - } else { - assert.Nil(t, err) - } - - fakeClient.AssertExpectations(t) - mockCache.AssertExpectations(t) - - if !reflect.DeepEqual(got, tt.expected) { - t.Errorf("ListAllSecurityGroups() got = %v, want %v", got, tt.expected) - } - }) - } -} - -func Test_Network_ListAllLoadBalancers(t *testing.T) { - expectedResults := []*armnetwork.LoadBalancer{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("lb-1"), - Name: to.StringPtr("lb-1"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("lb-2"), - Name: to.StringPtr("lb-2"), - }, - }, - } - - testcases := []struct { - name string - mocks func(*mockLoadBalancersListAllPager, *cache.MockCache) - expected []*armnetwork.LoadBalancer - wantErr string - }{ - { - name: "should return load balancers", - mocks: func(pager *mockLoadBalancersListAllPager, mockCache *cache.MockCache) { - pager.On("NextPage", context.Background()).Return(true).Times(1) - pager.On("NextPage", context.Background()).Return(false).Times(1) - pager.On("PageResponse").Return(armnetwork.LoadBalancersListAllResponse{ - LoadBalancersListAllResult: armnetwork.LoadBalancersListAllResult{ - LoadBalancerListResult: armnetwork.LoadBalancerListResult{ - Value: expectedResults, - }, - }, - }).Times(1) - pager.On("Err").Return(nil).Times(2) - - mockCache.On("GetAndLock", "networkListAllLoadBalancers").Return(nil).Times(1) - mockCache.On("Put", "networkListAllLoadBalancers", expectedResults).Return(false).Times(1) - mockCache.On("Unlock", "networkListAllLoadBalancers").Return(nil).Times(1) - }, - expected: expectedResults, - }, - { - name: "should hit cache and return load balancers", - mocks: func(pager *mockLoadBalancersListAllPager, mockCache *cache.MockCache) { - mockCache.On("GetAndLock", "networkListAllLoadBalancers").Return(expectedResults).Times(1) - mockCache.On("Unlock", "networkListAllLoadBalancers").Return(nil).Times(1) - }, - expected: expectedResults, - }, - { - name: "should return remote error", - mocks: func(pager *mockLoadBalancersListAllPager, mockCache *cache.MockCache) { - pager.On("NextPage", context.Background()).Return(true).Times(1) - pager.On("NextPage", context.Background()).Return(false).Times(1) - pager.On("PageResponse").Return(armnetwork.LoadBalancersListAllResponse{}).Times(1) - pager.On("Err").Return(errors.New("remote error")).Times(1) - - mockCache.On("GetAndLock", "networkListAllLoadBalancers").Return(nil).Times(1) - mockCache.On("Unlock", "networkListAllLoadBalancers").Return(nil).Times(1) - }, - wantErr: "remote error", - }, - } - - for _, tt := range testcases { - t.Run(tt.name, func(t *testing.T) { - fakePager := &mockLoadBalancersListAllPager{} - fakeClient := &mockLoadBalancersClient{} - mockCache := &cache.MockCache{} - - fakeClient.On("ListAll", (*armnetwork.LoadBalancersListAllOptions)(nil)).Return(fakePager).Maybe() - - tt.mocks(fakePager, mockCache) - - s := &networkRepository{ - loadBalancersClient: fakeClient, - cache: mockCache, - } - got, err := s.ListAllLoadBalancers() - if tt.wantErr != "" { - assert.EqualError(t, err, tt.wantErr) - } else { - assert.Nil(t, err) - } - - fakeClient.AssertExpectations(t) - mockCache.AssertExpectations(t) - - if !reflect.DeepEqual(got, tt.expected) { - t.Errorf("ListAllLoadBalancers() got = %v, want %v", got, tt.expected) - } - }) - } -} - -func Test_Network_ListLoadBalancerRules(t *testing.T) { - expectedResults := []*armnetwork.LoadBalancingRule{ - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("lbrule-1"), - }, - Name: to.StringPtr("lbrule-1"), - }, - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("lbrule-1"), - }, - Name: to.StringPtr("lbrule-1"), - }, - } - - testcases := []struct { - name string - loadBalancer *armnetwork.LoadBalancer - mocks func(*mockLoadBalancerRulesClient, *mockLoadBalancerRulesListAllPager, *cache.MockCache) - expected []*armnetwork.LoadBalancingRule - wantErr string - }{ - { - name: "should return load balancer rules", - loadBalancer: &armnetwork.LoadBalancer{ - Resource: armnetwork.Resource{ID: to.StringPtr("/subscriptions/xxx/resourceGroups/driftctl/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress")}, - }, - mocks: func(client *mockLoadBalancerRulesClient, pager *mockLoadBalancerRulesListAllPager, mockCache *cache.MockCache) { - client.On("List", "driftctl", "PublicIPAddress", &armnetwork.LoadBalancerLoadBalancingRulesListOptions{}).Return(pager) - - pager.On("NextPage", context.Background()).Return(true).Times(1) - pager.On("NextPage", context.Background()).Return(false).Times(1) - pager.On("PageResponse").Return(armnetwork.LoadBalancerLoadBalancingRulesListResponse{ - LoadBalancerLoadBalancingRulesListResult: armnetwork.LoadBalancerLoadBalancingRulesListResult{ - LoadBalancerLoadBalancingRuleListResult: armnetwork.LoadBalancerLoadBalancingRuleListResult{ - Value: expectedResults, - }, - }, - }).Times(1) - pager.On("Err").Return(nil).Times(2) - - mockCache.On("Get", "networkListLoadBalancerRules_/subscriptions/xxx/resourceGroups/driftctl/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress").Return(nil).Times(1) - mockCache.On("Put", "networkListLoadBalancerRules_/subscriptions/xxx/resourceGroups/driftctl/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress", expectedResults).Return(false).Times(1) - }, - expected: expectedResults, - }, - { - name: "should hit cache and return load balancers", - loadBalancer: &armnetwork.LoadBalancer{ - Resource: armnetwork.Resource{ID: to.StringPtr("lb-1")}, - }, - mocks: func(client *mockLoadBalancerRulesClient, pager *mockLoadBalancerRulesListAllPager, mockCache *cache.MockCache) { - mockCache.On("Get", "networkListLoadBalancerRules_lb-1").Return(expectedResults).Times(1) - }, - expected: expectedResults, - }, - { - name: "should return remote error", - loadBalancer: &armnetwork.LoadBalancer{ - Resource: armnetwork.Resource{ID: to.StringPtr("/subscriptions/xxx/resourceGroups/driftctl/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress")}, - }, - mocks: func(client *mockLoadBalancerRulesClient, pager *mockLoadBalancerRulesListAllPager, mockCache *cache.MockCache) { - client.On("List", "driftctl", "PublicIPAddress", &armnetwork.LoadBalancerLoadBalancingRulesListOptions{}).Return(pager) - - pager.On("NextPage", context.Background()).Return(true).Times(1) - pager.On("NextPage", context.Background()).Return(false).Times(1) - pager.On("PageResponse").Return(armnetwork.LoadBalancerLoadBalancingRulesListResponse{}).Times(1) - pager.On("Err").Return(errors.New("remote error")).Times(1) - - mockCache.On("Get", "networkListLoadBalancerRules_/subscriptions/xxx/resourceGroups/driftctl/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress").Return(nil).Times(1) - }, - wantErr: "remote error", - }, - } - - for _, tt := range testcases { - t.Run(tt.name, func(t *testing.T) { - fakePager := &mockLoadBalancerRulesListAllPager{} - fakeClient := &mockLoadBalancerRulesClient{} - mockCache := &cache.MockCache{} - - tt.mocks(fakeClient, fakePager, mockCache) - - s := &networkRepository{ - loadBalancerRulesClient: fakeClient, - cache: mockCache, - } - got, err := s.ListLoadBalancerRules(tt.loadBalancer) - if tt.wantErr != "" { - assert.EqualError(t, err, tt.wantErr) - } else { - assert.Nil(t, err) - } - - fakeClient.AssertExpectations(t) - mockCache.AssertExpectations(t) - - if !reflect.DeepEqual(got, tt.expected) { - t.Errorf("ListAllLoadBalancers() got = %v, want %v", got, tt.expected) - } - }) - } -} diff --git a/pkg/remote/azurerm/repository/postgresql.go b/pkg/remote/azurerm/repository/postgresql.go deleted file mode 100644 index 644704ec..00000000 --- a/pkg/remote/azurerm/repository/postgresql.go +++ /dev/null @@ -1,93 +0,0 @@ -package repository - -import ( - "context" - "fmt" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql" - "github.com/Azure/go-autorest/autorest/azure" - "github.com/snyk/driftctl/pkg/remote/azurerm/common" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type PostgresqlRespository interface { - ListAllServers() ([]*armpostgresql.Server, error) - ListAllDatabasesByServer(server *armpostgresql.Server) ([]*armpostgresql.Database, error) -} - -type postgresqlServersClientImpl struct { - client *armpostgresql.ServersClient -} - -type postgresqlServersClient interface { - List(context.Context, *armpostgresql.ServersListOptions) (armpostgresql.ServersListResponse, error) -} - -func (c postgresqlServersClientImpl) List(ctx context.Context, options *armpostgresql.ServersListOptions) (armpostgresql.ServersListResponse, error) { - return c.client.List(ctx, options) -} - -type postgresqlDatabaseClientImpl struct { - client *armpostgresql.DatabasesClient -} - -type postgresqlDatabaseClient interface { - ListByServer(context.Context, string, string, *armpostgresql.DatabasesListByServerOptions) (armpostgresql.DatabasesListByServerResponse, error) -} - -func (c postgresqlDatabaseClientImpl) ListByServer(ctx context.Context, resGroup string, serverName string, options *armpostgresql.DatabasesListByServerOptions) (armpostgresql.DatabasesListByServerResponse, error) { - return c.client.ListByServer(ctx, resGroup, serverName, options) -} - -type postgresqlRepository struct { - serversClient postgresqlServersClient - databaseClient postgresqlDatabaseClient - cache cache.Cache -} - -func NewPostgresqlRepository(cred azcore.TokenCredential, options *arm.ClientOptions, config common.AzureProviderConfig, cache cache.Cache) *postgresqlRepository { - return &postgresqlRepository{ - postgresqlServersClientImpl{client: armpostgresql.NewServersClient(config.SubscriptionID, cred, options)}, - postgresqlDatabaseClientImpl{client: armpostgresql.NewDatabasesClient(config.SubscriptionID, cred, options)}, - cache, - } -} - -func (s *postgresqlRepository) ListAllServers() ([]*armpostgresql.Server, error) { - cacheKey := "postgresqlListAllServers" - - defer s.cache.Unlock(cacheKey) - if v := s.cache.GetAndLock(cacheKey); v != nil { - return v.([]*armpostgresql.Server), nil - } - - res, err := s.serversClient.List(context.Background(), nil) - if err != nil { - return nil, err - } - - s.cache.Put(cacheKey, res.Value) - return res.Value, nil -} - -func (s *postgresqlRepository) ListAllDatabasesByServer(server *armpostgresql.Server) ([]*armpostgresql.Database, error) { - res, err := azure.ParseResourceID(*server.ID) - if err != nil { - return nil, err - } - - cacheKey := fmt.Sprintf("postgresqlListAllDatabases_%s_%s", res.ResourceGroup, *server.Name) - if v := s.cache.Get(cacheKey); v != nil { - return v.([]*armpostgresql.Database), nil - } - - result, err := s.databaseClient.ListByServer(context.Background(), res.ResourceGroup, *server.Name, nil) - if err != nil { - return nil, err - } - - s.cache.Put(cacheKey, result.Value) - return result.Value, nil -} diff --git a/pkg/remote/azurerm/repository/postgresql_test.go b/pkg/remote/azurerm/repository/postgresql_test.go deleted file mode 100644 index d88e0934..00000000 --- a/pkg/remote/azurerm/repository/postgresql_test.go +++ /dev/null @@ -1,196 +0,0 @@ -package repository - -import ( - "context" - "reflect" - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql" - "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func Test_Postgresql_ListAllServers(t *testing.T) { - expectedResults := []*armpostgresql.Server{ - { - TrackedResource: armpostgresql.TrackedResource{ - Resource: armpostgresql.Resource{ - ID: to.StringPtr("postgresql-server-1"), - }, - }, - }, - { - TrackedResource: armpostgresql.TrackedResource{ - Resource: armpostgresql.Resource{ - ID: to.StringPtr("postgresql-server-2"), - }, - }, - }, - } - - testcases := []struct { - name string - mocks func(*mockPostgresqlServersClient, *cache.MockCache) - expected []*armpostgresql.Server - wantErr string - }{ - { - name: "should return postgres servers", - mocks: func(client *mockPostgresqlServersClient, mockCache *cache.MockCache) { - client.On("List", context.Background(), mock.Anything).Return(armpostgresql.ServersListResponse{ - ServersListResult: armpostgresql.ServersListResult{ - ServerListResult: armpostgresql.ServerListResult{ - Value: expectedResults, - }, - }, - }, nil).Times(1) - - mockCache.On("GetAndLock", "postgresqlListAllServers").Return(nil).Times(1) - mockCache.On("Unlock", "postgresqlListAllServers").Return().Times(1) - mockCache.On("Put", "postgresqlListAllServers", expectedResults).Return(false).Times(1) - }, - expected: expectedResults, - }, - { - name: "should hit cache and return postgres servers", - mocks: func(client *mockPostgresqlServersClient, mockCache *cache.MockCache) { - mockCache.On("GetAndLock", "postgresqlListAllServers").Return(expectedResults).Times(1) - mockCache.On("Unlock", "postgresqlListAllServers").Return().Times(1) - }, - expected: expectedResults, - }, - { - name: "should return remote error", - mocks: func(client *mockPostgresqlServersClient, mockCache *cache.MockCache) { - client.On("List", context.Background(), mock.Anything).Return(armpostgresql.ServersListResponse{}, errors.New("remote error")).Times(1) - - mockCache.On("GetAndLock", "postgresqlListAllServers").Return(nil).Times(1) - mockCache.On("Unlock", "postgresqlListAllServers").Return().Times(1) - }, - wantErr: "remote error", - }, - } - - for _, tt := range testcases { - t.Run(tt.name, func(t *testing.T) { - fakeClient := &mockPostgresqlServersClient{} - mockCache := &cache.MockCache{} - - tt.mocks(fakeClient, mockCache) - - s := &postgresqlRepository{ - serversClient: fakeClient, - cache: mockCache, - } - got, err := s.ListAllServers() - if tt.wantErr != "" { - assert.EqualError(t, err, tt.wantErr) - } else { - assert.Nil(t, err) - } - - fakeClient.AssertExpectations(t) - mockCache.AssertExpectations(t) - - if !reflect.DeepEqual(got, tt.expected) { - t.Errorf("ListAllResourceGroups() got = %v, want %v", got, tt.expected) - } - }) - } -} - -func Test_Postgresql_ListAllDatabases(t *testing.T) { - expectedResults := []*armpostgresql.Database{ - { - ProxyResource: armpostgresql.ProxyResource{ - Resource: armpostgresql.Resource{ - ID: to.StringPtr("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/res-group/providers/Microsoft.DBforPostgreSQL/servers/postgresql-server-1/databases/postgresql-db-1"), - }, - }, - }, - { - ProxyResource: armpostgresql.ProxyResource{ - Resource: armpostgresql.Resource{ - ID: to.StringPtr("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/res-group/providers/Microsoft.DBforPostgreSQL/servers/postgresql-server-1/databases/postgresql-db-2"), - }, - }, - }, - } - - testcases := []struct { - name string - mocks func(*mockPostgresqlDatabaseClient, *cache.MockCache) - expected []*armpostgresql.Database - wantErr string - }{ - { - name: "should return postgres servers", - mocks: func(client *mockPostgresqlDatabaseClient, mockCache *cache.MockCache) { - client.On("ListByServer", context.Background(), "res-group", "postgresql-server-1", (*armpostgresql.DatabasesListByServerOptions)(nil)).Return(armpostgresql.DatabasesListByServerResponse{ - DatabasesListByServerResult: armpostgresql.DatabasesListByServerResult{ - DatabaseListResult: armpostgresql.DatabaseListResult{ - Value: expectedResults, - }, - }, - }, nil).Times(1) - - mockCache.On("Get", "postgresqlListAllDatabases_res-group_postgresql-server-1").Return(nil).Times(1) - mockCache.On("Put", "postgresqlListAllDatabases_res-group_postgresql-server-1", expectedResults).Return(false).Times(1) - }, - expected: expectedResults, - }, - { - name: "should hit cache and return postgres servers", - mocks: func(client *mockPostgresqlDatabaseClient, mockCache *cache.MockCache) { - mockCache.On("Get", "postgresqlListAllDatabases_res-group_postgresql-server-1").Return(expectedResults).Times(1) - }, - expected: expectedResults, - }, - { - name: "should return remote error", - mocks: func(client *mockPostgresqlDatabaseClient, mockCache *cache.MockCache) { - mockCache.On("Get", "postgresqlListAllDatabases_res-group_postgresql-server-1").Return(nil).Times(1) - - client.On("ListByServer", context.Background(), "res-group", "postgresql-server-1", (*armpostgresql.DatabasesListByServerOptions)(nil)).Return(armpostgresql.DatabasesListByServerResponse{}, errors.New("remote error")).Times(1) - }, - wantErr: "remote error", - }, - } - - for _, tt := range testcases { - t.Run(tt.name, func(t *testing.T) { - fakeClient := &mockPostgresqlDatabaseClient{} - mockCache := &cache.MockCache{} - - tt.mocks(fakeClient, mockCache) - - s := &postgresqlRepository{ - databaseClient: fakeClient, - cache: mockCache, - } - got, err := s.ListAllDatabasesByServer(&armpostgresql.Server{ - TrackedResource: armpostgresql.TrackedResource{ - Resource: armpostgresql.Resource{ - ID: to.StringPtr("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/res-group/providers/Microsoft.DBforPostgreSQL/servers/postgresql-server-1"), - Name: to.StringPtr("postgresql-server-1"), - }, - }, - }) - if tt.wantErr != "" { - assert.EqualError(t, err, tt.wantErr) - } else { - assert.Nil(t, err) - } - - fakeClient.AssertExpectations(t) - mockCache.AssertExpectations(t) - - if !reflect.DeepEqual(got, tt.expected) { - t.Errorf("ListAllResourceGroups() got = %v, want %v", got, tt.expected) - } - }) - } -} diff --git a/pkg/remote/azurerm/repository/privatedns.go b/pkg/remote/azurerm/repository/privatedns.go deleted file mode 100644 index a2a1fa4a..00000000 --- a/pkg/remote/azurerm/repository/privatedns.go +++ /dev/null @@ -1,243 +0,0 @@ -package repository - -import ( - "context" - "fmt" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns" - "github.com/Azure/go-autorest/autorest/azure" - "github.com/snyk/driftctl/pkg/remote/azurerm/common" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type PrivateDNSRepository interface { - ListAllPrivateZones() ([]*armprivatedns.PrivateZone, error) - ListAllARecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) - ListAllAAAARecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) - ListAllCNAMERecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) - ListAllPTRRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) - ListAllMXRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) - ListAllSRVRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) - ListAllTXTRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) -} - -type privateDNSZoneListPager interface { - pager - PageResponse() armprivatedns.PrivateZonesListResponse -} - -type privateDNSRecordSetListPager interface { - pager - PageResponse() armprivatedns.RecordSetsListResponse -} - -type privateRecordSetClient interface { - List(resourceGroupName string, privateZoneName string, options *armprivatedns.RecordSetsListOptions) privateDNSRecordSetListPager -} - -type privateRecordSetClientImpl struct { - client *armprivatedns.RecordSetsClient -} - -func (c *privateRecordSetClientImpl) List(resourceGroupName string, privateZoneName string, options *armprivatedns.RecordSetsListOptions) privateDNSRecordSetListPager { - return c.client.List(resourceGroupName, privateZoneName, options) -} - -type privateZonesClient interface { - List(options *armprivatedns.PrivateZonesListOptions) privateDNSZoneListPager -} - -type privateZonesClientImpl struct { - client *armprivatedns.PrivateZonesClient -} - -func (c *privateZonesClientImpl) List(options *armprivatedns.PrivateZonesListOptions) privateDNSZoneListPager { - return c.client.List(options) -} - -type privateDNSRepository struct { - zoneClient privateZonesClient - recordClient privateRecordSetClient - cache cache.Cache -} - -func NewPrivateDNSRepository(cred azcore.TokenCredential, options *arm.ClientOptions, config common.AzureProviderConfig, cache cache.Cache) *privateDNSRepository { - return &privateDNSRepository{ - &privateZonesClientImpl{armprivatedns.NewPrivateZonesClient(config.SubscriptionID, cred, options)}, - &privateRecordSetClientImpl{armprivatedns.NewRecordSetsClient(config.SubscriptionID, cred, options)}, - cache, - } -} - -func (s *privateDNSRepository) listAllRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { - cacheKey := fmt.Sprintf("privateDNSlistAllRecords-%s", *zone.ID) - v := s.cache.GetAndLock(cacheKey) - defer s.cache.Unlock(cacheKey) - if v != nil { - return v.([]*armprivatedns.RecordSet), nil - } - - res, err := azure.ParseResourceID(*zone.ID) - if err != nil { - return nil, err - } - - pager := s.recordClient.List(res.ResourceGroup, *zone.Name, nil) - results := make([]*armprivatedns.RecordSet, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - results = append(results, resp.Value...) - } - - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put(cacheKey, results) - - return results, nil -} - -func (s *privateDNSRepository) ListAllARecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { - records, err := s.listAllRecords(zone) - if err != nil { - return nil, err - } - results := make([]*armprivatedns.RecordSet, 0) - for _, record := range records { - if record.Properties.ARecords == nil { - continue - } - results = append(results, record) - - } - return results, nil -} - -func (s *privateDNSRepository) ListAllAAAARecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { - records, err := s.listAllRecords(zone) - if err != nil { - return nil, err - } - results := make([]*armprivatedns.RecordSet, 0) - for _, record := range records { - if record.Properties.AaaaRecords == nil { - continue - } - results = append(results, record) - - } - return results, nil -} - -func (s *privateDNSRepository) ListAllPTRRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { - records, err := s.listAllRecords(zone) - if err != nil { - return nil, err - } - results := make([]*armprivatedns.RecordSet, 0) - for _, record := range records { - if record.Properties.PtrRecords == nil { - continue - } - results = append(results, record) - - } - return results, nil -} - -func (s *privateDNSRepository) ListAllCNAMERecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { - records, err := s.listAllRecords(zone) - if err != nil { - return nil, err - } - results := make([]*armprivatedns.RecordSet, 0) - for _, record := range records { - if record.Properties.CnameRecord == nil { - continue - } - results = append(results, record) - - } - return results, nil -} - -func (s *privateDNSRepository) ListAllMXRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { - records, err := s.listAllRecords(zone) - if err != nil { - return nil, err - } - results := make([]*armprivatedns.RecordSet, 0) - for _, record := range records { - if record.Properties.MxRecords == nil { - continue - } - results = append(results, record) - - } - return results, nil -} - -func (s *privateDNSRepository) ListAllSRVRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { - records, err := s.listAllRecords(zone) - if err != nil { - return nil, err - } - results := make([]*armprivatedns.RecordSet, 0) - for _, record := range records { - if record.Properties.SrvRecords == nil { - continue - } - results = append(results, record) - - } - return results, nil -} - -func (s *privateDNSRepository) ListAllTXTRecords(zone *armprivatedns.PrivateZone) ([]*armprivatedns.RecordSet, error) { - records, err := s.listAllRecords(zone) - if err != nil { - return nil, err - } - results := make([]*armprivatedns.RecordSet, 0) - for _, record := range records { - if record.Properties.TxtRecords == nil { - continue - } - results = append(results, record) - - } - return results, nil -} - -func (s *privateDNSRepository) ListAllPrivateZones() ([]*armprivatedns.PrivateZone, error) { - cacheKey := "privateDNSListAllPrivateZones" - v := s.cache.GetAndLock(cacheKey) - defer s.cache.Unlock(cacheKey) - if v != nil { - return v.([]*armprivatedns.PrivateZone), nil - } - - pager := s.zoneClient.List(nil) - results := make([]*armprivatedns.PrivateZone, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - results = append(results, resp.Value...) - } - - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put(cacheKey, results) - - return results, nil -} diff --git a/pkg/remote/azurerm/repository/privatedns_test.go b/pkg/remote/azurerm/repository/privatedns_test.go deleted file mode 100644 index 0bcdbbbe..00000000 --- a/pkg/remote/azurerm/repository/privatedns_test.go +++ /dev/null @@ -1,1639 +0,0 @@ -package repository - -import ( - "reflect" - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns" - "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -// region PrivateZone -func Test_ListAllPrivateZones_MultiplesResults(t *testing.T) { - - expected := []*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("zone1"), - }, - }, - }, - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("zone2"), - }, - }, - }, - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("zone3"), - }, - }, - }, - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("zone4"), - }, - }, - }, - } - - fakeClient := &mockPrivateZonesClient{} - - mockPager := &mockPrivateDNSZoneListPager{} - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.PrivateZonesListResponse{ - PrivateZonesListResult: armprivatedns.PrivateZonesListResult{ - PrivateZoneListResult: armprivatedns.PrivateZoneListResult{ - Value: []*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("zone1"), - }, - }, - }, - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("zone2"), - }, - }, - }, - }, - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.PrivateZonesListResponse{ - PrivateZonesListResult: armprivatedns.PrivateZonesListResult{ - PrivateZoneListResult: armprivatedns.PrivateZoneListResult{ - Value: []*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("zone3"), - }, - }, - }, - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("zone4"), - }, - }, - }, - }, - }, - }, - }).Times(1) - - fakeClient.On("List", mock.Anything).Return(mockPager) - - c := &cache.MockCache{} - c.On("GetAndLock", "privateDNSListAllPrivateZones").Return(nil).Times(1) - c.On("Unlock", "privateDNSListAllPrivateZones").Times(1) - c.On("Put", "privateDNSListAllPrivateZones", expected).Return(true).Times(1) - s := &privateDNSRepository{ - zoneClient: fakeClient, - cache: c, - } - got, err := s.ListAllPrivateZones() - if err != nil { - t.Errorf("ListAllPrivateZones() error = %v", err) - return - } - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllPrivateZones() got = %v, want %v", got, expected) - } -} - -func Test_ListAllPrivateZones_MultiplesResults_WithCache(t *testing.T) { - - expected := []*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("zone1"), - }, - }, - }, - } - - fakeClient := &mockPrivateZonesClient{} - - c := &cache.MockCache{} - c.On("GetAndLock", "privateDNSListAllPrivateZones").Return(expected).Times(1) - c.On("Unlock", "privateDNSListAllPrivateZones").Times(1) - - s := &privateDNSRepository{ - zoneClient: fakeClient, - cache: c, - } - got, err := s.ListAllPrivateZones() - if err != nil { - t.Errorf("ListAllPrivateZones() error = %v", err) - return - } - - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllPrivateZones() got = %v, want %v", got, expected) - } -} - -func Test_ListAllPrivateZones_Error(t *testing.T) { - - fakeClient := &mockPrivateZonesClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockPrivateDNSZoneListPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.PrivateZonesListResponse{}).Times(1) - - fakeClient.On("List", mock.Anything).Return(mockPager) - - s := &privateDNSRepository{ - zoneClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllPrivateZones() - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -// endregion - -// region ARecord -func Test_ListAllARecords_MultiplesResults(t *testing.T) { - - expected := []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - ARecords: []*armprivatedns.ARecord{ - {IPv4Address: to.StringPtr("ip")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record3"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - ARecords: []*armprivatedns.ARecord{ - {IPv4Address: to.StringPtr("ip")}, - }, - }, - }, - } - - fakeRecordSetClient := &mockPrivateRecordSetClient{} - - mockPager := &mockPrivateDNSRecordSetListPager{} - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ - RecordSetsListResult: armprivatedns.RecordSetsListResult{ - RecordSetListResult: armprivatedns.RecordSetListResult{ - Value: []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - ARecords: []*armprivatedns.ARecord{ - {IPv4Address: to.StringPtr("ip")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record2"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{}, - }, - }, - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ - RecordSetsListResult: armprivatedns.RecordSetsListResult{ - RecordSetListResult: armprivatedns.RecordSetListResult{ - Value: []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record3"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - ARecords: []*armprivatedns.ARecord{ - {IPv4Address: to.StringPtr("ip")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record4"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{}, - }, - }, - }, - }, - }).Times(1) - - fakeRecordSetClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) - - c := &cache.MockCache{} - c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(nil).Times(1) - c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) - c.On("Put", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com", mock.Anything).Return(true).Times(1) - s := &privateDNSRepository{ - recordClient: fakeRecordSetClient, - cache: c, - } - got, err := s.ListAllARecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - if err != nil { - t.Errorf("ListAllARecords() error = %v", err) - return - } - - mockPager.AssertExpectations(t) - fakeRecordSetClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllARecords() got = %v, want %v", got, expected) - } -} - -func Test_ListAllARecords_MultiplesResults_WithCache(t *testing.T) { - - expected := []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - ARecords: []*armprivatedns.ARecord{ - {IPv4Address: to.StringPtr("ip")}, - }, - }, - }, - } - - fakeRecordSetClient := &mockPrivateRecordSetClient{} - - c := &cache.MockCache{} - c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(expected).Times(1) - c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Times(1) - s := &privateDNSRepository{ - recordClient: fakeRecordSetClient, - cache: c, - } - got, err := s.ListAllARecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - if err != nil { - t.Errorf("ListAllARecords() error = %v", err) - return - } - - fakeRecordSetClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllARecords() got = %v, want %v", got, expected) - } -} - -func Test_ListAllARecords_Error(t *testing.T) { - - fakeClient := &mockPrivateRecordSetClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockPrivateDNSRecordSetListPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{}).Times(1) - - fakeClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) - - s := &privateDNSRepository{ - recordClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllARecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -// endregion - -// region AAAAAAARecord -func Test_ListAllAAAARecords_MultiplesResults(t *testing.T) { - - expected := []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - AaaaRecords: []*armprivatedns.AaaaRecord{ - {IPv6Address: to.StringPtr("ip")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record3"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - AaaaRecords: []*armprivatedns.AaaaRecord{ - {IPv6Address: to.StringPtr("ip")}, - }, - }, - }, - } - - fakeRecordSetClient := &mockPrivateRecordSetClient{} - - mockPager := &mockPrivateDNSRecordSetListPager{} - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ - RecordSetsListResult: armprivatedns.RecordSetsListResult{ - RecordSetListResult: armprivatedns.RecordSetListResult{ - Value: []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - AaaaRecords: []*armprivatedns.AaaaRecord{ - {IPv6Address: to.StringPtr("ip")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record2"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{}, - }, - }, - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ - RecordSetsListResult: armprivatedns.RecordSetsListResult{ - RecordSetListResult: armprivatedns.RecordSetListResult{ - Value: []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record3"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - AaaaRecords: []*armprivatedns.AaaaRecord{ - {IPv6Address: to.StringPtr("ip")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record4"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{}, - }, - }, - }, - }, - }).Times(1) - - fakeRecordSetClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) - - c := &cache.MockCache{} - c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(nil).Times(1) - c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) - c.On("Put", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com", mock.Anything).Return(true).Times(1) - s := &privateDNSRepository{ - recordClient: fakeRecordSetClient, - cache: c, - } - got, err := s.ListAllAAAARecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - if err != nil { - t.Errorf("ListAllAAAARecords() error = %v", err) - return - } - - mockPager.AssertExpectations(t) - fakeRecordSetClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllAAAARecords() got = %v, want %v", got, expected) - } -} - -func Test_ListAllAAAARecords_MultiplesResults_WithCache(t *testing.T) { - - expected := []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - AaaaRecords: []*armprivatedns.AaaaRecord{ - {IPv6Address: to.StringPtr("ip")}, - }, - }, - }, - } - - fakeRecordSetClient := &mockPrivateRecordSetClient{} - - c := &cache.MockCache{} - c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(expected).Times(1) - c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Times(1) - s := &privateDNSRepository{ - recordClient: fakeRecordSetClient, - cache: c, - } - got, err := s.ListAllAAAARecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - if err != nil { - t.Errorf("ListAllAAAARecords() error = %v", err) - return - } - - fakeRecordSetClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllAAAARecords() got = %v, want %v", got, expected) - } -} - -func Test_ListAllAAAARecords_Error(t *testing.T) { - - fakeClient := &mockPrivateRecordSetClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockPrivateDNSRecordSetListPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{}).Times(1) - - fakeClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) - - s := &privateDNSRepository{ - recordClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllAAAARecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -// endregion - -// region CNAMERecord -func Test_ListAllCNAMERecords_MultiplesResults(t *testing.T) { - - expected := []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - CnameRecord: &armprivatedns.CnameRecord{ - Cname: to.StringPtr("cname"), - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record3"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - CnameRecord: &armprivatedns.CnameRecord{ - Cname: to.StringPtr("cname"), - }, - }, - }, - } - - fakeRecordSetClient := &mockPrivateRecordSetClient{} - - mockPager := &mockPrivateDNSRecordSetListPager{} - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ - RecordSetsListResult: armprivatedns.RecordSetsListResult{ - RecordSetListResult: armprivatedns.RecordSetListResult{ - Value: []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - CnameRecord: &armprivatedns.CnameRecord{ - Cname: to.StringPtr("cname"), - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record2"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{}, - }, - }, - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ - RecordSetsListResult: armprivatedns.RecordSetsListResult{ - RecordSetListResult: armprivatedns.RecordSetListResult{ - Value: []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record3"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - CnameRecord: &armprivatedns.CnameRecord{ - Cname: to.StringPtr("cname"), - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record4"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{}, - }, - }, - }, - }, - }).Times(1) - - fakeRecordSetClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) - - c := &cache.MockCache{} - c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(nil).Times(1) - c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) - c.On("Put", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com", mock.Anything).Return(true).Times(1) - s := &privateDNSRepository{ - recordClient: fakeRecordSetClient, - cache: c, - } - got, err := s.ListAllCNAMERecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - if err != nil { - t.Errorf("ListAllCNAMERecords() error = %v", err) - return - } - - mockPager.AssertExpectations(t) - fakeRecordSetClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllCNAMERecords() got = %v, want %v", got, expected) - } -} - -func Test_ListAllCNAMERecords_MultiplesResults_WithCache(t *testing.T) { - - expected := []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - CnameRecord: &armprivatedns.CnameRecord{ - Cname: to.StringPtr("cname"), - }, - }, - }, - } - - fakeRecordSetClient := &mockPrivateRecordSetClient{} - - c := &cache.MockCache{} - - c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(expected).Times(1) - c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) - - s := &privateDNSRepository{ - recordClient: fakeRecordSetClient, - cache: c, - } - got, err := s.ListAllCNAMERecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - if err != nil { - t.Errorf("ListAllCNAMERecords() error = %v", err) - return - } - - fakeRecordSetClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllCNAMERecords() got = %v, want %v", got, expected) - } -} - -func Test_ListAllCNAMERecords_Error(t *testing.T) { - - fakeClient := &mockPrivateRecordSetClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockPrivateDNSRecordSetListPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{}).Times(1) - - fakeClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) - - s := &privateDNSRepository{ - recordClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllCNAMERecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -// endregion - -// region PTRRecord -func Test_ListAllPTRRecords_MultiplesResults(t *testing.T) { - - expected := []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - PtrRecords: []*armprivatedns.PtrRecord{ - {Ptrdname: to.StringPtr("ptrdname")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record3"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - PtrRecords: []*armprivatedns.PtrRecord{ - {Ptrdname: to.StringPtr("ptrdname")}, - }, - }, - }, - } - - fakeRecordSetClient := &mockPrivateRecordSetClient{} - - mockPager := &mockPrivateDNSRecordSetListPager{} - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ - RecordSetsListResult: armprivatedns.RecordSetsListResult{ - RecordSetListResult: armprivatedns.RecordSetListResult{ - Value: []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - PtrRecords: []*armprivatedns.PtrRecord{ - {Ptrdname: to.StringPtr("ptrdname")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record2"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{}, - }, - }, - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ - RecordSetsListResult: armprivatedns.RecordSetsListResult{ - RecordSetListResult: armprivatedns.RecordSetListResult{ - Value: []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record3"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - PtrRecords: []*armprivatedns.PtrRecord{ - {Ptrdname: to.StringPtr("ptrdname")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record4"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{}, - }, - }, - }, - }, - }).Times(1) - - fakeRecordSetClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) - - c := &cache.MockCache{} - c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(nil).Times(1) - c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) - c.On("Put", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com", mock.Anything).Return(true).Times(1) - s := &privateDNSRepository{ - recordClient: fakeRecordSetClient, - cache: c, - } - got, err := s.ListAllPTRRecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - if err != nil { - t.Errorf("ListAllPTRRecords() error = %v", err) - return - } - - mockPager.AssertExpectations(t) - fakeRecordSetClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllPTRRecords() got = %v, want %v", got, expected) - } -} - -func Test_ListAllPTRRecords_MultiplesResults_WithCache(t *testing.T) { - - expected := []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - PtrRecords: []*armprivatedns.PtrRecord{ - {Ptrdname: to.StringPtr("ptrdname")}, - }, - }, - }, - } - - fakeRecordSetClient := &mockPrivateRecordSetClient{} - - c := &cache.MockCache{} - c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(expected).Times(1) - c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Times(1) - s := &privateDNSRepository{ - recordClient: fakeRecordSetClient, - cache: c, - } - got, err := s.ListAllPTRRecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - if err != nil { - t.Errorf("ListAllPTRRecords() error = %v", err) - return - } - - fakeRecordSetClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllPTRRecords() got = %v, want %v", got, expected) - } -} - -func Test_ListAllPTRRecords_Error(t *testing.T) { - - fakeClient := &mockPrivateRecordSetClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockPrivateDNSRecordSetListPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{}).Times(1) - - fakeClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) - - s := &privateDNSRepository{ - recordClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllPTRRecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -// endregion - -// region MXRecord -func Test_ListAllMXRecords_MultiplesResults(t *testing.T) { - expected := []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - MxRecords: []*armprivatedns.MxRecord{ - {Exchange: to.StringPtr("ex")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record3"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - MxRecords: []*armprivatedns.MxRecord{ - {Exchange: to.StringPtr("ex")}, - }, - }, - }, - } - - fakeRecordSetClient := &mockPrivateRecordSetClient{} - - mockPager := &mockPrivateDNSRecordSetListPager{} - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ - RecordSetsListResult: armprivatedns.RecordSetsListResult{ - RecordSetListResult: armprivatedns.RecordSetListResult{ - Value: []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - MxRecords: []*armprivatedns.MxRecord{ - {Exchange: to.StringPtr("ex")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record2"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{}, - }, - }, - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ - RecordSetsListResult: armprivatedns.RecordSetsListResult{ - RecordSetListResult: armprivatedns.RecordSetListResult{ - Value: []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record3"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - MxRecords: []*armprivatedns.MxRecord{ - {Exchange: to.StringPtr("ex")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record4"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{}, - }, - }, - }, - }, - }).Times(1) - - fakeRecordSetClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) - - c := &cache.MockCache{} - c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(nil).Times(1) - c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) - c.On("Put", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com", mock.Anything).Return(true).Times(1) - s := &privateDNSRepository{ - recordClient: fakeRecordSetClient, - cache: c, - } - got, err := s.ListAllMXRecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - if err != nil { - t.Errorf("ListAllMXRecords() error = %v", err) - return - } - - mockPager.AssertExpectations(t) - fakeRecordSetClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllMXRecords() got = %v, want %v", got, expected) - } -} - -func Test_ListAllMXRecords_MultiplesResults_WithCache(t *testing.T) { - expected := []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - MxRecords: []*armprivatedns.MxRecord{ - {Exchange: to.StringPtr("ex")}, - }, - }, - }, - } - - fakeRecordSetClient := &mockPrivateRecordSetClient{} - - c := &cache.MockCache{} - c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(expected).Times(1) - c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Times(1) - s := &privateDNSRepository{ - recordClient: fakeRecordSetClient, - cache: c, - } - - got, err := s.ListAllMXRecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - if err != nil { - - t.Errorf("ListAllMXRecords() error = %v", err) - return - } - - fakeRecordSetClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllMXRecords() got = %v, want %v", got, expected) - } -} - -func Test_ListAllMXRecords_Error(t *testing.T) { - - fakeClient := &mockPrivateRecordSetClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockPrivateDNSRecordSetListPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{}).Times(1) - - fakeClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) - - s := &privateDNSRepository{ - recordClient: fakeClient, - cache: cache.New(0), - } - - got, err := s.ListAllMXRecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -// endregion - -// region SRVRecord -func Test_ListAllSRVRecords_MultiplesResults(t *testing.T) { - - expected := []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - SrvRecords: []*armprivatedns.SrvRecord{ - {Target: to.StringPtr("targetname")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record3"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - SrvRecords: []*armprivatedns.SrvRecord{ - {Target: to.StringPtr("targetname")}, - }, - }, - }, - } - - fakeRecordSetClient := &mockPrivateRecordSetClient{} - - mockPager := &mockPrivateDNSRecordSetListPager{} - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ - RecordSetsListResult: armprivatedns.RecordSetsListResult{ - RecordSetListResult: armprivatedns.RecordSetListResult{ - Value: []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - SrvRecords: []*armprivatedns.SrvRecord{ - {Target: to.StringPtr("targetname")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record2"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{}, - }, - }, - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ - RecordSetsListResult: armprivatedns.RecordSetsListResult{ - RecordSetListResult: armprivatedns.RecordSetListResult{ - Value: []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record3"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - SrvRecords: []*armprivatedns.SrvRecord{ - {Target: to.StringPtr("targetname")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record4"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{}, - }, - }, - }, - }, - }).Times(1) - - fakeRecordSetClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) - - c := &cache.MockCache{} - c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(nil).Times(1) - c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) - c.On("Put", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com", mock.Anything).Return(true).Times(1) - s := &privateDNSRepository{ - recordClient: fakeRecordSetClient, - cache: c, - } - got, err := s.ListAllSRVRecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - if err != nil { - t.Errorf("ListAllSRVRecords() error = %v", err) - return - } - - mockPager.AssertExpectations(t) - fakeRecordSetClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllSRVRecords() got = %v, want %v", got, expected) - } -} - -func Test_ListAllSRVRecords_MultiplesResults_WithCache(t *testing.T) { - - expected := []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - SrvRecords: []*armprivatedns.SrvRecord{ - {Target: to.StringPtr("targetname")}, - }, - }, - }, - } - - fakeRecordSetClient := &mockPrivateRecordSetClient{} - - c := &cache.MockCache{} - c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(expected).Times(1) - c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Times(1) - s := &privateDNSRepository{ - recordClient: fakeRecordSetClient, - cache: c, - } - got, err := s.ListAllSRVRecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - if err != nil { - t.Errorf("ListAllSRVRecords() error = %v", err) - return - } - - fakeRecordSetClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllSRVRecords() got = %v, want %v", got, expected) - } -} - -func Test_ListAllSRVRecords_Error(t *testing.T) { - - fakeClient := &mockPrivateRecordSetClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockPrivateDNSRecordSetListPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{}).Times(1) - - fakeClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) - - s := &privateDNSRepository{ - recordClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllSRVRecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -// endregion - -// region TXTRecord -func Test_ListAllTXTRecords_MultiplesResults(t *testing.T) { - - expected := []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - TxtRecords: []*armprivatedns.TxtRecord{ - {Value: []*string{to.StringPtr("value")}}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record3"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - TxtRecords: []*armprivatedns.TxtRecord{ - {Value: []*string{to.StringPtr("value")}}, - }, - }, - }, - } - - fakeRecordSetClient := &mockPrivateRecordSetClient{} - - mockPager := &mockPrivateDNSRecordSetListPager{} - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ - RecordSetsListResult: armprivatedns.RecordSetsListResult{ - RecordSetListResult: armprivatedns.RecordSetListResult{ - Value: []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - TxtRecords: []*armprivatedns.TxtRecord{ - {Value: []*string{to.StringPtr("value")}}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record2"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{}, - }, - }, - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{ - RecordSetsListResult: armprivatedns.RecordSetsListResult{ - RecordSetListResult: armprivatedns.RecordSetListResult{ - Value: []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record3"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - TxtRecords: []*armprivatedns.TxtRecord{ - {Value: []*string{to.StringPtr("value")}}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record4"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{}, - }, - }, - }, - }, - }).Times(1) - - fakeRecordSetClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) - - c := &cache.MockCache{} - c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(nil).Times(1) - c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return().Times(1) - c.On("Put", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com", mock.Anything).Return(true).Times(1) - s := &privateDNSRepository{ - recordClient: fakeRecordSetClient, - cache: c, - } - got, err := s.ListAllTXTRecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - if err != nil { - t.Errorf("ListAllTXTRecords() error = %v", err) - return - } - - mockPager.AssertExpectations(t) - fakeRecordSetClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllTXTRecords() got = %v, want %v", got, expected) - } -} - -func Test_ListAllTXTRecords_MultiplesResults_WithCache(t *testing.T) { - - expected := []*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("record1"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - TxtRecords: []*armprivatedns.TxtRecord{ - {Value: []*string{to.StringPtr("value")}}, - }, - }, - }, - } - - fakeRecordSetClient := &mockPrivateRecordSetClient{} - - c := &cache.MockCache{} - c.On("GetAndLock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Return(expected).Times(1) - c.On("Unlock", "privateDNSlistAllRecords-/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com").Times(1) - s := &privateDNSRepository{ - recordClient: fakeRecordSetClient, - cache: c, - } - got, err := s.ListAllTXTRecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - if err != nil { - t.Errorf("ListAllTXTRecords() error = %v", err) - return - } - - fakeRecordSetClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllTXTRecords() got = %v, want %v", got, expected) - } -} - -func Test_ListAllTXTRecords_Error(t *testing.T) { - - fakeClient := &mockPrivateRecordSetClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockPrivateDNSRecordSetListPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armprivatedns.RecordSetsListResponse{}).Times(1) - - fakeClient.On("List", "rgid", "zone", (*armprivatedns.RecordSetsListOptions)(nil)).Return(mockPager) - - s := &privateDNSRepository{ - recordClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllTXTRecords(&armprivatedns.PrivateZone{ - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/subid/resourceGroups/rgid/providers/Microsoft.Network/privateDnsZones/zone.com"), - Name: to.StringPtr("zone"), - }, - }, - }) - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -// endregion diff --git a/pkg/remote/azurerm/repository/resources.go b/pkg/remote/azurerm/repository/resources.go deleted file mode 100644 index c7ba9bd0..00000000 --- a/pkg/remote/azurerm/repository/resources.go +++ /dev/null @@ -1,68 +0,0 @@ -package repository - -import ( - "context" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" - "github.com/snyk/driftctl/pkg/remote/azurerm/common" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type ResourcesRepository interface { - ListAllResourceGroups() ([]*armresources.ResourceGroup, error) -} - -type resourcesListPager interface { - pager - PageResponse() armresources.ResourceGroupsListResponse -} - -type resourcesClient interface { - List(options *armresources.ResourceGroupsListOptions) resourcesListPager -} - -type resourcesClientImpl struct { - client *armresources.ResourceGroupsClient -} - -func (c resourcesClientImpl) List(options *armresources.ResourceGroupsListOptions) resourcesListPager { - return c.client.List(options) -} - -type resourcesRepository struct { - client resourcesClient - cache cache.Cache -} - -func NewResourcesRepository(cred azcore.TokenCredential, options *arm.ClientOptions, config common.AzureProviderConfig, cache cache.Cache) *resourcesRepository { - return &resourcesRepository{ - &resourcesClientImpl{armresources.NewResourceGroupsClient(config.SubscriptionID, cred, options)}, - cache, - } -} - -func (s *resourcesRepository) ListAllResourceGroups() ([]*armresources.ResourceGroup, error) { - cacheKey := "resourcesListAllResourceGroups" - if v := s.cache.Get(cacheKey); v != nil { - return v.([]*armresources.ResourceGroup), nil - } - - pager := s.client.List(nil) - results := make([]*armresources.ResourceGroup, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - results = append(results, resp.ResourceGroupsListResult.Value...) - } - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put(cacheKey, results) - - return results, nil -} diff --git a/pkg/remote/azurerm/repository/resources_test.go b/pkg/remote/azurerm/repository/resources_test.go deleted file mode 100644 index 841ee839..00000000 --- a/pkg/remote/azurerm/repository/resources_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package repository - -import ( - "reflect" - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" - "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func Test_Resources_ListAllResourceGroups(t *testing.T) { - expectedResults := []*armresources.ResourceGroup{ - { - ID: to.StringPtr("/subscriptions/008b5f48-1b66-4d92-a6b6-d215b4c9b473/resourceGroups/elie-dev"), - Name: to.StringPtr("elie-dev"), - }, - { - ID: to.StringPtr("/subscriptions/008b5f48-1b66-4d92-a6b6-d215b4c9b473/resourceGroups/william-dev"), - Name: to.StringPtr("william-dev"), - }, - { - ID: to.StringPtr("/subscriptions/008b5f48-1b66-4d92-a6b6-d215b4c9b473/resourceGroups/driftctl-sj-tests"), - Name: to.StringPtr("driftctl-sj-tests"), - }, - } - - testcases := []struct { - name string - mocks func(*mockResourcesListPager, *cache.MockCache) - expected []*armresources.ResourceGroup - wantErr string - }{ - { - name: "should return resource groups", - mocks: func(mockPager *mockResourcesListPager, mockCache *cache.MockCache) { - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armresources.ResourceGroupsListResponse{ - ResourceGroupsListResult: armresources.ResourceGroupsListResult{ - ResourceGroupListResult: armresources.ResourceGroupListResult{ - Value: []*armresources.ResourceGroup{ - { - ID: to.StringPtr("/subscriptions/008b5f48-1b66-4d92-a6b6-d215b4c9b473/resourceGroups/elie-dev"), - Name: to.StringPtr("elie-dev"), - }, - { - ID: to.StringPtr("/subscriptions/008b5f48-1b66-4d92-a6b6-d215b4c9b473/resourceGroups/william-dev"), - Name: to.StringPtr("william-dev"), - }, - }, - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armresources.ResourceGroupsListResponse{ - ResourceGroupsListResult: armresources.ResourceGroupsListResult{ - ResourceGroupListResult: armresources.ResourceGroupListResult{ - Value: []*armresources.ResourceGroup{ - { - ID: to.StringPtr("/subscriptions/008b5f48-1b66-4d92-a6b6-d215b4c9b473/resourceGroups/driftctl-sj-tests"), - Name: to.StringPtr("driftctl-sj-tests"), - }, - }, - }, - }, - }).Times(1) - - mockCache.On("Get", "resourcesListAllResourceGroups").Return(nil).Times(1) - mockCache.On("Put", "resourcesListAllResourceGroups", expectedResults).Return(true).Times(1) - }, - expected: expectedResults, - }, - { - name: "should hit cache and return resource groups", - mocks: func(mockPager *mockResourcesListPager, mockCache *cache.MockCache) { - mockCache.On("Get", "resourcesListAllResourceGroups").Return(expectedResults).Times(1) - }, - expected: expectedResults, - }, - { - name: "should return remote error", - mocks: func(mockPager *mockResourcesListPager, mockCache *cache.MockCache) { - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armresources.ResourceGroupsListResponse{ - ResourceGroupsListResult: armresources.ResourceGroupsListResult{ - ResourceGroupListResult: armresources.ResourceGroupListResult{ - Value: []*armresources.ResourceGroup{}, - }, - }, - }).Times(1) - mockPager.On("Err").Return(errors.New("remote error")).Times(1) - - mockCache.On("Get", "resourcesListAllResourceGroups").Return(nil).Times(1) - }, - wantErr: "remote error", - }, - { - name: "should return remote error after fetching all pages", - mocks: func(mockPager *mockResourcesListPager, mockCache *cache.MockCache) { - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armresources.ResourceGroupsListResponse{ - ResourceGroupsListResult: armresources.ResourceGroupsListResult{ - ResourceGroupListResult: armresources.ResourceGroupListResult{ - Value: []*armresources.ResourceGroup{}, - }, - }, - }).Times(1) - mockPager.On("Err").Return(nil).Times(1) - mockPager.On("Err").Return(errors.New("remote error")).Times(1) - - mockCache.On("Get", "resourcesListAllResourceGroups").Return(nil).Times(1) - }, - wantErr: "remote error", - }, - } - - for _, tt := range testcases { - t.Run(tt.name, func(t *testing.T) { - fakeClient := &mockResourcesClient{} - mockPager := &mockResourcesListPager{} - mockCache := &cache.MockCache{} - - fakeClient.On("List", mock.Anything).Maybe().Return(mockPager) - - tt.mocks(mockPager, mockCache) - - s := &resourcesRepository{ - client: fakeClient, - cache: mockCache, - } - got, err := s.ListAllResourceGroups() - if tt.wantErr != "" { - assert.EqualError(t, err, tt.wantErr) - } else { - assert.Nil(t, err) - } - - fakeClient.AssertExpectations(t) - mockPager.AssertExpectations(t) - mockCache.AssertExpectations(t) - - if !reflect.DeepEqual(got, tt.expected) { - t.Errorf("ListAllResourceGroups() got = %v, want %v", got, tt.expected) - } - }) - } -} diff --git a/pkg/remote/azurerm/repository/storage.go b/pkg/remote/azurerm/repository/storage.go deleted file mode 100644 index 11791ff6..00000000 --- a/pkg/remote/azurerm/repository/storage.go +++ /dev/null @@ -1,128 +0,0 @@ -package repository - -import ( - "context" - "fmt" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage" - "github.com/Azure/go-autorest/autorest/azure" - "github.com/snyk/driftctl/pkg/remote/azurerm/common" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type StorageRespository interface { - ListAllStorageAccount() ([]*armstorage.StorageAccount, error) - ListAllStorageContainer(account *armstorage.StorageAccount) ([]string, error) -} - -type blobContainerListPager interface { - pager - PageResponse() armstorage.BlobContainersListResponse -} - -// Interfaces are only used to create mock on Azure SDK -type blobContainerClient interface { - List(resourceGroupName string, accountName string, options *armstorage.BlobContainersListOptions) blobContainerListPager -} - -type blobContainerClientImpl struct { - client *armstorage.BlobContainersClient -} - -func (c blobContainerClientImpl) List(resourceGroupName string, accountName string, options *armstorage.BlobContainersListOptions) blobContainerListPager { - return c.client.List(resourceGroupName, accountName, options) -} - -type storageAccountListPager interface { - pager - PageResponse() armstorage.StorageAccountsListResponse -} - -type storageAccountClient interface { - List(options *armstorage.StorageAccountsListOptions) storageAccountListPager -} - -type storageAccountClientImpl struct { - client *armstorage.StorageAccountsClient -} - -func (c storageAccountClientImpl) List(options *armstorage.StorageAccountsListOptions) storageAccountListPager { - return c.client.List(options) -} - -type storageRepository struct { - storageAccountsClient storageAccountClient - blobContainerClient blobContainerClient - cache cache.Cache -} - -func NewStorageRepository(cred azcore.TokenCredential, options *arm.ClientOptions, config common.AzureProviderConfig, cache cache.Cache) *storageRepository { - return &storageRepository{ - storageAccountClientImpl{client: armstorage.NewStorageAccountsClient(config.SubscriptionID, cred, options)}, - blobContainerClientImpl{client: armstorage.NewBlobContainersClient(config.SubscriptionID, cred, options)}, - cache, - } -} - -func (s *storageRepository) ListAllStorageAccount() ([]*armstorage.StorageAccount, error) { - - cacheKey := "ListAllStorageAccount" - v := s.cache.GetAndLock(cacheKey) - defer s.cache.Unlock(cacheKey) - if v != nil { - return v.([]*armstorage.StorageAccount), nil - } - - pager := s.storageAccountsClient.List(nil) - results := make([]*armstorage.StorageAccount, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - results = append(results, resp.StorageAccountsListResult.StorageAccountListResult.Value...) - } - - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put(cacheKey, results) - - return results, nil -} - -func (s *storageRepository) ListAllStorageContainer(account *armstorage.StorageAccount) ([]string, error) { - - cacheKey := fmt.Sprintf("ListAllStorageContainer_%s", *account.Name) - if v := s.cache.Get(cacheKey); v != nil { - return v.([]string), nil - } - - res, err := azure.ParseResourceID(*account.ID) - if err != nil { - return nil, err - } - - pager := s.blobContainerClient.List(res.ResourceGroup, *account.Name, nil) - results := make([]string, 0) - for pager.NextPage(context.Background()) { - resp := pager.PageResponse() - if err := pager.Err(); err != nil { - return nil, err - } - for _, item := range resp.BlobContainersListResult.ListContainerItems.Value { - results = append(results, fmt.Sprintf("%s%s", *account.Properties.PrimaryEndpoints.Blob, *item.Name)) - } - } - - if err := pager.Err(); err != nil { - return nil, err - } - - s.cache.Put(cacheKey, results) - - return results, nil -} diff --git a/pkg/remote/azurerm/repository/storage_test.go b/pkg/remote/azurerm/repository/storage_test.go deleted file mode 100644 index b9dc424c..00000000 --- a/pkg/remote/azurerm/repository/storage_test.go +++ /dev/null @@ -1,373 +0,0 @@ -package repository - -import ( - "reflect" - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage" - "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func Test_ListAllStorageAccount_MultiplesResults(t *testing.T) { - - expected := []*armstorage.StorageAccount{ - { - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("account1"), - }, - }, - }, - { - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("account2"), - }, - }, - }, - { - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("account3"), - }, - }, - }, - { - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("account4"), - }, - }, - }, - } - - fakeClient := &mockStorageAccountClient{} - - mockPager := &mockStorageAccountListPager{} - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armstorage.StorageAccountsListResponse{ - StorageAccountsListResult: armstorage.StorageAccountsListResult{ - StorageAccountListResult: armstorage.StorageAccountListResult{ - Value: []*armstorage.StorageAccount{ - { - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("account1"), - }, - }, - }, - { - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("account2"), - }, - }, - }, - }, - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armstorage.StorageAccountsListResponse{ - StorageAccountsListResult: armstorage.StorageAccountsListResult{ - StorageAccountListResult: armstorage.StorageAccountListResult{ - Value: []*armstorage.StorageAccount{ - { - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("account3"), - }, - }, - }, - { - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("account4"), - }, - }, - }, - }, - }, - }, - }).Times(1) - - fakeClient.On("List", mock.Anything).Return(mockPager) - - c := &cache.MockCache{} - c.On("GetAndLock", "ListAllStorageAccount").Return(nil).Times(1) - c.On("Unlock", "ListAllStorageAccount").Times(1) - c.On("Put", "ListAllStorageAccount", expected).Return(true).Times(1) - s := &storageRepository{ - storageAccountsClient: fakeClient, - cache: c, - } - got, err := s.ListAllStorageAccount() - if err != nil { - t.Errorf("ListAllStorageAccount() error = %v", err) - return - } - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllStorageAccount() got = %v, want %v", got, expected) - } -} - -func Test_ListAllStorageAccount_MultiplesResults_WithCache(t *testing.T) { - - expected := []*armstorage.StorageAccount{ - { - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("account1"), - }, - }, - }, - } - - fakeClient := &mockStorageAccountClient{} - - c := &cache.MockCache{} - c.On("GetAndLock", "ListAllStorageAccount").Return(expected).Times(1) - c.On("Unlock", "ListAllStorageAccount").Times(1) - s := &storageRepository{ - storageAccountsClient: fakeClient, - cache: c, - } - got, err := s.ListAllStorageAccount() - if err != nil { - t.Errorf("ListAllStorageAccount() error = %v", err) - return - } - - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - if !reflect.DeepEqual(got, expected) { - t.Errorf("ListAllStorageAccount() got = %v, want %v", got, expected) - } -} - -func Test_ListAllStorageAccount_Error(t *testing.T) { - - fakeClient := &mockStorageAccountClient{} - - expectedErr := errors.New("unexpected error") - - mockPager := &mockStorageAccountListPager{} - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("PageResponse").Return(armstorage.StorageAccountsListResponse{}).Times(1) - - fakeClient.On("List", mock.Anything).Return(mockPager) - - s := &storageRepository{ - storageAccountsClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllStorageAccount() - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - - assert.Equal(t, expectedErr, err) - assert.Nil(t, got) -} - -func Test_ListAllStorageContainer_MultiplesResults(t *testing.T) { - - account := armstorage.StorageAccount{ - Properties: &armstorage.StorageAccountProperties{ - PrimaryEndpoints: &armstorage.Endpoints{ - Blob: to.StringPtr("https://testeliedriftctl.blob.core.windows.net/"), - }, - }, - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: to.StringPtr("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/foobar/providers/Microsoft.Storage/storageAccounts/testeliedriftctl"), - Name: to.StringPtr("testeliedriftctl"), - }, - }, - } - - expected := []string{ - "https://testeliedriftctl.blob.core.windows.net/container1", - "https://testeliedriftctl.blob.core.windows.net/container2", - "https://testeliedriftctl.blob.core.windows.net/container3", - "https://testeliedriftctl.blob.core.windows.net/container4", - } - - fakeClient := &mockBlobContainerClient{} - - mockPager := &mockBlobContainerListPager{} - mockPager.On("Err").Return(nil).Times(3) - mockPager.On("NextPage", mock.Anything).Return(true).Times(2) - mockPager.On("NextPage", mock.Anything).Return(false).Times(1) - mockPager.On("PageResponse").Return(armstorage.BlobContainersListResponse{ - BlobContainersListResult: armstorage.BlobContainersListResult{ - ListContainerItems: armstorage.ListContainerItems{ - Value: []*armstorage.ListContainerItem{ - { - AzureEntityResource: armstorage.AzureEntityResource{ - Resource: armstorage.Resource{ - Name: to.StringPtr("container1"), - }, - }, - }, - { - AzureEntityResource: armstorage.AzureEntityResource{ - Resource: armstorage.Resource{ - Name: to.StringPtr("container2"), - }, - }, - }, - }, - }, - }, - }).Times(1) - mockPager.On("PageResponse").Return(armstorage.BlobContainersListResponse{ - BlobContainersListResult: armstorage.BlobContainersListResult{ - ListContainerItems: armstorage.ListContainerItems{ - Value: []*armstorage.ListContainerItem{ - { - AzureEntityResource: armstorage.AzureEntityResource{ - Resource: armstorage.Resource{ - Name: to.StringPtr("container3"), - }, - }, - }, - { - AzureEntityResource: armstorage.AzureEntityResource{ - Resource: armstorage.Resource{ - Name: to.StringPtr("container4"), - }, - }, - }, - }, - }, - }, - }).Times(1) - - fakeClient.On("List", "foobar", "testeliedriftctl", (*armstorage.BlobContainersListOptions)(nil)).Return(mockPager) - - c := &cache.MockCache{} - c.On("Get", "ListAllStorageContainer_testeliedriftctl").Return(nil).Times(1) - c.On("Put", "ListAllStorageContainer_testeliedriftctl", expected).Return(true).Times(1) - s := &storageRepository{ - blobContainerClient: fakeClient, - cache: c, - } - got, err := s.ListAllStorageContainer(&account) - if err != nil { - t.Errorf("ListAllStorageAccount() error = %v", err) - return - } - - mockPager.AssertExpectations(t) - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - assert.Equal(t, expected, got) -} - -func Test_ListAllStorageContainer_MultiplesResults_WithCache(t *testing.T) { - - account := armstorage.StorageAccount{ - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: to.StringPtr("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/foobar/providers/Microsoft.Storage/storageAccounts/testeliedriftctl"), - Name: to.StringPtr("testeliedriftctl"), - }, - }, - } - - expected := []string{ - "https://testeliedriftctl.blob.core.windows.net/container1", - } - - fakeClient := &mockBlobContainerClient{} - - c := &cache.MockCache{} - c.On("Get", "ListAllStorageContainer_testeliedriftctl").Return(expected).Times(1) - s := &storageRepository{ - blobContainerClient: fakeClient, - cache: c, - } - got, err := s.ListAllStorageContainer(&account) - if err != nil { - t.Errorf("ListAllStorageAccount() error = %v", err) - return - } - - fakeClient.AssertExpectations(t) - c.AssertExpectations(t) - - assert.Equal(t, expected, got) -} - -func Test_ListAllStorageContainer_InvalidStorageAccountResourceID(t *testing.T) { - - account := armstorage.StorageAccount{ - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: to.StringPtr("foobar"), - Name: to.StringPtr(""), - }, - }, - } - - fakeClient := &mockBlobContainerClient{} - - s := &storageRepository{ - blobContainerClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllStorageContainer(&account) - - fakeClient.AssertExpectations(t) - - assert.Nil(t, got) - assert.Equal(t, "parsing failed for foobar. Invalid resource Id format", err.Error()) -} - -func Test_ListAllStorageContainer_Error(t *testing.T) { - - account := armstorage.StorageAccount{ - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: to.StringPtr("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/foobar/providers/Microsoft.Storage/storageAccounts/testeliedriftctl"), - Name: to.StringPtr("testeliedriftctl"), - }, - }, - } - - expectedErr := errors.New("sample error") - - fakeClient := &mockBlobContainerClient{} - mockPager := &mockBlobContainerListPager{} - mockPager.On("NextPage", mock.Anything).Return(true).Times(1) - mockPager.On("Err").Return(expectedErr).Times(1) - mockPager.On("PageResponse").Return(armstorage.BlobContainersListResponse{}).Times(1) - - fakeClient.On("List", "foobar", "testeliedriftctl", (*armstorage.BlobContainersListOptions)(nil)).Return(mockPager) - - s := &storageRepository{ - blobContainerClient: fakeClient, - cache: cache.New(0), - } - got, err := s.ListAllStorageContainer(&account) - - fakeClient.AssertExpectations(t) - mockPager.AssertExpectations(t) - - assert.Nil(t, got) - assert.Equal(t, expectedErr, err) -} diff --git a/pkg/remote/azurerm_compute_scanner_test.go b/pkg/remote/azurerm_compute_scanner_test.go deleted file mode 100644 index 1aa86a72..00000000 --- a/pkg/remote/azurerm_compute_scanner_test.go +++ /dev/null @@ -1,233 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/azurerm" - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceazure "github.com/snyk/driftctl/pkg/resource/azurerm" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestAzurermCompute_Image(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockComputeRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no images", - mocks: func(repository *repository.MockComputeRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllImages").Return([]*armcompute.Image{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "error listing images", - mocks: func(repository *repository.MockComputeRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllImages").Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzureImageResourceType), - }, - { - test: "multiple images including an invalid ID", - mocks: func(repository *repository.MockComputeRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllImages").Return([]*armcompute.Image{ - { - Resource: armcompute.Resource{ - ID: to.StringPtr("/subscriptions/4e411884-65b0-4911-bc80-52f9a21942a2/resourceGroups/testgroup/providers/Microsoft.Compute/images/image1"), - Name: to.StringPtr("image1"), - }, - }, - { - Resource: armcompute.Resource{ - ID: to.StringPtr("/subscriptions/4e411884-65b0-4911-bc80-52f9a21942a2/resourceGroups/testgroup/providers/Microsoft.Compute/images/image2"), - Name: to.StringPtr("image2"), - }, - }, - { - Resource: armcompute.Resource{ - ID: to.StringPtr("/invalid-id/image3"), - Name: to.StringPtr("image3"), - }, - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "/subscriptions/4e411884-65b0-4911-bc80-52f9a21942a2/resourceGroups/testgroup/providers/Microsoft.Compute/images/image1") - assert.Equal(t, got[0].ResourceType(), resourceazure.AzureImageResourceType) - - assert.Equal(t, got[1].ResourceId(), "/subscriptions/4e411884-65b0-4911-bc80-52f9a21942a2/resourceGroups/testgroup/providers/Microsoft.Compute/images/image2") - assert.Equal(t, got[1].ResourceType(), resourceazure.AzureImageResourceType) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockComputeRepository{} - c.mocks(fakeRepo, alerter) - - remoteLibrary.AddEnumerator(azurerm.NewAzurermImageEnumerator(fakeRepo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermCompute_SSHPublicKey(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - dirName string - mocks func(*repository.MockComputeRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no public key", - dirName: "azurerm_ssh_public_key_empty", - mocks: func(repository *repository.MockComputeRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllSSHPublicKeys").Return([]*armcompute.SSHPublicKeyResource{}, nil) - }, - }, - { - test: "error listing public keys", - dirName: "azurerm_ssh_public_key_empty", - mocks: func(repository *repository.MockComputeRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllSSHPublicKeys").Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzureSSHPublicKeyResourceType), - }, - { - test: "multiple public keys", - dirName: "azurerm_ssh_public_key_multiple", - mocks: func(repository *repository.MockComputeRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllSSHPublicKeys").Return([]*armcompute.SSHPublicKeyResource{ - { - Resource: armcompute.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/TESTRESGROUP/providers/Microsoft.Compute/sshPublicKeys/example-key"), - Name: to.StringPtr("example-key"), - }, - }, - { - Resource: armcompute.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/TESTRESGROUP/providers/Microsoft.Compute/sshPublicKeys/example-key2"), - Name: to.StringPtr("example-key2"), - }, - }, - }, nil) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockComputeRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ComputeRepository = fakeRepo - providerVersion := "2.71.0" - realProvider, err := terraform2.InitTestAzureProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) - if err != nil { - t.Fatal(err) - } - clientOptions := &arm.ClientOptions{} - repo = repository.NewComputeRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(azurerm.NewAzurermSSHPublicKeyEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceazure.AzureSSHPublicKeyResourceType, common.NewGenericDetailsFetcher(resourceazure.AzureSSHPublicKeyResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceazure.AzureSSHPublicKeyResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/azurerm_containerregistry_scanner_test.go b/pkg/remote/azurerm_containerregistry_scanner_test.go deleted file mode 100644 index 718b1282..00000000 --- a/pkg/remote/azurerm_containerregistry_scanner_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/azurerm" - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - "github.com/snyk/driftctl/pkg/remote/common" - error2 "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceazure "github.com/snyk/driftctl/pkg/resource/azurerm" - "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestAzurermContainerRegistry(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockContainerRegistryRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no container registry", - mocks: func(repository *repository.MockContainerRegistryRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllContainerRegistries").Return([]*armcontainerregistry.Registry{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "error listing container registry", - mocks: func(repository *repository.MockContainerRegistryRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllContainerRegistries").Return(nil, dummyError) - }, - wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureContainerRegistryResourceType), - }, - { - test: "multiple container registries", - mocks: func(repository *repository.MockContainerRegistryRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllContainerRegistries").Return([]*armcontainerregistry.Registry{ - { - Resource: armcontainerregistry.Resource{ - ID: to.StringPtr("registry1"), - }, - }, - { - Resource: armcontainerregistry.Resource{ - ID: to.StringPtr("registry2"), - }, - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "registry1") - assert.Equal(t, got[0].ResourceType(), resourceazure.AzureContainerRegistryResourceType) - - assert.Equal(t, got[1].ResourceId(), "registry2") - assert.Equal(t, got[1].ResourceType(), resourceazure.AzureContainerRegistryResourceType) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockContainerRegistryRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ContainerRegistryRepository = fakeRepo - - remoteLibrary.AddEnumerator(azurerm.NewAzurermContainerRegistryEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/azurerm_network_scanner_test.go b/pkg/remote/azurerm_network_scanner_test.go deleted file mode 100644 index fddf7267..00000000 --- a/pkg/remote/azurerm_network_scanner_test.go +++ /dev/null @@ -1,1012 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/azurerm" - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - error2 "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceazure "github.com/snyk/driftctl/pkg/resource/azurerm" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestAzurermVirtualNetwork(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockNetworkRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no virtual network", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllVirtualNetworks").Return([]*armnetwork.VirtualNetwork{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "error listing virtual network", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllVirtualNetworks").Return(nil, dummyError) - }, - wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureVirtualNetworkResourceType), - }, - { - test: "multiple virtual network", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllVirtualNetworks").Return([]*armnetwork.VirtualNetwork{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("network1"), - Name: to.StringPtr("network1"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("network2"), - Name: to.StringPtr("network2"), - }, - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "network1") - assert.Equal(t, got[0].ResourceType(), resourceazure.AzureVirtualNetworkResourceType) - - assert.Equal(t, got[1].ResourceId(), "network2") - assert.Equal(t, got[1].ResourceType(), resourceazure.AzureVirtualNetworkResourceType) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockNetworkRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.NetworkRepository = fakeRepo - - remoteLibrary.AddEnumerator(azurerm.NewAzurermVirtualNetworkEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermRouteTables(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockNetworkRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no route tables", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRouteTables").Return([]*armnetwork.RouteTable{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "error listing route tables", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRouteTables").Return(nil, dummyError) - }, - wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureRouteTableResourceType), - }, - { - test: "multiple route tables", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRouteTables").Return([]*armnetwork.RouteTable{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("route1"), - Name: to.StringPtr("route1"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("route2"), - Name: to.StringPtr("route2"), - }, - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "route1") - assert.Equal(t, got[0].ResourceType(), resourceazure.AzureRouteTableResourceType) - - assert.Equal(t, got[1].ResourceId(), "route2") - assert.Equal(t, got[1].ResourceType(), resourceazure.AzureRouteTableResourceType) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockNetworkRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.NetworkRepository = fakeRepo - - remoteLibrary.AddEnumerator(azurerm.NewAzurermRouteTableEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermRoutes(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockNetworkRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no route tables", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRouteTables").Return([]*armnetwork.RouteTable{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "no routes", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRouteTables").Return([]*armnetwork.RouteTable{ - { - Properties: &armnetwork.RouteTablePropertiesFormat{ - Routes: []*armnetwork.Route{}, - }, - }, - { - Properties: &armnetwork.RouteTablePropertiesFormat{ - Routes: []*armnetwork.Route{}, - }, - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "error listing route tables", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRouteTables").Return(nil, dummyError) - }, - wantErr: error2.NewResourceListingErrorWithType(dummyError, resourceazure.AzureRouteResourceType, resourceazure.AzureRouteTableResourceType), - }, - { - test: "multiple routes", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllRouteTables").Return([]*armnetwork.RouteTable{ - { - Resource: armnetwork.Resource{ - Name: to.StringPtr("table1"), - }, - Properties: &armnetwork.RouteTablePropertiesFormat{ - Routes: []*armnetwork.Route{ - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("route1"), - }, - Name: to.StringPtr("route1"), - }, - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("route2"), - }, - Name: to.StringPtr("route2"), - }, - }, - }, - }, - { - Resource: armnetwork.Resource{ - Name: to.StringPtr("table2"), - }, - Properties: &armnetwork.RouteTablePropertiesFormat{ - Routes: []*armnetwork.Route{ - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("route3"), - }, - Name: to.StringPtr("route3"), - }, - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("route4"), - }, - Name: to.StringPtr("route4"), - }, - }, - }, - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 4) - - assert.Equal(t, "route1", got[0].ResourceId()) - assert.Equal(t, resourceazure.AzureRouteResourceType, got[0].ResourceType()) - - assert.Equal(t, "route2", got[1].ResourceId()) - assert.Equal(t, resourceazure.AzureRouteResourceType, got[1].ResourceType()) - - assert.Equal(t, "route3", got[2].ResourceId()) - assert.Equal(t, resourceazure.AzureRouteResourceType, got[2].ResourceType()) - - assert.Equal(t, "route4", got[3].ResourceId()) - assert.Equal(t, resourceazure.AzureRouteResourceType, got[3].ResourceType()) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockNetworkRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.NetworkRepository = fakeRepo - - remoteLibrary.AddEnumerator(azurerm.NewAzurermRouteEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermSubnets(t *testing.T) { - - dummyError := errors.New("this is an error") - - networks := []*armnetwork.VirtualNetwork{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("network1"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("network2"), - }, - }, - } - - tests := []struct { - test string - mocks func(*repository.MockNetworkRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no subnets", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllVirtualNetworks").Return(networks, nil) - repository.On("ListAllSubnets", networks[0]).Return([]*armnetwork.Subnet{}, nil).Times(1) - repository.On("ListAllSubnets", networks[1]).Return([]*armnetwork.Subnet{}, nil).Times(1) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "error listing virtual network", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllVirtualNetworks").Return(nil, dummyError) - }, - wantErr: error2.NewResourceListingErrorWithType(dummyError, resourceazure.AzureSubnetResourceType, resourceazure.AzureVirtualNetworkResourceType), - }, - { - test: "error listing subnets", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllVirtualNetworks").Return(networks, nil) - repository.On("ListAllSubnets", networks[0]).Return(nil, dummyError).Times(1) - }, - wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureSubnetResourceType), - }, - { - test: "multiple subnets", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllVirtualNetworks").Return(networks, nil) - repository.On("ListAllSubnets", networks[0]).Return([]*armnetwork.Subnet{ - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("subnet1"), - }, - }, - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("subnet2"), - }, - }, - }, nil).Times(1) - repository.On("ListAllSubnets", networks[1]).Return([]*armnetwork.Subnet{ - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("subnet3"), - }, - }, - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("subnet4"), - }, - }, - }, nil).Times(1) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 4) - - assert.Equal(t, got[0].ResourceId(), "subnet1") - assert.Equal(t, got[0].ResourceType(), resourceazure.AzureSubnetResourceType) - - assert.Equal(t, got[1].ResourceId(), "subnet2") - assert.Equal(t, got[1].ResourceType(), resourceazure.AzureSubnetResourceType) - - assert.Equal(t, got[2].ResourceId(), "subnet3") - assert.Equal(t, got[2].ResourceType(), resourceazure.AzureSubnetResourceType) - - assert.Equal(t, got[3].ResourceId(), "subnet4") - assert.Equal(t, got[3].ResourceType(), resourceazure.AzureSubnetResourceType) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockNetworkRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.NetworkRepository = fakeRepo - - remoteLibrary.AddEnumerator(azurerm.NewAzurermSubnetEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermFirewalls(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockNetworkRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no firewall", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllFirewalls").Return([]*armnetwork.AzureFirewall{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "error listing firewalls", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllFirewalls").Return(nil, dummyError) - }, - wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureFirewallResourceType), - }, - { - test: "multiple firewalls", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllFirewalls").Return([]*armnetwork.AzureFirewall{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("firewall1"), // Here we don't care to have a valid ID, it is for testing purpose only - Name: to.StringPtr("firewall1"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("firewall2"), - Name: to.StringPtr("firewall2"), - }, - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "firewall1") - assert.Equal(t, got[0].ResourceType(), resourceazure.AzureFirewallResourceType) - - assert.Equal(t, got[1].ResourceId(), "firewall2") - assert.Equal(t, got[1].ResourceType(), resourceazure.AzureFirewallResourceType) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockNetworkRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.NetworkRepository = fakeRepo - - remoteLibrary.AddEnumerator(azurerm.NewAzurermFirewallsEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermPublicIP(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockNetworkRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no public IP", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPublicIPAddresses").Return([]*armnetwork.PublicIPAddress{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "error listing public IPs", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPublicIPAddresses").Return(nil, dummyError) - }, - wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzurePublicIPResourceType), - }, - { - test: "multiple public IP", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPublicIPAddresses").Return([]*armnetwork.PublicIPAddress{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("ip1"), // Here we don't care to have a valid ID, it is for testing purpose only - Name: to.StringPtr("ip1"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("ip2"), - Name: to.StringPtr("ip2"), - }, - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "ip1") - assert.Equal(t, got[0].ResourceType(), resourceazure.AzurePublicIPResourceType) - - assert.Equal(t, got[1].ResourceId(), "ip2") - assert.Equal(t, got[1].ResourceType(), resourceazure.AzurePublicIPResourceType) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockNetworkRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.NetworkRepository = fakeRepo - - remoteLibrary.AddEnumerator(azurerm.NewAzurermPublicIPEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermSecurityGroups(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - dirName string - mocks func(*repository.MockNetworkRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no security group", - dirName: "azurerm_network_security_group_empty", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllSecurityGroups").Return([]*armnetwork.NetworkSecurityGroup{}, nil) - }, - }, - { - test: "error listing security groups", - dirName: "azurerm_network_security_group_empty", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllSecurityGroups").Return(nil, dummyError) - }, - wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureNetworkSecurityGroupResourceType), - }, - { - test: "multiple security groups", - dirName: "azurerm_network_security_group_multiple", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllSecurityGroups").Return([]*armnetwork.NetworkSecurityGroup{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/example-resources/providers/Microsoft.Network/networkSecurityGroups/acceptanceTestSecurityGroup1"), - Name: to.StringPtr("acceptanceTestSecurityGroup1"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/example-resources/providers/Microsoft.Network/networkSecurityGroups/acceptanceTestSecurityGroup2"), - Name: to.StringPtr("acceptanceTestSecurityGroup2"), - }, - }, - }, nil) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockNetworkRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.NetworkRepository = fakeRepo - providerVersion := "2.71.0" - realProvider, err := terraform2.InitTestAzureProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) - if err != nil { - t.Fatal(err) - } - clientOptions := &arm.ClientOptions{} - repo = repository.NewNetworkRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(azurerm.NewAzurermNetworkSecurityGroupEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceazure.AzureNetworkSecurityGroupResourceType, common.NewGenericDetailsFetcher(resourceazure.AzureNetworkSecurityGroupResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceazure.AzureNetworkSecurityGroupResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermLoadBalancers(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockNetworkRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no load balancer", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllLoadBalancers").Return([]*armnetwork.LoadBalancer{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "error listing load balancers", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllLoadBalancers").Return(nil, dummyError) - }, - wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureLoadBalancerResourceType), - }, - { - test: "multiple load balancers", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllLoadBalancers").Return([]*armnetwork.LoadBalancer{ - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("lb-1"), // Here we don't care to have a valid ID, it is for testing purpose only - Name: to.StringPtr("lb-1"), - }, - }, - { - Resource: armnetwork.Resource{ - ID: to.StringPtr("lb-2"), - Name: to.StringPtr("lb-2"), - }, - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "lb-1") - assert.Equal(t, got[0].ResourceType(), resourceazure.AzureLoadBalancerResourceType) - - assert.Equal(t, got[1].ResourceId(), "lb-2") - assert.Equal(t, got[1].ResourceType(), resourceazure.AzureLoadBalancerResourceType) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockNetworkRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.NetworkRepository = fakeRepo - - remoteLibrary.AddEnumerator(azurerm.NewAzurermLoadBalancerEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermLoadBalancerRules(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - dirName string - mocks func(*repository.MockNetworkRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no load balancer rule", - dirName: "azurerm_lb_rule_empty", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - loadbalancer := &armnetwork.LoadBalancer{ - Resource: armnetwork.Resource{ - ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress"), - Name: to.StringPtr("testlb"), - }, - } - - repository.On("ListAllLoadBalancers").Return([]*armnetwork.LoadBalancer{loadbalancer}, nil) - - repository.On("ListLoadBalancerRules", loadbalancer).Return([]*armnetwork.LoadBalancingRule{}, nil) - }, - }, - { - test: "error listing load balancer rules", - dirName: "azurerm_lb_rule_empty", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllLoadBalancers").Return(nil, dummyError) - }, - wantErr: error2.NewResourceListingErrorWithType(dummyError, resourceazure.AzureLoadBalancerRuleResourceType, resourceazure.AzureLoadBalancerResourceType), - }, - { - test: "multiple load balancer rules", - dirName: "azurerm_lb_rule_multiple", - mocks: func(repository *repository.MockNetworkRepository, alerter *mocks.AlerterInterface) { - loadbalancer := &armnetwork.LoadBalancer{ - Resource: armnetwork.Resource{ - ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress"), - Name: to.StringPtr("TestLoadBalancer"), - }, - } - - repository.On("ListAllLoadBalancers").Return([]*armnetwork.LoadBalancer{loadbalancer}, nil) - - repository.On("ListLoadBalancerRules", loadbalancer).Return([]*armnetwork.LoadBalancingRule{ - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/loadBalancingRules/LBRule"), - }, - Name: to.StringPtr("LBRule"), - }, - { - SubResource: armnetwork.SubResource{ - ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/loadBalancingRules/LBRule2"), - }, - Name: to.StringPtr("LBRule2"), - }, - }, nil).Once() - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockNetworkRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.NetworkRepository = fakeRepo - providerVersion := "2.71.0" - realProvider, err := terraform2.InitTestAzureProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) - if err != nil { - t.Fatal(err) - } - clientOptions := &arm.ClientOptions{} - repo = repository.NewNetworkRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(azurerm.NewAzurermLoadBalancerRuleEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceazure.AzureLoadBalancerRuleResourceType, common.NewGenericDetailsFetcher(resourceazure.AzureLoadBalancerRuleResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceazure.AzureLoadBalancerRuleResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/azurerm_postgresql_scanner_test.go b/pkg/remote/azurerm_postgresql_scanner_test.go deleted file mode 100644 index 0ba882b8..00000000 --- a/pkg/remote/azurerm_postgresql_scanner_test.go +++ /dev/null @@ -1,244 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/azurerm" - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceazure "github.com/snyk/driftctl/pkg/resource/azurerm" - "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestAzurermPostgresqlServer(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockPostgresqlRespository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no postgres server", - mocks: func(repository *repository.MockPostgresqlRespository, alerter *mocks.AlerterInterface) { - repository.On("ListAllServers").Return([]*armpostgresql.Server{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "error listing postgres servers", - mocks: func(repository *repository.MockPostgresqlRespository, alerter *mocks.AlerterInterface) { - repository.On("ListAllServers").Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePostgresqlServerResourceType), - }, - { - test: "multiple postgres servers", - mocks: func(repository *repository.MockPostgresqlRespository, alerter *mocks.AlerterInterface) { - repository.On("ListAllServers").Return([]*armpostgresql.Server{ - { - TrackedResource: armpostgresql.TrackedResource{ - Resource: armpostgresql.Resource{ - ID: to.StringPtr("server1"), - Name: to.StringPtr("server1"), - }, - }, - }, - { - TrackedResource: armpostgresql.TrackedResource{ - Resource: armpostgresql.Resource{ - ID: to.StringPtr("server2"), - Name: to.StringPtr("server2"), - }, - }, - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "server1") - assert.Equal(t, got[0].ResourceType(), resourceazure.AzurePostgresqlServerResourceType) - - assert.Equal(t, got[1].ResourceId(), "server2") - assert.Equal(t, got[1].ResourceType(), resourceazure.AzurePostgresqlServerResourceType) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockPostgresqlRespository{} - c.mocks(fakeRepo, alerter) - - var repo repository.PostgresqlRespository = fakeRepo - - remoteLibrary.AddEnumerator(azurerm.NewAzurermPostgresqlServerEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermPostgresqlDatabase(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockPostgresqlRespository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no postgres database", - mocks: func(repository *repository.MockPostgresqlRespository, alerter *mocks.AlerterInterface) { - repository.On("ListAllServers").Return([]*armpostgresql.Server{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "error listing postgres servers", - mocks: func(repository *repository.MockPostgresqlRespository, alerter *mocks.AlerterInterface) { - repository.On("ListAllServers").Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePostgresqlDatabaseResourceType, resourceazure.AzurePostgresqlServerResourceType), - }, - { - test: "error listing postgres databases", - mocks: func(repository *repository.MockPostgresqlRespository, alerter *mocks.AlerterInterface) { - repository.On("ListAllServers").Return([]*armpostgresql.Server{ - { - TrackedResource: armpostgresql.TrackedResource{ - Resource: armpostgresql.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/api-rg-pro/providers/Microsoft.DBforPostgreSQL/servers/postgresql-server-8791542"), - Name: to.StringPtr("postgresql-server-8791542"), - }, - }, - }, - }, nil).Once() - - repository.On("ListAllDatabasesByServer", mock.IsType(&armpostgresql.Server{})).Return(nil, dummyError).Once() - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePostgresqlDatabaseResourceType), - }, - { - test: "multiple postgres databases", - mocks: func(repository *repository.MockPostgresqlRespository, alerter *mocks.AlerterInterface) { - repository.On("ListAllServers").Return([]*armpostgresql.Server{ - { - TrackedResource: armpostgresql.TrackedResource{ - Resource: armpostgresql.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/api-rg-pro/providers/Microsoft.DBforPostgreSQL/servers/postgresql-server-8791542"), - Name: to.StringPtr("postgresql-server-8791542"), - }, - }, - }, - }, nil).Once() - - repository.On("ListAllDatabasesByServer", mock.IsType(&armpostgresql.Server{})).Return([]*armpostgresql.Database{ - { - ProxyResource: armpostgresql.ProxyResource{ - Resource: armpostgresql.Resource{ - ID: to.StringPtr("db1"), - Name: to.StringPtr("db1"), - }, - }, - }, - { - ProxyResource: armpostgresql.ProxyResource{ - Resource: armpostgresql.Resource{ - ID: to.StringPtr("db2"), - Name: to.StringPtr("db2"), - }, - }, - }, - }, nil).Once() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "db1") - assert.Equal(t, got[0].ResourceType(), resourceazure.AzurePostgresqlDatabaseResourceType) - - assert.Equal(t, got[1].ResourceId(), "db2") - assert.Equal(t, got[1].ResourceType(), resourceazure.AzurePostgresqlDatabaseResourceType) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockPostgresqlRespository{} - c.mocks(fakeRepo, alerter) - - var repo repository.PostgresqlRespository = fakeRepo - - remoteLibrary.AddEnumerator(azurerm.NewAzurermPostgresqlDatabaseEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/azurerm_privatedns_scanner_test.go b/pkg/remote/azurerm_privatedns_scanner_test.go deleted file mode 100644 index a76bc30f..00000000 --- a/pkg/remote/azurerm_privatedns_scanner_test.go +++ /dev/null @@ -1,1218 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/azidentity" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/azurerm" - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceazure "github.com/snyk/driftctl/pkg/resource/azurerm" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testresource "github.com/snyk/driftctl/test/resource" - terraformtest "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestAzurermPrivateDNSZone(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - dirName string - mocks func(*repository.MockPrivateDNSRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no private zone", - dirName: "azurerm_private_dns_private_zone_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) - }, - }, - { - test: "error listing private zones", - dirName: "azurerm_private_dns_private_zone_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSZoneResourceType), - }, - { - test: "multiple private zones", - dirName: "azurerm_private_dns_private_zone_multiple", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), - Name: to.StringPtr("thisisatestusingtf.com"), - }, - }, - }, - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf2.com"), - Name: to.StringPtr("thisisatestusingtf2.com"), - }, - }, - }, - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/testmartin.com"), - Name: to.StringPtr("testmartin.com"), - }, - }, - }, - }, nil) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockPrivateDNSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.PrivateDNSRepository = fakeRepo - providerVersion := "2.71.0" - realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraformtest.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) - if err != nil { - t.Fatal(err) - } - clientOptions := &arm.ClientOptions{} - repo = repository.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(azurerm.NewAzurermPrivateDNSZoneEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSZoneResourceType, common.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSZoneResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSZoneResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermPrivateDNSARecord(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - dirName string - mocks func(*repository.MockPrivateDNSRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no private a record", - dirName: "azurerm_private_dns_a_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) - }, - }, - { - test: "error listing private zone", - dirName: "azurerm_private_dns_a_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePrivateDNSARecordResourceType, resourceazure.AzurePrivateDNSZoneResourceType), - }, - { - test: "error listing private a records", - dirName: "azurerm_private_dns_a_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), - Name: to.StringPtr("thisisatestusingtf.com"), - }, - }, - }, - }, nil) - repository.On("ListAllARecords", mock.Anything).Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSARecordResourceType), - }, - { - test: "multiple private a records", - dirName: "azurerm_private_dns_a_record_multiple", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), - Name: to.StringPtr("thisisatestusingtf.com"), - }, - }, - }, - }, nil) - - repository.On("ListAllARecords", mock.Anything).Return([]*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/A/test"), - Name: to.StringPtr("test"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - ARecords: []*armprivatedns.ARecord{ - {IPv4Address: to.StringPtr("10.0.180.17")}, - {IPv4Address: to.StringPtr("10.0.180.20")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/A/othertest"), - Name: to.StringPtr("othertest"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - ARecords: []*armprivatedns.ARecord{ - {IPv4Address: to.StringPtr("10.0.180.20")}, - }, - }, - }, - }, nil).Once() - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockPrivateDNSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.PrivateDNSRepository = fakeRepo - providerVersion := "2.71.0" - realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraformtest.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) - if err != nil { - t.Fatal(err) - } - clientOptions := &arm.ClientOptions{} - repo = repository.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(azurerm.NewAzurermPrivateDNSARecordEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSARecordResourceType, common.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSARecordResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSARecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermPrivateDNSAAAARecord(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - dirName string - mocks func(*repository.MockPrivateDNSRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no private aaaa record", - dirName: "azurerm_private_dns_aaaa_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) - }, - }, - { - test: "error listing private zone", - dirName: "azurerm_private_dns_aaaa_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePrivateDNSAAAARecordResourceType, resourceazure.AzurePrivateDNSZoneResourceType), - }, - { - test: "error listing private aaaa records", - dirName: "azurerm_private_dns_aaaa_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), - Name: to.StringPtr("thisisatestusingtf.com"), - }, - }, - }, - }, nil) - repository.On("ListAllAAAARecords", mock.Anything).Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSAAAARecordResourceType), - }, - { - test: "multiple private aaaaa records", - dirName: "azurerm_private_dns_aaaaa_record_multiple", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), - Name: to.StringPtr("thisisatestusingtf.com"), - }, - }, - }, - }, nil) - - repository.On("ListAllAAAARecords", mock.Anything).Return([]*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/AAAA/test"), - Name: to.StringPtr("test"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - AaaaRecords: []*armprivatedns.AaaaRecord{ - {IPv6Address: to.StringPtr("fd5d:70bc:930e:d008:0000:0000:0000:7334")}, - {IPv6Address: to.StringPtr("fd5d:70bc:930e:d008::7335")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/AAAA/othertest"), - Name: to.StringPtr("othertest"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - AaaaRecords: []*armprivatedns.AaaaRecord{ - {IPv6Address: to.StringPtr("fd5d:70bc:930e:d008:0000:0000:0000:7334")}, - {IPv6Address: to.StringPtr("fd5d:70bc:930e:d008::7335")}, - }, - }, - }, - }, nil).Once() - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockPrivateDNSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.PrivateDNSRepository = fakeRepo - providerVersion := "2.71.0" - realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraformtest.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) - if err != nil { - t.Fatal(err) - } - clientOptions := &arm.ClientOptions{} - repo = repository.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(azurerm.NewAzurermPrivateDNSAAAARecordEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSAAAARecordResourceType, common.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSAAAARecordResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSAAAARecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermPrivateDNSCNAMERecord(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - dirName string - mocks func(*repository.MockPrivateDNSRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no private cname record", - dirName: "azurerm_private_dns_cname_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) - }, - }, - { - test: "error listing private zone", - dirName: "azurerm_private_dns_cname_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePrivateDNSCNameRecordResourceType, resourceazure.AzurePrivateDNSZoneResourceType), - }, - { - test: "error listing private cname records", - dirName: "azurerm_private_dns_cname_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), - Name: to.StringPtr("thisisatestusingtf.com"), - }, - }, - }, - }, nil) - repository.On("ListAllCNAMERecords", mock.Anything).Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSCNameRecordResourceType), - }, - { - test: "multiple private cname records", - dirName: "azurerm_private_dns_cname_record_multiple", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), - Name: to.StringPtr("thisisatestusingtf.com"), - }, - }, - }, - }, nil) - - repository.On("ListAllCNAMERecords", mock.Anything).Return([]*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/CNAME/test"), - Name: to.StringPtr("test"), - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/CNAME/othertest"), - Name: to.StringPtr("othertest"), - }, - }, - }, - }, nil).Once() - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockPrivateDNSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.PrivateDNSRepository = fakeRepo - providerVersion := "2.71.0" - realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraformtest.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) - if err != nil { - t.Fatal(err) - } - clientOptions := &arm.ClientOptions{} - repo = repository.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(azurerm.NewAzurermPrivateDNSCNameRecordEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSCNameRecordResourceType, common.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSCNameRecordResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSCNameRecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermPrivateDNSPTRRecord(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - dirName string - mocks func(*repository.MockPrivateDNSRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no private ptr record", - dirName: "azurerm_private_dns_ptr_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) - }, - }, - { - test: "error listing private zone", - dirName: "azurerm_private_dns_ptr_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePrivateDNSPTRRecordResourceType, resourceazure.AzurePrivateDNSZoneResourceType), - }, - { - test: "error listing private ptr records", - dirName: "azurerm_private_dns_ptr_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), - Name: to.StringPtr("thisisatestusingtf.com"), - }, - }, - }, - }, nil) - repository.On("ListAllPTRRecords", mock.Anything).Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSPTRRecordResourceType), - }, - { - test: "multiple private ptra records", - dirName: "azurerm_private_dns_ptr_record_multiple", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), - Name: to.StringPtr("thisisatestusingtf.com"), - }, - }, - }, - }, nil) - - repository.On("ListAllPTRRecords", mock.Anything).Return([]*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/PTR/othertestptr"), - Name: to.StringPtr("othertestptr"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - PtrRecords: []*armprivatedns.PtrRecord{ - {Ptrdname: to.StringPtr("ptr1.thisisatestusingtf.com")}, - {Ptrdname: to.StringPtr("ptr2.thisisatestusingtf.com")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/PTR/testptr"), - Name: to.StringPtr("testptr"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - PtrRecords: []*armprivatedns.PtrRecord{ - {Ptrdname: to.StringPtr("ptr3.thisisatestusingtf.com")}, - }, - }, - }, - }, nil).Once() - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockPrivateDNSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.PrivateDNSRepository = fakeRepo - providerVersion := "2.71.0" - realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraformtest.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) - if err != nil { - t.Fatal(err) - } - clientOptions := &arm.ClientOptions{} - repo = repository.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(azurerm.NewAzurermPrivateDNSPTRRecordEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSPTRRecordResourceType, common.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSPTRRecordResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSPTRRecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermPrivateDNSMXRecord(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - dirName string - mocks func(*repository.MockPrivateDNSRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no private mx record", - dirName: "azurerm_private_dns_mx_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) - }, - }, - { - test: "error listing private zone", - dirName: "azurerm_private_dns_mx_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePrivateDNSMXRecordResourceType, resourceazure.AzurePrivateDNSZoneResourceType), - }, - { - test: "error listing private mx records", - dirName: "azurerm_private_dns_mx_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), - Name: to.StringPtr("thisisatestusingtf.com"), - }, - }, - }, - }, nil) - repository.On("ListAllMXRecords", mock.Anything).Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSMXRecordResourceType), - }, - { - test: "multiple private mx records", - dirName: "azurerm_private_dns_mx_record_multiple", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), - Name: to.StringPtr("thisisatestusingtf.com"), - }, - }, - }, - }, nil) - - repository.On("ListAllMXRecords", mock.Anything).Return([]*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/MX/othertestmx"), - Name: to.StringPtr("othertestmx"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - MxRecords: []*armprivatedns.MxRecord{ - {Exchange: to.StringPtr("ex1")}, - {Exchange: to.StringPtr("ex2")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/MX/testmx"), - Name: to.StringPtr("testmx"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - MxRecords: []*armprivatedns.MxRecord{ - {Exchange: to.StringPtr("ex1")}, - {Exchange: to.StringPtr("ex2")}, - }, - }, - }, - }, nil).Once() - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockPrivateDNSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.PrivateDNSRepository = fakeRepo - providerVersion := "2.71.0" - realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraformtest.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) - if err != nil { - t.Fatal(err) - } - clientOptions := &arm.ClientOptions{} - repo = repository.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(azurerm.NewAzurermPrivateDNSMXRecordEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSMXRecordResourceType, common.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSMXRecordResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSMXRecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermPrivateDNSSRVRecord(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - dirName string - mocks func(*repository.MockPrivateDNSRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no private srv record", - dirName: "azurerm_private_dns_srv_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) - }, - }, - { - test: "error listing private zone", - dirName: "azurerm_private_dns_srv_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePrivateDNSSRVRecordResourceType, resourceazure.AzurePrivateDNSZoneResourceType), - }, - { - test: "error listing private srv records", - dirName: "azurerm_private_dns_srv_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), - Name: to.StringPtr("thisisatestusingtf.com"), - }, - }, - }, - }, nil) - repository.On("ListAllSRVRecords", mock.Anything).Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSSRVRecordResourceType), - }, - { - test: "multiple private srv records", - dirName: "azurerm_private_dns_srv_record_multiple", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), - Name: to.StringPtr("thisisatestusingtf.com"), - }, - }, - }, - }, nil) - - repository.On("ListAllSRVRecords", mock.Anything).Return([]*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/SRV/othertestptr"), - Name: to.StringPtr("othertestptr"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - SrvRecords: []*armprivatedns.SrvRecord{ - {Target: to.StringPtr("srv1.thisisatestusingtf.com")}, - {Target: to.StringPtr("srv2.thisisatestusingtf.com")}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/SRV/testptr"), - Name: to.StringPtr("testptr"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - PtrRecords: []*armprivatedns.PtrRecord{ - {Ptrdname: to.StringPtr("srv3.thisisatestusingtf.com")}, - }, - }, - }, - }, nil).Once() - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockPrivateDNSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.PrivateDNSRepository = fakeRepo - providerVersion := "2.71.0" - realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraformtest.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) - if err != nil { - t.Fatal(err) - } - clientOptions := &arm.ClientOptions{} - repo = repository.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(azurerm.NewAzurermPrivateDNSSRVRecordEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSSRVRecordResourceType, common.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSSRVRecordResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSSRVRecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermPrivateDNSTXTRecord(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - dirName string - mocks func(*repository.MockPrivateDNSRepository, *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no private txt record", - dirName: "azurerm_private_dns_txt_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{}, nil) - }, - }, - { - test: "error listing private zone", - dirName: "azurerm_private_dns_txt_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingErrorWithType(dummyError, resourceazure.AzurePrivateDNSTXTRecordResourceType, resourceazure.AzurePrivateDNSZoneResourceType), - }, - { - test: "error listing private txt records", - dirName: "azurerm_private_dns_txt_record_empty", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), - Name: to.StringPtr("thisisatestusingtf.com"), - }, - }, - }, - }, nil) - repository.On("ListAllTXTRecords", mock.Anything).Return(nil, dummyError) - }, - wantErr: remoteerr.NewResourceListingError(dummyError, resourceazure.AzurePrivateDNSTXTRecordResourceType), - }, - { - test: "multiple private txt records", - dirName: "azurerm_private_dns_txt_record_multiple", - mocks: func(repository *repository.MockPrivateDNSRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllPrivateZones").Return([]*armprivatedns.PrivateZone{ - { - TrackedResource: armprivatedns.TrackedResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com"), - Name: to.StringPtr("thisisatestusingtf.com"), - }, - }, - }, - }, nil) - - repository.On("ListAllTXTRecords", mock.Anything).Return([]*armprivatedns.RecordSet{ - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/TXT/othertesttxt"), - Name: to.StringPtr("othertesttxt"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - TxtRecords: []*armprivatedns.TxtRecord{ - {Value: []*string{to.StringPtr("this is value line 1")}}, - {Value: []*string{to.StringPtr("this is value line 2")}}, - }, - }, - }, - { - ProxyResource: armprivatedns.ProxyResource{ - Resource: armprivatedns.Resource{ - ID: to.StringPtr("/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/TXT/testtxt"), - Name: to.StringPtr("testtxt"), - }, - }, - Properties: &armprivatedns.RecordSetProperties{ - PtrRecords: []*armprivatedns.PtrRecord{ - {Ptrdname: to.StringPtr("this is value line 3")}, - }, - }, - }, - }, nil).Once() - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockPrivateDNSRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.PrivateDNSRepository = fakeRepo - providerVersion := "2.71.0" - realProvider, err := terraformtest.InitTestAzureProvider(providerLibrary, providerVersion) - if err != nil { - t.Fatal(err) - } - provider := terraformtest.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - cred, err := azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{}) - if err != nil { - t.Fatal(err) - } - clientOptions := &arm.ClientOptions{} - repo = repository.NewPrivateDNSRepository(cred, clientOptions, realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(azurerm.NewAzurermPrivateDNSTXTRecordEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resourceazure.AzurePrivateDNSTXTRecordResourceType, common.NewGenericDetailsFetcher(resourceazure.AzurePrivateDNSTXTRecordResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - - if err != nil { - return - } - test.TestAgainstGoldenFile(got, resourceazure.AzurePrivateDNSTXTRecordResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/azurerm_resources_scanner_test.go b/pkg/remote/azurerm_resources_scanner_test.go deleted file mode 100644 index 44fca908..00000000 --- a/pkg/remote/azurerm_resources_scanner_test.go +++ /dev/null @@ -1,110 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/azurerm" - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - "github.com/snyk/driftctl/pkg/remote/common" - error2 "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceazure "github.com/snyk/driftctl/pkg/resource/azurerm" - "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestAzurermResourceGroups(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockResourcesRepository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no resource group", - mocks: func(repository *repository.MockResourcesRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllResourceGroups").Return([]*armresources.ResourceGroup{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "error listing resource groups", - mocks: func(repository *repository.MockResourcesRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllResourceGroups").Return(nil, dummyError) - }, - wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureResourceGroupResourceType), - }, - { - test: "multiple resource groups", - mocks: func(repository *repository.MockResourcesRepository, alerter *mocks.AlerterInterface) { - repository.On("ListAllResourceGroups").Return([]*armresources.ResourceGroup{ - { - ID: to.StringPtr("group1"), - Name: to.StringPtr("group1"), - }, - { - ID: to.StringPtr("group2"), - Name: to.StringPtr("group2"), - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "group1") - assert.Equal(t, got[0].ResourceType(), resourceazure.AzureResourceGroupResourceType) - - assert.Equal(t, got[1].ResourceId(), "group2") - assert.Equal(t, got[1].ResourceType(), resourceazure.AzureResourceGroupResourceType) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockResourcesRepository{} - c.mocks(fakeRepo, alerter) - - var repo repository.ResourcesRepository = fakeRepo - - remoteLibrary.AddEnumerator(azurerm.NewAzurermResourceGroupEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/azurerm_storage_scanner_test.go b/pkg/remote/azurerm_storage_scanner_test.go deleted file mode 100644 index 54804640..00000000 --- a/pkg/remote/azurerm_storage_scanner_test.go +++ /dev/null @@ -1,260 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/azurerm" - "github.com/snyk/driftctl/pkg/remote/azurerm/repository" - "github.com/snyk/driftctl/pkg/remote/common" - error2 "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - resourceazure "github.com/snyk/driftctl/pkg/resource/azurerm" - "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestAzurermStorageAccount(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockStorageRespository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no storage accounts", - mocks: func(repository *repository.MockStorageRespository, alerter *mocks.AlerterInterface) { - repository.On("ListAllStorageAccount").Return([]*armstorage.StorageAccount{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "error listing storage accounts", - mocks: func(repository *repository.MockStorageRespository, alerter *mocks.AlerterInterface) { - repository.On("ListAllStorageAccount").Return(nil, dummyError) - }, - wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureStorageAccountResourceType), - }, - { - test: "multiple storage accounts", - mocks: func(repository *repository.MockStorageRespository, alerter *mocks.AlerterInterface) { - repository.On("ListAllStorageAccount").Return([]*armstorage.StorageAccount{ - { - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("testeliedriftctl1"), - }, - }, - }, - { - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("testeliedriftctl2"), - }, - }, - }, - }, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "testeliedriftctl1") - assert.Equal(t, got[0].ResourceType(), resourceazure.AzureStorageAccountResourceType) - - assert.Equal(t, got[1].ResourceId(), "testeliedriftctl2") - assert.Equal(t, got[1].ResourceType(), resourceazure.AzureStorageAccountResourceType) - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockStorageRespository{} - c.mocks(fakeRepo, alerter) - - var repo repository.StorageRespository = fakeRepo - - remoteLibrary.AddEnumerator(azurerm.NewAzurermStorageAccountEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} - -func TestAzurermStorageContainer(t *testing.T) { - - dummyError := errors.New("this is an error") - - tests := []struct { - test string - mocks func(*repository.MockStorageRespository, *mocks.AlerterInterface) - assertExpected func(t *testing.T, got []*resource.Resource) - wantErr error - }{ - { - test: "no storage accounts", - mocks: func(repository *repository.MockStorageRespository, alerter *mocks.AlerterInterface) { - repository.On("ListAllStorageAccount").Return([]*armstorage.StorageAccount{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "no storage containers", - mocks: func(repository *repository.MockStorageRespository, alerter *mocks.AlerterInterface) { - account1 := &armstorage.StorageAccount{ - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("testeliedriftctl1"), - }, - }, - } - account2 := &armstorage.StorageAccount{ - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("testeliedriftctl1"), - }, - }, - } - repository.On("ListAllStorageAccount").Return([]*armstorage.StorageAccount{ - account1, - account2, - }, nil) - repository.On("ListAllStorageContainer", account1).Return([]string{}, nil) - repository.On("ListAllStorageContainer", account2).Return([]string{}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "error listing storage accounts", - mocks: func(repository *repository.MockStorageRespository, alerter *mocks.AlerterInterface) { - repository.On("ListAllStorageAccount").Return(nil, dummyError) - }, - wantErr: error2.NewResourceListingErrorWithType(dummyError, resourceazure.AzureStorageContainerResourceType, resourceazure.AzureStorageAccountResourceType), - }, - { - test: "error listing storage container", - mocks: func(repository *repository.MockStorageRespository, alerter *mocks.AlerterInterface) { - account := &armstorage.StorageAccount{ - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("testeliedriftctl1"), - }, - }, - } - repository.On("ListAllStorageAccount").Return([]*armstorage.StorageAccount{account}, nil) - repository.On("ListAllStorageContainer", account).Return(nil, dummyError) - }, - wantErr: error2.NewResourceListingError(dummyError, resourceazure.AzureStorageContainerResourceType), - }, - { - test: "multiple storage containers", - mocks: func(repository *repository.MockStorageRespository, alerter *mocks.AlerterInterface) { - account1 := &armstorage.StorageAccount{ - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("testeliedriftctl1"), - }, - }, - } - account2 := &armstorage.StorageAccount{ - TrackedResource: armstorage.TrackedResource{ - Resource: armstorage.Resource{ - ID: func(s string) *string { return &s }("testeliedriftctl2"), - }, - }, - } - repository.On("ListAllStorageAccount").Return([]*armstorage.StorageAccount{ - account1, - account2, - }, nil) - repository.On("ListAllStorageContainer", account1).Return([]string{"https://testeliedriftctl1.blob.core.windows.net/container1", "https://testeliedriftctl1.blob.core.windows.net/container2"}, nil) - repository.On("ListAllStorageContainer", account2).Return([]string{"https://testeliedriftctl2.blob.core.windows.net/container3", "https://testeliedriftctl2.blob.core.windows.net/container4"}, nil) - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 4) - - for _, container := range got { - assert.Equal(t, container.ResourceType(), resourceazure.AzureStorageContainerResourceType) - } - - assert.Equal(t, got[0].ResourceId(), "https://testeliedriftctl1.blob.core.windows.net/container1") - assert.Equal(t, got[1].ResourceId(), "https://testeliedriftctl1.blob.core.windows.net/container2") - assert.Equal(t, got[2].ResourceId(), "https://testeliedriftctl2.blob.core.windows.net/container3") - assert.Equal(t, got[3].ResourceId(), "https://testeliedriftctl2.blob.core.windows.net/container4") - }, - }, - } - - providerVersion := "2.71.0" - schemaRepository := testresource.InitFakeSchemaRepository("azurerm", providerVersion) - resourceazure.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - - scanOptions := ScannerOptions{} - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - fakeRepo := &repository.MockStorageRespository{} - c.mocks(fakeRepo, alerter) - - var repo repository.StorageRespository = fakeRepo - - remoteLibrary.AddEnumerator(azurerm.NewAzurermStorageContainerEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - c.assertExpected(tt, got) - alerter.AssertExpectations(tt) - fakeRepo.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/cache/cache_test.go b/pkg/remote/cache/cache_test.go deleted file mode 100644 index 9f2dc630..00000000 --- a/pkg/remote/cache/cache_test.go +++ /dev/null @@ -1,157 +0,0 @@ -package cache - -import ( - "fmt" - "sync" - "testing" - "time" - - "github.com/snyk/driftctl/pkg/resource" - "github.com/stretchr/testify/assert" -) - -func BenchmarkCache(b *testing.B) { - cache := New(500) - for i := 0; i < b.N; i++ { - key := fmt.Sprintf("test-key-%d", i) - data := make([]*resource.Resource, 1024) - assert.Equal(b, false, cache.Put(key, data)) - assert.Equal(b, data, cache.Get(key)) - } -} - -func TestCache(t *testing.T) { - t.Run("should return nil on non-existing key", func(t *testing.T) { - cache := New(5) - assert.Equal(t, nil, cache.Get("test")) - assert.Equal(t, 0, cache.Len()) - }) - - t.Run("should retrieve newly added key", func(t *testing.T) { - cache := New(5) - assert.Equal(t, false, cache.Put("s3", []string{})) - assert.Equal(t, []string{}, cache.Get("s3")) - assert.Equal(t, 1, cache.Len()) - }) - - t.Run("should override existing key", func(t *testing.T) { - cache := New(5) - assert.Equal(t, false, cache.Put("s3", []string{})) - assert.Equal(t, []string{}, cache.Get("s3")) - - assert.Equal(t, true, cache.Put("s3", []string{"test"})) - assert.Equal(t, []string{"test"}, cache.Get("s3")) - assert.Equal(t, 1, cache.Len()) - }) - - t.Run("should delete the least used keys", func(t *testing.T) { - keys := []struct { - key string - value interface{} - }{ - {key: "test-0", value: nil}, - {key: "test-1", value: nil}, - {key: "test-2", value: nil}, - {key: "test-3", value: nil}, - {key: "test-4", value: nil}, - {key: "test-5", value: nil}, - {key: "test-6", value: "value"}, - {key: "test-7", value: "value"}, - {key: "test-8", value: "value"}, - {key: "test-9", value: "value"}, - {key: "test-10", value: "value"}, - } - - cache := New(5) - for i := 0; i <= 10; i++ { - cache.Put(fmt.Sprintf("test-%d", i), "value") - } - for _, k := range keys { - assert.Equal(t, k.value, cache.Get(k.key)) - } - assert.Equal(t, 5, cache.Len()) - }) - - t.Run("should ignore keys when capacity is 0", func(t *testing.T) { - keys := []struct { - key string - value interface{} - }{ - { - "test", - []string{"slice"}, - }, - { - "test", - []string{}, - }, - { - "test2", - []*resource.Resource{}, - }, - } - cache := New(0) - - for _, k := range keys { - assert.Equal(t, false, cache.Put(k.key, k.value)) - assert.Equal(t, nil, cache.Get(k.key)) - } - assert.Equal(t, 0, cache.Len()) - }) - - t.Run("cache will not panic for parallel calls", func(t *testing.T) { - key := "sameKeyForMultiplesRoutines" - - cache := New(1) - - wg := sync.WaitGroup{} - missCount := 0 - - for i := 0; i < 100; i++ { - wg.Add(1) - go func() { - defer wg.Done() - hit := cache.Get(key) - if hit != nil { - return - } - missCount++ - time.Sleep(10 * time.Millisecond) - cache.Put(key, "value") - }() - } - - wg.Wait() - assert.Equal(t, cache.Get(key), "value") - assert.Greater(t, missCount, 1) - }) - - t.Run("cache should be missed only once with parallel calls and GetAndLock usage", func(t *testing.T) { - key := "sameKeyForMultiplesRoutines" - - cache := New(1) - - nbRoutines := 100 - wg := sync.WaitGroup{} - wg.Add(nbRoutines) - - missCount := 0 - for i := 0; i < nbRoutines; i++ { - go func() { - defer wg.Done() - hit := cache.GetAndLock(key) - defer cache.Unlock(key) - if hit != nil { - return - } - missCount++ - time.Sleep(1 * time.Millisecond) - cache.Put(key, "value") - }() - } - - wg.Wait() - assert.Equal(t, cache.Get(key), "value") - assert.Equal(t, 1, missCount) - }) -} diff --git a/pkg/remote/common/details_fetcher.go b/pkg/remote/common/details_fetcher.go deleted file mode 100644 index 6b72cd65..00000000 --- a/pkg/remote/common/details_fetcher.go +++ /dev/null @@ -1,54 +0,0 @@ -package common - -import ( - "github.com/sirupsen/logrus" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/terraform" -) - -type DetailsFetcher interface { - ReadDetails(*resource.Resource) (*resource.Resource, error) -} - -type GenericDetailsFetcher struct { - resType resource.ResourceType - reader terraform.ResourceReader - deserializer *resource.Deserializer -} - -func NewGenericDetailsFetcher(resType resource.ResourceType, provider terraform.ResourceReader, deserializer *resource.Deserializer) *GenericDetailsFetcher { - return &GenericDetailsFetcher{ - resType: resType, - reader: provider, - deserializer: deserializer, - } -} - -func (f *GenericDetailsFetcher) ReadDetails(res *resource.Resource) (*resource.Resource, error) { - attributes := map[string]string{} - if res.Schema().ResolveReadAttributesFunc != nil { - attributes = res.Schema().ResolveReadAttributesFunc(res) - } - ctyVal, err := f.reader.ReadResource(terraform.ReadResourceArgs{ - Ty: f.resType, - ID: res.ResourceId(), - Attributes: attributes, - }) - if err != nil { - return nil, remoteerror.NewResourceScanningError(err, res.ResourceType(), res.ResourceId()) - } - if ctyVal.IsNull() { - logrus.WithFields(logrus.Fields{ - "type": f.resType, - "id": res.ResourceId(), - }).Debug("Got null while reading resource details") - return nil, nil - } - deserializedRes, err := f.deserializer.DeserializeOne(string(f.resType), *ctyVal) - if err != nil { - return nil, err - } - - return deserializedRes, nil -} diff --git a/pkg/remote/common/library.go b/pkg/remote/common/library.go deleted file mode 100644 index f28e4ef2..00000000 --- a/pkg/remote/common/library.go +++ /dev/null @@ -1,38 +0,0 @@ -package common - -import ( - "github.com/snyk/driftctl/pkg/resource" -) - -type Enumerator interface { - SupportedType() resource.ResourceType - Enumerate() ([]*resource.Resource, error) -} - -type RemoteLibrary struct { - enumerators []Enumerator - detailsFetchers map[resource.ResourceType]DetailsFetcher -} - -func NewRemoteLibrary() *RemoteLibrary { - return &RemoteLibrary{ - make([]Enumerator, 0), - make(map[resource.ResourceType]DetailsFetcher), - } -} - -func (r *RemoteLibrary) AddEnumerator(enumerator Enumerator) { - r.enumerators = append(r.enumerators, enumerator) -} - -func (r *RemoteLibrary) Enumerators() []Enumerator { - return r.enumerators -} - -func (r *RemoteLibrary) AddDetailsFetcher(ty resource.ResourceType, detailsFetcher DetailsFetcher) { - r.detailsFetchers[ty] = detailsFetcher -} - -func (r *RemoteLibrary) GetDetailsFetcher(ty resource.ResourceType) DetailsFetcher { - return r.detailsFetchers[ty] -} diff --git a/pkg/remote/common/mock_Enumerator.go b/pkg/remote/common/mock_Enumerator.go deleted file mode 100644 index 339d5f6a..00000000 --- a/pkg/remote/common/mock_Enumerator.go +++ /dev/null @@ -1,50 +0,0 @@ -// Code generated by mockery v0.0.0-dev. DO NOT EDIT. - -package common - -import ( - resource "github.com/snyk/driftctl/pkg/resource" - mock "github.com/stretchr/testify/mock" -) - -// MockEnumerator is an autogenerated mock type for the Enumerator type -type MockEnumerator struct { - mock.Mock -} - -// Enumerate provides a mock function with given fields: -func (_m *MockEnumerator) Enumerate() ([]*resource.Resource, error) { - ret := _m.Called() - - var r0 []*resource.Resource - if rf, ok := ret.Get(0).(func() []*resource.Resource); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]*resource.Resource) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// SupportedType provides a mock function with given fields: -func (_m *MockEnumerator) SupportedType() resource.ResourceType { - ret := _m.Called() - - var r0 resource.ResourceType - if rf, ok := ret.Get(0).(func() resource.ResourceType); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(resource.ResourceType) - } - - return r0 -} diff --git a/pkg/remote/common/providers.go b/pkg/remote/common/providers.go deleted file mode 100644 index c8d46904..00000000 --- a/pkg/remote/common/providers.go +++ /dev/null @@ -1,30 +0,0 @@ -package common - -import ( - tf "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/pkg/terraform/lock" -) - -type RemoteParameter string - -const ( - RemoteAWSTerraform = "aws+tf" - RemoteGithubTerraform = "github+tf" - RemoteGoogleTerraform = "gcp+tf" - RemoteAzureTerraform = "azure+tf" -) - -var remoteParameterMapping = map[RemoteParameter]string{ - RemoteAWSTerraform: tf.AWS, - RemoteGithubTerraform: tf.GITHUB, - RemoteGoogleTerraform: tf.GOOGLE, - RemoteAzureTerraform: tf.AZURE, -} - -func (p RemoteParameter) GetProviderAddress() *lock.ProviderAddress { - return &lock.ProviderAddress{ - Hostname: "registry.terraform.io", - Namespace: "hashicorp", - Type: remoteParameterMapping[p], - } -} diff --git a/pkg/remote/github/github_branch_protection_enumerator.go b/pkg/remote/github/github_branch_protection_enumerator.go deleted file mode 100644 index ab05579b..00000000 --- a/pkg/remote/github/github_branch_protection_enumerator.go +++ /dev/null @@ -1,45 +0,0 @@ -package github - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/github" -) - -type GithubBranchProtectionEnumerator struct { - repository GithubRepository - factory resource.ResourceFactory -} - -func NewGithubBranchProtectionEnumerator(repo GithubRepository, factory resource.ResourceFactory) *GithubBranchProtectionEnumerator { - return &GithubBranchProtectionEnumerator{ - repository: repo, - factory: factory, - } -} - -func (g *GithubBranchProtectionEnumerator) SupportedType() resource.ResourceType { - return github.GithubBranchProtectionResourceType -} - -func (g *GithubBranchProtectionEnumerator) Enumerate() ([]*resource.Resource, error) { - ids, err := g.repository.ListBranchProtection() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(g.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(ids)) - - for _, id := range ids { - results = append( - results, - g.factory.CreateAbstractResource( - string(g.SupportedType()), - id, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/github/github_membership_enumerator.go b/pkg/remote/github/github_membership_enumerator.go deleted file mode 100644 index 9c5db5b7..00000000 --- a/pkg/remote/github/github_membership_enumerator.go +++ /dev/null @@ -1,45 +0,0 @@ -package github - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/github" -) - -type GithubMembershipEnumerator struct { - Membership GithubRepository - factory resource.ResourceFactory -} - -func NewGithubMembershipEnumerator(repo GithubRepository, factory resource.ResourceFactory) *GithubMembershipEnumerator { - return &GithubMembershipEnumerator{ - Membership: repo, - factory: factory, - } -} - -func (g *GithubMembershipEnumerator) SupportedType() resource.ResourceType { - return github.GithubMembershipResourceType -} - -func (g *GithubMembershipEnumerator) Enumerate() ([]*resource.Resource, error) { - ids, err := g.Membership.ListMembership() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(g.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(ids)) - - for _, id := range ids { - results = append( - results, - g.factory.CreateAbstractResource( - string(g.SupportedType()), - id, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/github/github_repository_enumerator.go b/pkg/remote/github/github_repository_enumerator.go deleted file mode 100644 index 6d624486..00000000 --- a/pkg/remote/github/github_repository_enumerator.go +++ /dev/null @@ -1,45 +0,0 @@ -package github - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/github" -) - -type GithubRepositoryEnumerator struct { - repository GithubRepository - factory resource.ResourceFactory -} - -func NewGithubRepositoryEnumerator(repo GithubRepository, factory resource.ResourceFactory) *GithubRepositoryEnumerator { - return &GithubRepositoryEnumerator{ - repository: repo, - factory: factory, - } -} - -func (g *GithubRepositoryEnumerator) SupportedType() resource.ResourceType { - return github.GithubRepositoryResourceType -} - -func (g *GithubRepositoryEnumerator) Enumerate() ([]*resource.Resource, error) { - ids, err := g.repository.ListRepositories() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(g.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(ids)) - - for _, id := range ids { - results = append( - results, - g.factory.CreateAbstractResource( - string(g.SupportedType()), - id, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/github/github_team_enumerator.go b/pkg/remote/github/github_team_enumerator.go deleted file mode 100644 index 98b195b4..00000000 --- a/pkg/remote/github/github_team_enumerator.go +++ /dev/null @@ -1,47 +0,0 @@ -package github - -import ( - "fmt" - - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/github" -) - -type GithubTeamEnumerator struct { - repository GithubRepository - factory resource.ResourceFactory -} - -func NewGithubTeamEnumerator(repo GithubRepository, factory resource.ResourceFactory) *GithubTeamEnumerator { - return &GithubTeamEnumerator{ - repository: repo, - factory: factory, - } -} - -func (g *GithubTeamEnumerator) SupportedType() resource.ResourceType { - return github.GithubTeamResourceType -} - -func (g *GithubTeamEnumerator) Enumerate() ([]*resource.Resource, error) { - resourceList, err := g.repository.ListTeams() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(g.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resourceList)) - - for _, team := range resourceList { - results = append( - results, - g.factory.CreateAbstractResource( - string(g.SupportedType()), - fmt.Sprintf("%d", team.DatabaseId), - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/github/github_team_membership_enumerator.go b/pkg/remote/github/github_team_membership_enumerator.go deleted file mode 100644 index 7f87be72..00000000 --- a/pkg/remote/github/github_team_membership_enumerator.go +++ /dev/null @@ -1,45 +0,0 @@ -package github - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/github" -) - -type GithubTeamMembershipEnumerator struct { - repository GithubRepository - factory resource.ResourceFactory -} - -func NewGithubTeamMembershipEnumerator(repo GithubRepository, factory resource.ResourceFactory) *GithubTeamMembershipEnumerator { - return &GithubTeamMembershipEnumerator{ - repository: repo, - factory: factory, - } -} - -func (g *GithubTeamMembershipEnumerator) SupportedType() resource.ResourceType { - return github.GithubTeamMembershipResourceType -} - -func (g *GithubTeamMembershipEnumerator) Enumerate() ([]*resource.Resource, error) { - ids, err := g.repository.ListTeamMemberships() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(g.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(ids)) - - for _, id := range ids { - results = append( - results, - g.factory.CreateAbstractResource( - string(g.SupportedType()), - id, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/github/init.go b/pkg/remote/github/init.go deleted file mode 100644 index 68924cbd..00000000 --- a/pkg/remote/github/init.go +++ /dev/null @@ -1,63 +0,0 @@ -package github - -import ( - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/github" - "github.com/snyk/driftctl/pkg/terraform" -) - -/** - * Initialize remote (configure credentials, launch tf providers and start gRPC clients) - * Required to use Scanner - */ - -func Init(version string, alerter *alerter.Alerter, - providerLibrary *terraform.ProviderLibrary, - remoteLibrary *common.RemoteLibrary, - progress output.Progress, - resourceSchemaRepository *resource.SchemaRepository, - factory resource.ResourceFactory, - configDir string) error { - - provider, err := NewGithubTerraformProvider(version, progress, configDir) - if err != nil { - return err - } - err = provider.Init() - if err != nil { - return err - } - - repositoryCache := cache.New(100) - - repository := NewGithubRepository(provider.GetConfig(), repositoryCache) - deserializer := resource.NewDeserializer(factory) - providerLibrary.AddProvider(terraform.GITHUB, provider) - - remoteLibrary.AddEnumerator(NewGithubTeamEnumerator(repository, factory)) - remoteLibrary.AddDetailsFetcher(github.GithubTeamResourceType, common.NewGenericDetailsFetcher(github.GithubTeamResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewGithubRepositoryEnumerator(repository, factory)) - remoteLibrary.AddDetailsFetcher(github.GithubRepositoryResourceType, common.NewGenericDetailsFetcher(github.GithubRepositoryResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewGithubMembershipEnumerator(repository, factory)) - remoteLibrary.AddDetailsFetcher(github.GithubMembershipResourceType, common.NewGenericDetailsFetcher(github.GithubMembershipResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewGithubTeamMembershipEnumerator(repository, factory)) - remoteLibrary.AddDetailsFetcher(github.GithubTeamMembershipResourceType, common.NewGenericDetailsFetcher(github.GithubTeamMembershipResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewGithubBranchProtectionEnumerator(repository, factory)) - remoteLibrary.AddDetailsFetcher(github.GithubBranchProtectionResourceType, common.NewGenericDetailsFetcher(github.GithubBranchProtectionResourceType, provider, deserializer)) - - err = resourceSchemaRepository.Init(terraform.GITHUB, provider.Version(), provider.Schema()) - if err != nil { - return err - } - github.InitResourcesMetadata(resourceSchemaRepository) - - return nil -} diff --git a/pkg/remote/github/provider.go b/pkg/remote/github/provider.go deleted file mode 100644 index b893e554..00000000 --- a/pkg/remote/github/provider.go +++ /dev/null @@ -1,76 +0,0 @@ -package github - -import ( - "os" - - "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/remote/terraform" - tf "github.com/snyk/driftctl/pkg/terraform" -) - -type GithubTerraformProvider struct { - *terraform.TerraformProvider - name string - version string -} - -type githubConfig struct { - Token string - Owner string `cty:"owner"` - Organization string -} - -func NewGithubTerraformProvider(version string, progress output.Progress, configDir string) (*GithubTerraformProvider, error) { - if version == "" { - version = "4.4.0" - } - p := &GithubTerraformProvider{ - version: version, - name: "github", - } - installer, err := tf.NewProviderInstaller(tf.ProviderConfig{ - Key: p.name, - Version: version, - ConfigDir: configDir, - }) - if err != nil { - return nil, err - } - tfProvider, err := terraform.NewTerraformProvider(installer, terraform.TerraformProviderConfig{ - Name: p.name, - DefaultAlias: p.GetConfig().getDefaultOwner(), - GetProviderConfig: func(owner string) interface{} { - return githubConfig{ - Owner: p.GetConfig().getDefaultOwner(), - } - }, - }, progress) - if err != nil { - return nil, err - } - p.TerraformProvider = tfProvider - return p, err -} - -func (c githubConfig) getDefaultOwner() string { - if c.Organization != "" { - return c.Organization - } - return c.Owner -} - -func (p GithubTerraformProvider) GetConfig() githubConfig { - return githubConfig{ - Token: os.Getenv("GITHUB_TOKEN"), - Owner: os.Getenv("GITHUB_OWNER"), - Organization: os.Getenv("GITHUB_ORGANIZATION"), - } -} - -func (p *GithubTerraformProvider) Name() string { - return p.name -} - -func (p *GithubTerraformProvider) Version() string { - return p.version -} diff --git a/pkg/remote/github/repository.go b/pkg/remote/github/repository.go deleted file mode 100644 index 25bfdb54..00000000 --- a/pkg/remote/github/repository.go +++ /dev/null @@ -1,361 +0,0 @@ -package github - -import ( - "context" - "fmt" - - "github.com/shurcooL/githubv4" - "github.com/snyk/driftctl/pkg/remote/cache" - "golang.org/x/oauth2" -) - -type GithubRepository interface { - ListRepositories() ([]string, error) - ListTeams() ([]Team, error) - ListMembership() ([]string, error) - ListTeamMemberships() ([]string, error) - ListBranchProtection() ([]string, error) -} - -type GithubGraphQLClient interface { - Query(ctx context.Context, q interface{}, variables map[string]interface{}) error -} - -type githubRepository struct { - client GithubGraphQLClient - ctx context.Context - config githubConfig - cache cache.Cache -} - -func NewGithubRepository(config githubConfig, c cache.Cache) *githubRepository { - ctx := context.Background() - ts := oauth2.StaticTokenSource( - &oauth2.Token{AccessToken: config.Token}, - ) - oauthClient := oauth2.NewClient(ctx, ts) - - repo := &githubRepository{ - client: githubv4.NewClient(oauthClient), - ctx: context.Background(), - config: config, - cache: c, - } - - return repo -} - -func (r *githubRepository) ListRepositories() ([]string, error) { - if v := r.cache.Get("githubListRepositories"); v != nil { - return v.([]string), nil - } - - if r.config.Organization != "" { - results, err := r.listRepoForOrg() - if err != nil { - return nil, err - } - r.cache.Put("githubListRepositories", results) - return results, nil - } - - results, err := r.listRepoForOwner() - if err != nil { - return nil, err - } - r.cache.Put("githubListRepositories", results) - return results, nil -} - -type pageInfo struct { - EndCursor githubv4.String - HasNextPage bool -} - -type listRepoForOrgQuery struct { - Organization struct { - Repositories struct { - Nodes []struct { - Name string - } - PageInfo pageInfo - } `graphql:"repositories(first: 100, after: $cursor)"` - } `graphql:"organization(login: $org)"` -} - -func (r *githubRepository) listRepoForOrg() ([]string, error) { - query := listRepoForOrgQuery{} - variables := map[string]interface{}{ - "org": (githubv4.String)(r.config.Organization), - "cursor": (*githubv4.String)(nil), - } - var results []string - for { - err := r.client.Query(r.ctx, &query, variables) - if err != nil { - return nil, err - } - for _, repo := range query.Organization.Repositories.Nodes { - results = append(results, repo.Name) - } - if !query.Organization.Repositories.PageInfo.HasNextPage { - break - } - variables["cursor"] = githubv4.NewString(query.Organization.Repositories.PageInfo.EndCursor) - } - return results, nil -} - -type listRepoForOwnerQuery struct { - Viewer struct { - Repositories struct { - Nodes []struct { - Name string - } - PageInfo struct { - EndCursor githubv4.String - HasNextPage bool - } - } `graphql:"repositories(first: 100, after: $cursor)"` - } -} - -func (r githubRepository) listRepoForOwner() ([]string, error) { - query := listRepoForOwnerQuery{} - variables := map[string]interface{}{ - "cursor": (*githubv4.String)(nil), - } - var results []string - for { - err := r.client.Query(r.ctx, &query, variables) - if err != nil { - return nil, err - } - for _, repo := range query.Viewer.Repositories.Nodes { - results = append(results, repo.Name) - } - if !query.Viewer.Repositories.PageInfo.HasNextPage { - break - } - variables["cursor"] = githubv4.NewString(query.Viewer.Repositories.PageInfo.EndCursor) - } - return results, nil -} - -type listTeamsQuery struct { - Organization struct { - Teams struct { - Nodes []struct { - DatabaseId int - Slug string - } - PageInfo struct { - EndCursor githubv4.String - HasNextPage bool - } - } `graphql:"teams(first: 100, after: $cursor)"` - } `graphql:"organization(login: $login)"` -} - -type Team struct { - DatabaseId int - Slug string -} - -func (r githubRepository) ListTeams() ([]Team, error) { - if v := r.cache.Get("githubListTeams"); v != nil { - return v.([]Team), nil - } - - query := listTeamsQuery{} - results := make([]Team, 0) - if r.config.Organization == "" { - r.cache.Put("githubListTeams", results) - return results, nil - } - variables := map[string]interface{}{ - "cursor": (*githubv4.String)(nil), - "login": (githubv4.String)(r.config.Organization), - } - for { - err := r.client.Query(r.ctx, &query, variables) - if err != nil { - return nil, err - } - for _, team := range query.Organization.Teams.Nodes { - results = append(results, Team{ - DatabaseId: team.DatabaseId, - Slug: team.Slug, - }) - } - if !query.Organization.Teams.PageInfo.HasNextPage { - break - } - variables["cursor"] = githubv4.NewString(query.Organization.Teams.PageInfo.EndCursor) - } - - r.cache.Put("githubListTeams", results) - return results, nil -} - -type listMembership struct { - Organization struct { - MembersWithRole struct { - Nodes []struct { - Login string - } - PageInfo struct { - EndCursor githubv4.String - HasNextPage bool - } - } `graphql:"membersWithRole(first: 100, after: $cursor)"` - } `graphql:"organization(login: $login)"` -} - -func (r *githubRepository) ListMembership() ([]string, error) { - if v := r.cache.Get("githubListMembership"); v != nil { - return v.([]string), nil - } - - query := listMembership{} - results := make([]string, 0) - if r.config.Organization == "" { - r.cache.Put("githubListMembership", results) - return results, nil - } - variables := map[string]interface{}{ - "cursor": (*githubv4.String)(nil), - "login": (githubv4.String)(r.config.Organization), - } - for { - err := r.client.Query(r.ctx, &query, variables) - if err != nil { - return nil, err - } - for _, membership := range query.Organization.MembersWithRole.Nodes { - results = append(results, fmt.Sprintf("%s:%s", r.config.Organization, membership.Login)) - } - if !query.Organization.MembersWithRole.PageInfo.HasNextPage { - break - } - variables["cursor"] = githubv4.NewString(query.Organization.MembersWithRole.PageInfo.EndCursor) - } - - r.cache.Put("githubListMembership", results) - return results, nil -} - -type listTeamMembershipsQuery struct { - Organization struct { - Team struct { - Members struct { - Nodes []struct { - Login string - } - PageInfo struct { - EndCursor githubv4.String - HasNextPage bool - } - } `graphql:"members(first: 100, after: $cursor)"` - } `graphql:"team(slug: $slug)"` - } `graphql:"organization(login: $login)"` -} - -func (r githubRepository) ListTeamMemberships() ([]string, error) { - if v := r.cache.Get("githubListTeamMemberships"); v != nil { - return v.([]string), nil - } - - teamList, err := r.ListTeams() - if err != nil { - return nil, err - } - - query := listTeamMembershipsQuery{} - results := make([]string, 0) - if r.config.Organization == "" { - r.cache.Put("githubListTeamMemberships", results) - return results, nil - } - variables := map[string]interface{}{ - "login": (githubv4.String)(r.config.Organization), - } - - for _, team := range teamList { - variables["slug"] = (githubv4.String)(team.Slug) - variables["cursor"] = (*githubv4.String)(nil) - for { - err := r.client.Query(r.ctx, &query, variables) - if err != nil { - return nil, err - } - for _, membership := range query.Organization.Team.Members.Nodes { - results = append(results, fmt.Sprintf("%d:%s", team.DatabaseId, membership.Login)) - } - if !query.Organization.Team.Members.PageInfo.HasNextPage { - break - } - variables["cursor"] = query.Organization.Team.Members.PageInfo.EndCursor - } - } - - r.cache.Put("githubListTeamMemberships", results) - return results, nil -} - -type listBranchProtectionQuery struct { - Repository struct { - BranchProtectionRules struct { - Nodes []struct { - Id string - } - PageInfo struct { - EndCursor githubv4.String - HasNextPage bool - } - } `graphql:"branchProtectionRules(first: 1, after: $cursor)"` - } `graphql:"repository(owner: $owner, name: $name)"` -} - -func (r *githubRepository) ListBranchProtection() ([]string, error) { - if v := r.cache.Get("githubListBranchProtection"); v != nil { - return v.([]string), nil - } - - repoList, err := r.ListRepositories() - if err != nil { - return nil, err - } - - results := make([]string, 0) - query := listBranchProtectionQuery{} - variables := map[string]interface{}{ - "cursor": (*githubv4.String)(nil), - "owner": (githubv4.String)(r.config.getDefaultOwner()), - "name": (githubv4.String)(""), - } - - for _, repo := range repoList { - variables["name"] = (githubv4.String)(repo) - variables["cursor"] = (*githubv4.String)(nil) - for { - err := r.client.Query(r.ctx, &query, variables) - if err != nil { - return nil, err - } - for _, protection := range query.Repository.BranchProtectionRules.Nodes { - results = append(results, protection.Id) - } - - variables["cursor"] = query.Repository.BranchProtectionRules.PageInfo.EndCursor - - if !query.Repository.BranchProtectionRules.PageInfo.HasNextPage { - break - } - } - - } - - r.cache.Put("githubListBranchProtection", results) - return results, nil -} diff --git a/pkg/remote/github/repository_test.go b/pkg/remote/github/repository_test.go deleted file mode 100644 index 8b95e64c..00000000 --- a/pkg/remote/github/repository_test.go +++ /dev/null @@ -1,920 +0,0 @@ -package github - -import ( - "context" - "testing" - - "github.com/pkg/errors" - "github.com/shurcooL/githubv4" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestListRepositoriesForUser_WithError(t *testing.T) { - mockedClient := mocks.GithubGraphQLClient{} - expectedError := errors.New("test error from graphql") - mockedClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(expectedError) - - r := githubRepository{ - client: &mockedClient, - config: githubConfig{}, - cache: cache.New(1), - } - - _, err := r.ListRepositories() - assert.Equal(t, expectedError, err) -} - -func TestListRepositoriesForUser(t *testing.T) { - mockedClient := mocks.GithubGraphQLClient{} - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listRepoForOwnerQuery) - if !ok { - return false - } - q.Viewer.Repositories.Nodes = []struct{ Name string }{ - { - Name: "repo1", - }, - { - Name: "repo2", - }, - } - q.Viewer.Repositories.PageInfo = pageInfo{ - EndCursor: "next", - HasNextPage: true, - } - return true - }), - map[string]interface{}{ - "cursor": (*githubv4.String)(nil), - }).Return(nil).Once() - - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listRepoForOwnerQuery) - if !ok { - return false - } - q.Viewer.Repositories.Nodes = []struct{ Name string }{ - { - Name: "repo3", - }, - { - Name: "repo4", - }, - } - q.Viewer.Repositories.PageInfo = pageInfo{ - HasNextPage: false, - } - return true - }), - map[string]interface{}{ - "cursor": githubv4.NewString("next"), - }).Return(nil).Once() - - store := cache.New(1) - r := githubRepository{ - client: &mockedClient, - ctx: context.TODO(), - config: githubConfig{}, - cache: store, - } - - repos, err := r.ListRepositories() - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, []string{ - "repo1", - "repo2", - "repo3", - "repo4", - }, repos) - - // Check that results were cached - cachedData, err := r.ListRepositories() - assert.NoError(t, err) - assert.Equal(t, repos, cachedData) - assert.IsType(t, []string{}, store.Get("githubListRepositories")) -} - -func TestListRepositoriesForOrganization_WithError(t *testing.T) { - mockedClient := mocks.GithubGraphQLClient{} - expectedError := errors.New("test error from graphql") - mockedClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(expectedError) - - r := githubRepository{ - client: &mockedClient, - config: githubConfig{ - Organization: "testorg", - }, - cache: cache.New(1), - } - - _, err := r.ListRepositories() - assert.Equal(t, expectedError, err) -} - -func TestListRepositoriesForOrganization(t *testing.T) { - mockedClient := mocks.GithubGraphQLClient{} - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listRepoForOrgQuery) - if !ok { - return false - } - q.Organization.Repositories.Nodes = []struct { - Name string - }{ - { - Name: "repo1", - }, - { - Name: "repo2", - }, - } - q.Organization.Repositories.PageInfo = pageInfo{ - EndCursor: "next", - HasNextPage: true, - } - return true - }), - map[string]interface{}{ - "org": (githubv4.String)("testorg"), - "cursor": (*githubv4.String)(nil), - }).Return(nil).Once() - - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listRepoForOrgQuery) - if !ok { - return false - } - q.Organization.Repositories.Nodes = []struct { - Name string - }{ - { - Name: "repo3", - }, - { - Name: "repo4", - }, - } - q.Organization.Repositories.PageInfo = pageInfo{ - HasNextPage: false, - } - return true - }), - map[string]interface{}{ - "org": (githubv4.String)("testorg"), - "cursor": githubv4.NewString("next"), - }).Return(nil).Once() - - store := cache.New(1) - r := githubRepository{ - client: &mockedClient, - ctx: context.TODO(), - config: githubConfig{ - Organization: "testorg", - }, - cache: store, - } - - repos, err := r.ListRepositories() - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, []string{ - "repo1", - "repo2", - "repo3", - "repo4", - }, repos) - - // Check that results were cached - cachedData, err := r.ListRepositories() - assert.NoError(t, err) - assert.Equal(t, repos, cachedData) - assert.IsType(t, []string{}, store.Get("githubListRepositories")) -} - -func TestListTeams_WithError(t *testing.T) { - mockedClient := mocks.GithubGraphQLClient{} - expectedError := errors.New("test error from graphql") - mockedClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(expectedError) - - r := githubRepository{ - client: &mockedClient, - config: githubConfig{ - Organization: "testorg", - }, - cache: cache.New(1), - } - - _, err := r.ListTeams() - assert.Equal(t, expectedError, err) -} - -func TestListTeams_WithoutOrganization(t *testing.T) { - r := githubRepository{cache: cache.New(1)} - - teams, err := r.ListTeams() - assert.Nil(t, err) - assert.Equal(t, []Team{}, teams) -} - -func TestListTeams(t *testing.T) { - mockedClient := mocks.GithubGraphQLClient{} - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listTeamsQuery) - if !ok { - return false - } - q.Organization.Teams.Nodes = []struct { - DatabaseId int - Slug string - }{ - { - DatabaseId: 1, - Slug: "1", - }, - { - DatabaseId: 2, - Slug: "2", - }, - } - q.Organization.Teams.PageInfo = pageInfo{ - EndCursor: "next", - HasNextPage: true, - } - return true - }), - map[string]interface{}{ - "login": (githubv4.String)("testorg"), - "cursor": (*githubv4.String)(nil), - }).Return(nil).Once() - - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listTeamsQuery) - if !ok { - return false - } - q.Organization.Teams.Nodes = []struct { - DatabaseId int - Slug string - }{ - { - DatabaseId: 3, - Slug: "3", - }, - { - DatabaseId: 4, - Slug: "4", - }, - } - q.Organization.Teams.PageInfo = pageInfo{ - HasNextPage: false, - } - return true - }), - map[string]interface{}{ - "login": (githubv4.String)("testorg"), - "cursor": githubv4.NewString("next"), - }).Return(nil).Once() - - store := cache.New(1) - r := githubRepository{ - client: &mockedClient, - ctx: context.TODO(), - config: githubConfig{ - Organization: "testorg", - }, - cache: store, - } - - teams, err := r.ListTeams() - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, []Team{ - {1, "1"}, - {2, "2"}, - {3, "3"}, - {4, "4"}, - }, teams) - - // Check that results were cached - cachedData, err := r.ListTeams() - assert.NoError(t, err) - assert.Equal(t, teams, cachedData) - assert.IsType(t, []Team{}, store.Get("githubListTeams")) -} - -func TestListTeamMemberships_WithTeamListingError(t *testing.T) { - mockedClient := mocks.GithubGraphQLClient{} - expectedError := errors.New("test error from graphql") - mockedClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(expectedError) - - r := githubRepository{ - client: &mockedClient, - config: githubConfig{ - Organization: "testorg", - }, - cache: cache.New(1), - } - - _, err := r.ListTeamMemberships() - assert.Equal(t, expectedError, err) -} - -func TestListTeamMemberships_WithError(t *testing.T) { - mockedClient := mocks.GithubGraphQLClient{} - - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listTeamsQuery) - if !ok { - return false - } - q.Organization.Teams.Nodes = []struct { - DatabaseId int - Slug string - }{ - { - DatabaseId: 1, - Slug: "foo", - }, - } - q.Organization.Teams.PageInfo = pageInfo{ - HasNextPage: false, - } - return true - }), - map[string]interface{}{ - "login": (githubv4.String)("testorg"), - "cursor": (*githubv4.String)(nil), - }).Return(nil) - - expectedError := errors.New("test error from graphql") - mockedClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(expectedError) - - r := githubRepository{ - client: &mockedClient, - config: githubConfig{ - Organization: "testorg", - }, - cache: cache.New(1), - } - - _, err := r.ListTeamMemberships() - assert.Equal(t, expectedError, err) -} - -func TestListTeamMemberships_WithoutOrganization(t *testing.T) { - r := githubRepository{cache: cache.New(1)} - - teams, err := r.ListTeamMemberships() - assert.Nil(t, err) - assert.Equal(t, []string{}, teams) -} - -func TestListTeamMemberships(t *testing.T) { - mockedClient := mocks.GithubGraphQLClient{} - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listTeamsQuery) - if !ok { - return false - } - q.Organization.Teams.Nodes = []struct { - DatabaseId int - Slug string - }{ - { - DatabaseId: 1, - Slug: "foo", - }, - { - DatabaseId: 2, - Slug: "bar", - }, - } - q.Organization.Teams.PageInfo = pageInfo{ - HasNextPage: false, - } - return true - }), - map[string]interface{}{ - "login": (githubv4.String)("testorg"), - "cursor": (*githubv4.String)(nil), - }).Return(nil).Once() - - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listTeamMembershipsQuery) - if !ok { - return false - } - q.Organization.Team.Members.Nodes = []struct { - Login string - }{ - { - Login: "user-1", - }, - { - Login: "user-2", - }, - } - q.Organization.Team.Members.PageInfo = pageInfo{ - EndCursor: "next", - HasNextPage: true, - } - return true - }), - map[string]interface{}{ - "login": (githubv4.String)("testorg"), - "cursor": (*githubv4.String)(nil), - "slug": (githubv4.String)("foo"), - }).Return(nil).Once() - - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listTeamMembershipsQuery) - if !ok { - return false - } - q.Organization.Team.Members.Nodes = []struct { - Login string - }{ - { - Login: "user-3", - }, - { - Login: "user-4", - }, - } - q.Organization.Team.Members.PageInfo = pageInfo{ - HasNextPage: false, - } - return true - }), - map[string]interface{}{ - "login": (githubv4.String)("testorg"), - "cursor": (githubv4.String)("next"), - "slug": (githubv4.String)("foo"), - }).Return(nil).Once() - - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listTeamMembershipsQuery) - if !ok { - return false - } - q.Organization.Team.Members.Nodes = []struct { - Login string - }{ - { - Login: "user-5", - }, - { - Login: "user-6", - }, - } - q.Organization.Team.Members.PageInfo = pageInfo{ - HasNextPage: false, - } - return true - }), - map[string]interface{}{ - "login": (githubv4.String)("testorg"), - "cursor": (*githubv4.String)(nil), - "slug": (githubv4.String)("bar"), - }).Return(nil).Once() - - store := cache.New(1) - r := githubRepository{ - client: &mockedClient, - ctx: context.TODO(), - config: githubConfig{ - Organization: "testorg", - }, - cache: store, - } - - memberships, err := r.ListTeamMemberships() - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, []string{ - "1:user-1", - "1:user-2", - "1:user-3", - "1:user-4", - "2:user-5", - "2:user-6", - }, memberships) - - // Check that results were cached - cachedData, err := r.ListTeamMemberships() - assert.NoError(t, err) - assert.Equal(t, memberships, cachedData) - assert.IsType(t, []string{}, store.Get("githubListTeamMemberships")) -} - -func TestListMembership_WithError(t *testing.T) { - mockedClient := mocks.GithubGraphQLClient{} - expectedError := errors.New("test error from graphql") - mockedClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(expectedError) - - r := githubRepository{ - client: &mockedClient, - config: githubConfig{ - Organization: "testorg", - }, - cache: cache.New(1), - } - - _, err := r.ListMembership() - assert.Equal(t, expectedError, err) -} - -func TestListMembership_WithoutOrganization(t *testing.T) { - r := githubRepository{cache: cache.New(1)} - - teams, err := r.ListMembership() - assert.Nil(t, err) - assert.Equal(t, []string{}, teams) -} - -func TestListMembership(t *testing.T) { - mockedClient := mocks.GithubGraphQLClient{} - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listMembership) - if !ok { - return false - } - q.Organization.MembersWithRole.Nodes = []struct { - Login string - }{ - { - Login: "user-admin", - }, - { - Login: "user-non-admin-1", - }, - } - q.Organization.MembersWithRole.PageInfo = pageInfo{ - EndCursor: "next", - HasNextPage: true, - } - return true - }), - map[string]interface{}{ - "login": (githubv4.String)("testorg"), - "cursor": (*githubv4.String)(nil), - }).Return(nil).Once() - - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listMembership) - if !ok { - return false - } - q.Organization.MembersWithRole.Nodes = []struct { - Login string - }{ - { - Login: "user-non-admin-2", - }, - { - Login: "user-non-admin-3", - }, - } - q.Organization.MembersWithRole.PageInfo = pageInfo{ - HasNextPage: false, - } - return true - }), - map[string]interface{}{ - "login": (githubv4.String)("testorg"), - "cursor": githubv4.NewString("next"), - }).Return(nil).Once() - - store := cache.New(1) - r := githubRepository{ - client: &mockedClient, - ctx: context.TODO(), - config: githubConfig{ - Organization: "testorg", - }, - cache: store, - } - - teams, err := r.ListMembership() - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, []string{ - "testorg:user-admin", - "testorg:user-non-admin-1", - "testorg:user-non-admin-2", - "testorg:user-non-admin-3", - }, teams) - - // Check that results were cached - cachedData, err := r.ListMembership() - assert.NoError(t, err) - assert.Equal(t, teams, cachedData) - assert.IsType(t, []string{}, store.Get("githubListMembership")) - -} - -func TestListBranchProtection_WithRepoListingError(t *testing.T) { - mockedClient := mocks.GithubGraphQLClient{} - expectedError := errors.New("test error from graphql") - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listRepoForOrgQuery) - if !ok { - return false - } - q.Organization.Repositories.Nodes = []struct { - Name string - }{ - { - Name: "repo1", - }, - { - Name: "repo2", - }, - } - q.Organization.Repositories.PageInfo = pageInfo{ - HasNextPage: false, - } - return true - }), - map[string]interface{}{ - "org": (githubv4.String)("my-organization"), - "cursor": (*githubv4.String)(nil), - }).Return(expectedError) - - r := githubRepository{ - client: &mockedClient, - config: githubConfig{ - Organization: "my-organization", - }, - cache: cache.New(1), - } - - _, err := r.ListBranchProtection() - assert.Equal(t, expectedError, err) -} - -func TestListBranchProtection_WithError(t *testing.T) { - mockedClient := mocks.GithubGraphQLClient{} - expectedError := errors.New("test error from graphql") - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listRepoForOrgQuery) - if !ok { - return false - } - q.Organization.Repositories.Nodes = []struct { - Name string - }{ - { - Name: "repo1", - }, - { - Name: "repo2", - }, - } - q.Organization.Repositories.PageInfo = pageInfo{ - HasNextPage: false, - } - return true - }), - map[string]interface{}{ - "org": (githubv4.String)("testorg"), - "cursor": (*githubv4.String)(nil), - }).Return(nil) - - mockedClient.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(expectedError) - - r := githubRepository{ - client: &mockedClient, - config: githubConfig{ - Organization: "testorg", - }, - cache: cache.New(1), - } - - _, err := r.ListBranchProtection() - assert.Equal(t, expectedError, err) -} - -func TestListBranchProtection(t *testing.T) { - mockedClient := mocks.GithubGraphQLClient{} - - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listRepoForOrgQuery) - if !ok { - return false - } - q.Organization.Repositories.Nodes = []struct { - Name string - }{ - { - Name: "repo1", - }, - { - Name: "repo2", - }, - } - q.Organization.Repositories.PageInfo = pageInfo{ - HasNextPage: false, - } - return true - }), - map[string]interface{}{ - "org": (githubv4.String)("my-organization"), - "cursor": (*githubv4.String)(nil), - }).Return(nil).Once() - - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listBranchProtectionQuery) - if !ok { - return false - } - q.Repository.BranchProtectionRules.Nodes = []struct { - Id string - }{ - { - Id: "id1", - }, - { - Id: "id2", - }, - } - q.Repository.BranchProtectionRules.PageInfo = pageInfo{ - EndCursor: "nextPage", - HasNextPage: true, - } - return true - }), - map[string]interface{}{ - "owner": (githubv4.String)("my-organization"), - "name": (githubv4.String)("repo1"), - "cursor": (*githubv4.String)(nil), - }).Return(nil).Once() - - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listBranchProtectionQuery) - if !ok { - return false - } - q.Repository.BranchProtectionRules.Nodes = []struct { - Id string - }{ - { - Id: "id3", - }, - { - Id: "id4", - }, - } - q.Repository.BranchProtectionRules.PageInfo = pageInfo{ - EndCursor: "nextPage", - HasNextPage: false, - } - return true - }), - map[string]interface{}{ - "owner": (githubv4.String)("my-organization"), - "name": (githubv4.String)("repo1"), - "cursor": (githubv4.String)("nextPage"), - }).Return(nil).Once() - - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listBranchProtectionQuery) - if !ok { - return false - } - q.Repository.BranchProtectionRules.Nodes = []struct { - Id string - }{ - { - Id: "id5", - }, - { - Id: "id6", - }, - } - q.Repository.BranchProtectionRules.PageInfo = pageInfo{ - EndCursor: "nextPage", - HasNextPage: true, - } - return true - }), - map[string]interface{}{ - "owner": (githubv4.String)("my-organization"), - "name": (githubv4.String)("repo2"), - "cursor": (*githubv4.String)(nil), - }).Return(nil).Once() - - mockedClient.On("Query", - mock.Anything, - mock.MatchedBy(func(query interface{}) bool { - q, ok := query.(*listBranchProtectionQuery) - if !ok { - return false - } - q.Repository.BranchProtectionRules.Nodes = []struct { - Id string - }{ - { - Id: "id7", - }, - { - Id: "id8", - }, - } - q.Repository.BranchProtectionRules.PageInfo = pageInfo{ - EndCursor: "nextPage", - HasNextPage: false, - } - return true - }), - map[string]interface{}{ - "owner": (githubv4.String)("my-organization"), - "name": (githubv4.String)("repo2"), - "cursor": (githubv4.String)("nextPage"), - }).Return(nil).Once() - - store := cache.New(1) - r := githubRepository{ - client: &mockedClient, - ctx: context.TODO(), - config: githubConfig{ - Organization: "my-organization", - }, - cache: store, - } - - teams, err := r.ListBranchProtection() - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, []string{ - "id1", - "id2", - "id3", - "id4", - "id5", - "id6", - "id7", - "id8", - }, teams) - - // Check that results were cached - cachedData, err := r.ListBranchProtection() - assert.NoError(t, err) - assert.Equal(t, teams, cachedData) - assert.IsType(t, []string{}, store.Get("githubListBranchProtection")) -} diff --git a/pkg/remote/github_branch_protection_scanner_test.go b/pkg/remote/github_branch_protection_scanner_test.go deleted file mode 100644 index 8cf2b31d..00000000 --- a/pkg/remote/github_branch_protection_scanner_test.go +++ /dev/null @@ -1,123 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/github" - githubres "github.com/snyk/driftctl/pkg/resource/github" - "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - tftest "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/mock" - - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - "github.com/stretchr/testify/assert" -) - -func TestScanGithubBranchProtection(t *testing.T) { - - cases := []struct { - test string - dirName string - mocks func(*github.MockGithubRepository, *mocks.AlerterInterface) - err error - }{ - { - test: "no branch protection", - dirName: "github_branch_protection_empty", - mocks: func(client *github.MockGithubRepository, alerter *mocks.AlerterInterface) { - client.On("ListBranchProtection").Return([]string{}, nil) - }, - err: nil, - }, - { - test: "Multiple branch protections", - dirName: "github_branch_protection_multiples", - mocks: func(client *github.MockGithubRepository, alerter *mocks.AlerterInterface) { - client.On("ListBranchProtection").Return([]string{ - "MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzI=", //"repo0:main" - "MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0Nzg=", //"repo0:toto" - "MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzQ=", //"repo1:main" - "MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0ODA=", //"repo1:toto" - "MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0NzE=", //"repo2:main" - "MDIwOkJyYW5jaFByb3RlY3Rpb25SdWxlMTk1NDg0Nzc=", //"repo2:toto" - }, nil) - }, - err: nil, - }, - { - test: "cannot list branch protections", - dirName: "github_branch_protection_empty", - mocks: func(client *github.MockGithubRepository, alerter *mocks.AlerterInterface) { - client.On("ListBranchProtection").Return(nil, errors.New("Your token has not been granted the required scopes to execute this query.")) - - alerter.On("SendAlert", githubres.GithubBranchProtectionResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteGithubTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Your token has not been granted the required scopes to execute this query."), githubres.GithubBranchProtectionResourceType, githubres.GithubBranchProtectionResourceType), alerts.EnumerationPhase)).Return() - }, - err: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("github", "4.4.0") - githubres.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - mockedRepo := github.MockGithubRepository{} - c.mocks(&mockedRepo, alerter) - - var repo github.GithubRepository = &mockedRepo - - realProvider, err := tftest.InitTestGithubProvider(providerLibrary, "4.4.0") - if err != nil { - t.Fatal(err) - } - provider := tftest.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = github.NewGithubRepository(realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(github.NewGithubBranchProtectionEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(githubres.GithubBranchProtectionResourceType, common.NewGenericDetailsFetcher(githubres.GithubBranchProtectionResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, githubres.GithubBranchProtectionResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - mockedRepo.AssertExpectations(tt) - alerter.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/github_membership_scanner_test.go b/pkg/remote/github_membership_scanner_test.go deleted file mode 100644 index c3fbda9d..00000000 --- a/pkg/remote/github_membership_scanner_test.go +++ /dev/null @@ -1,119 +0,0 @@ -package remote - -import ( - "errors" - "testing" - - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/github" - githubres "github.com/snyk/driftctl/pkg/resource/github" - "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - tftest "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/mock" - - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - "github.com/stretchr/testify/assert" -) - -func TestScanGithubMembership(t *testing.T) { - - cases := []struct { - test string - dirName string - mocks func(*github.MockGithubRepository, *mocks.AlerterInterface) - err error - }{ - { - test: "no members", - dirName: "github_membership_empty", - mocks: func(client *github.MockGithubRepository, alerter *mocks.AlerterInterface) { - client.On("ListMembership").Return([]string{}, nil) - }, - err: nil, - }, - { - test: "Multiple membership with admin and member roles", - dirName: "github_membership_multiple", - mocks: func(client *github.MockGithubRepository, alerter *mocks.AlerterInterface) { - client.On("ListMembership").Return([]string{ - "driftctl-test:driftctl-acceptance-tester", - "driftctl-test:eliecharra", - }, nil) - }, - err: nil, - }, - { - test: "cannot list membership", - dirName: "github_membership_empty", - mocks: func(client *github.MockGithubRepository, alerter *mocks.AlerterInterface) { - client.On("ListMembership").Return(nil, errors.New("Your token has not been granted the required scopes to execute this query.")) - - alerter.On("SendAlert", githubres.GithubMembershipResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteGithubTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Your token has not been granted the required scopes to execute this query."), githubres.GithubMembershipResourceType, githubres.GithubMembershipResourceType), alerts.EnumerationPhase)).Return() - }, - err: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("github", "4.4.0") - githubres.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - mockedRepo := github.MockGithubRepository{} - c.mocks(&mockedRepo, alerter) - - var repo github.GithubRepository = &mockedRepo - - realProvider, err := tftest.InitTestGithubProvider(providerLibrary, "4.4.0") - if err != nil { - t.Fatal(err) - } - provider := tftest.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = github.NewGithubRepository(realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(github.NewGithubMembershipEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(githubres.GithubMembershipResourceType, common.NewGenericDetailsFetcher(githubres.GithubMembershipResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, githubres.GithubMembershipResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - mockedRepo.AssertExpectations(tt) - alerter.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/github_repository_scanner_test.go b/pkg/remote/github_repository_scanner_test.go deleted file mode 100644 index f0b2a8c1..00000000 --- a/pkg/remote/github_repository_scanner_test.go +++ /dev/null @@ -1,119 +0,0 @@ -package remote - -import ( - "errors" - "testing" - - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/github" - githubres "github.com/snyk/driftctl/pkg/resource/github" - "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - tftest "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/mock" - - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - "github.com/stretchr/testify/assert" -) - -func TestScanGithubRepository(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*github.MockGithubRepository, *mocks.AlerterInterface) - err error - }{ - { - test: "no github repos", - dirName: "github_repository_empty", - mocks: func(client *github.MockGithubRepository, alerter *mocks.AlerterInterface) { - client.On("ListRepositories").Return([]string{}, nil) - }, - err: nil, - }, - { - test: "Multiple github repos Table", - dirName: "github_repository_multiple", - mocks: func(client *github.MockGithubRepository, alerter *mocks.AlerterInterface) { - client.On("ListRepositories").Return([]string{ - "driftctl", - "driftctl-demos", - }, nil) - }, - err: nil, - }, - { - test: "cannot list repositories", - dirName: "github_repository_empty", - mocks: func(client *github.MockGithubRepository, alerter *mocks.AlerterInterface) { - client.On("ListRepositories").Return(nil, errors.New("Your token has not been granted the required scopes to execute this query.")) - - alerter.On("SendAlert", githubres.GithubRepositoryResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteGithubTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Your token has not been granted the required scopes to execute this query."), githubres.GithubRepositoryResourceType, githubres.GithubRepositoryResourceType), alerts.EnumerationPhase)).Return() - }, - err: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("github", "4.4.0") - githubres.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - mockedRepo := github.MockGithubRepository{} - c.mocks(&mockedRepo, alerter) - - var repo github.GithubRepository = &mockedRepo - - realProvider, err := tftest.InitTestGithubProvider(providerLibrary, "4.4.0") - if err != nil { - t.Fatal(err) - } - provider := tftest.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = github.NewGithubRepository(realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(github.NewGithubRepositoryEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(githubres.GithubRepositoryResourceType, common.NewGenericDetailsFetcher(githubres.GithubRepositoryResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, githubres.GithubRepositoryResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - mockedRepo.AssertExpectations(tt) - alerter.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/github_team_membership_scanner_test.go b/pkg/remote/github_team_membership_scanner_test.go deleted file mode 100644 index d2c5f1b0..00000000 --- a/pkg/remote/github_team_membership_scanner_test.go +++ /dev/null @@ -1,119 +0,0 @@ -package remote - -import ( - "errors" - "testing" - - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/github" - githubres "github.com/snyk/driftctl/pkg/resource/github" - "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - tftest "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/mock" - - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - "github.com/stretchr/testify/assert" -) - -func TestScanGithubTeamMembership(t *testing.T) { - - cases := []struct { - test string - dirName string - mocks func(*github.MockGithubRepository, *mocks.AlerterInterface) - err error - }{ - { - test: "no github team memberships", - dirName: "github_team_membership_empty", - mocks: func(client *github.MockGithubRepository, alerter *mocks.AlerterInterface) { - client.On("ListTeamMemberships").Return([]string{}, nil) - }, - err: nil, - }, - { - test: "multiple github team memberships", - dirName: "github_team_membership_multiple", - mocks: func(client *github.MockGithubRepository, alerter *mocks.AlerterInterface) { - client.On("ListTeamMemberships").Return([]string{ - "4570529:driftctl-acceptance-tester", - "4570529:wbeuil", - }, nil) - }, - err: nil, - }, - { - test: "cannot list team membership", - dirName: "github_team_membership_empty", - mocks: func(client *github.MockGithubRepository, alerter *mocks.AlerterInterface) { - client.On("ListTeamMemberships").Return(nil, errors.New("Your token has not been granted the required scopes to execute this query.")) - - alerter.On("SendAlert", githubres.GithubTeamMembershipResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteGithubTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Your token has not been granted the required scopes to execute this query."), githubres.GithubTeamMembershipResourceType, githubres.GithubTeamMembershipResourceType), alerts.EnumerationPhase)).Return() - }, - err: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("github", "4.4.0") - githubres.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - mockedRepo := github.MockGithubRepository{} - c.mocks(&mockedRepo, alerter) - - var repo github.GithubRepository = &mockedRepo - - realProvider, err := tftest.InitTestGithubProvider(providerLibrary, "4.4.0") - if err != nil { - t.Fatal(err) - } - provider := tftest.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = github.NewGithubRepository(realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(github.NewGithubTeamMembershipEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(githubres.GithubTeamMembershipResourceType, common.NewGenericDetailsFetcher(githubres.GithubTeamMembershipResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, githubres.GithubTeamMembershipResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - mockedRepo.AssertExpectations(tt) - alerter.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/github_team_scanner_test.go b/pkg/remote/github_team_scanner_test.go deleted file mode 100644 index 6dc51722..00000000 --- a/pkg/remote/github_team_scanner_test.go +++ /dev/null @@ -1,120 +0,0 @@ -package remote - -import ( - "errors" - "testing" - - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/github" - githubres "github.com/snyk/driftctl/pkg/resource/github" - "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - tftest "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/mock" - - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - "github.com/stretchr/testify/assert" -) - -func TestScanGithubTeam(t *testing.T) { - - tests := []struct { - test string - dirName string - mocks func(*github.MockGithubRepository, *mocks.AlerterInterface) - err error - }{ - { - test: "no github teams", - dirName: "github_teams_empty", - mocks: func(client *github.MockGithubRepository, alerter *mocks.AlerterInterface) { - client.On("ListTeams").Return([]github.Team{}, nil) - }, - err: nil, - }, - { - test: "Multiple github teams with parent", - dirName: "github_teams_multiple", - mocks: func(client *github.MockGithubRepository, alerter *mocks.AlerterInterface) { - client.On("ListTeams").Return([]github.Team{ - {DatabaseId: 4556811}, // github_team.team1 - {DatabaseId: 4556812}, // github_team.team2 - {DatabaseId: 4556814}, // github_team.with_parent - }, nil) - }, - err: nil, - }, - { - test: "cannot list teams", - dirName: "github_teams_empty", - mocks: func(client *github.MockGithubRepository, alerter *mocks.AlerterInterface) { - client.On("ListTeams").Return(nil, errors.New("Your token has not been granted the required scopes to execute this query.")) - - alerter.On("SendAlert", githubres.GithubTeamResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteGithubTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Your token has not been granted the required scopes to execute this query."), githubres.GithubTeamResourceType, githubres.GithubTeamResourceType), alerts.EnumerationPhase)).Return() - }, - err: nil, - }, - } - - schemaRepository := testresource.InitFakeSchemaRepository("github", "4.4.0") - githubres.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range tests { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - mockedRepo := github.MockGithubRepository{} - c.mocks(&mockedRepo, alerter) - - var repo github.GithubRepository = &mockedRepo - - realProvider, err := tftest.InitTestGithubProvider(providerLibrary, "4.4.0") - if err != nil { - t.Fatal(err) - } - provider := tftest.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - if shouldUpdate { - err := realProvider.Init() - if err != nil { - t.Fatal(err) - } - provider.ShouldUpdate() - repo = github.NewGithubRepository(realProvider.GetConfig(), cache.New(0)) - } - - remoteLibrary.AddEnumerator(github.NewGithubTeamEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(githubres.GithubTeamResourceType, common.NewGenericDetailsFetcher(githubres.GithubTeamResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.err) - if err != nil { - return - } - test.TestAgainstGoldenFile(got, githubres.GithubTeamResourceType, c.dirName, provider, deserializer, shouldUpdate, tt) - mockedRepo.AssertExpectations(tt) - alerter.AssertExpectations(tt) - }) - } -} diff --git a/pkg/remote/google/google_bigquery_dataset_enumerator.go b/pkg/remote/google/google_bigquery_dataset_enumerator.go deleted file mode 100644 index a3d9276a..00000000 --- a/pkg/remote/google/google_bigquery_dataset_enumerator.go +++ /dev/null @@ -1,49 +0,0 @@ -package google - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleBigqueryDatasetEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleBigqueryDatasetEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleBigqueryDatasetEnumerator { - return &GoogleBigqueryDatasetEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleBigqueryDatasetEnumerator) SupportedType() resource.ResourceType { - return google.GoogleBigqueryDatasetResourceType -} - -func (e *GoogleBigqueryDatasetEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllDatasets() - - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{ - "friendly_name": res.DisplayName, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_bigquery_table_enumerator.go b/pkg/remote/google/google_bigquery_table_enumerator.go deleted file mode 100644 index 465db973..00000000 --- a/pkg/remote/google/google_bigquery_table_enumerator.go +++ /dev/null @@ -1,49 +0,0 @@ -package google - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleBigqueryTableEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleBigqueryTableEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleBigqueryTableEnumerator { - return &GoogleBigqueryTableEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleBigqueryTableEnumerator) SupportedType() resource.ResourceType { - return google.GoogleBigqueryTableResourceType -} - -func (e *GoogleBigqueryTableEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllTables() - - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{ - "friendly_name": res.DisplayName, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_bigtable_instance_enumerator.go b/pkg/remote/google/google_bigtable_instance_enumerator.go deleted file mode 100644 index 9bce37d1..00000000 --- a/pkg/remote/google/google_bigtable_instance_enumerator.go +++ /dev/null @@ -1,53 +0,0 @@ -package google - -import ( - "github.com/sirupsen/logrus" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleBigTableInstanceEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleBigTableInstanceEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleBigTableInstanceEnumerator { - return &GoogleBigTableInstanceEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleBigTableInstanceEnumerator) SupportedType() resource.ResourceType { - return google.GoogleBigTableInstanceResourceType -} - -func (e *GoogleBigTableInstanceEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllBigtableInstances() - - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - name, exist := res.GetResource().GetData().GetFields()["name"] - if !exist || name.GetStringValue() == "" { - logrus.WithField("name", res.GetName()).Warn("Unable to retrieve resource name") - continue - } - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - name.GetStringValue(), - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_bigtable_table_enumerator.go b/pkg/remote/google/google_bigtable_table_enumerator.go deleted file mode 100644 index 1257bd69..00000000 --- a/pkg/remote/google/google_bigtable_table_enumerator.go +++ /dev/null @@ -1,53 +0,0 @@ -package google - -import ( - "github.com/sirupsen/logrus" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleBigtableTableEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleBigtableTableEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleBigtableTableEnumerator { - return &GoogleBigtableTableEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleBigtableTableEnumerator) SupportedType() resource.ResourceType { - return google.GoogleBigtableTableResourceType -} - -func (e *GoogleBigtableTableEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllBigtableTables() - - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - name, exist := res.GetResource().GetData().GetFields()["name"] - if !exist || name.GetStringValue() == "" { - logrus.WithField("name", res.GetName()).Warn("Unable to retrieve resource name") - continue - } - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - name.GetStringValue(), - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_cloudfunctions_function_enumerator.go b/pkg/remote/google/google_cloudfunctions_function_enumerator.go deleted file mode 100644 index 21a9513a..00000000 --- a/pkg/remote/google/google_cloudfunctions_function_enumerator.go +++ /dev/null @@ -1,53 +0,0 @@ -package google - -import ( - "github.com/sirupsen/logrus" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleCloudFunctionsFunctionEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleCloudFunctionsFunctionEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleCloudFunctionsFunctionEnumerator { - return &GoogleCloudFunctionsFunctionEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleCloudFunctionsFunctionEnumerator) SupportedType() resource.ResourceType { - return google.GoogleCloudFunctionsFunctionResourceType -} - -func (e *GoogleCloudFunctionsFunctionEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllFunctions() - - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - name, exist := res.GetResource().GetData().GetFields()["name"] - if !exist || name.GetStringValue() == "" { - logrus.WithField("name", res.GetName()).Warn("Unable to retrieve resource name") - continue - } - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - name.GetStringValue(), - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_cloudrun_service_enumerator.go b/pkg/remote/google/google_cloudrun_service_enumerator.go deleted file mode 100644 index a3db094f..00000000 --- a/pkg/remote/google/google_cloudrun_service_enumerator.go +++ /dev/null @@ -1,62 +0,0 @@ -package google - -import ( - "strings" - - "github.com/sirupsen/logrus" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleCloudRunServiceEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleCloudRunServiceEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleCloudRunServiceEnumerator { - return &GoogleCloudRunServiceEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleCloudRunServiceEnumerator) SupportedType() resource.ResourceType { - return google.GoogleCloudRunServiceResourceType -} - -func (e *GoogleCloudRunServiceEnumerator) Enumerate() ([]*resource.Resource, error) { - subnets, err := e.repository.SearchAllCloudRunServices() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(subnets)) - - for _, res := range subnets { - splittedName := strings.Split(res.GetName(), "/") - if len(splittedName) != 9 { - logrus.WithField("name", res.GetName()).Error("Unable to decode project from resource name") - continue - } - project := splittedName[4] - id := strings.Join([]string{ - "locations", res.GetLocation(), - "namespaces", project, - "services", res.GetDisplayName(), - }, "/") - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - id, - map[string]interface{}{ - "name": res.GetDisplayName(), - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_compute_address_enumerator.go b/pkg/remote/google/google_compute_address_enumerator.go deleted file mode 100644 index 419d3aa5..00000000 --- a/pkg/remote/google/google_compute_address_enumerator.go +++ /dev/null @@ -1,58 +0,0 @@ -package google - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleComputeAddressEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleComputeAddressEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeAddressEnumerator { - return &GoogleComputeAddressEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleComputeAddressEnumerator) SupportedType() resource.ResourceType { - return google.GoogleComputeAddressResourceType -} - -func (e *GoogleComputeAddressEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllAddresses() - - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - // Global addresses are handled as a dedicated resource - if res.GetLocation() == "global" { - continue - } - address := "" - if addr, exist := res.GetAdditionalAttributes().GetFields()["address"]; exist { - address = addr.GetStringValue() - } - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{ - "name": res.GetDisplayName(), - "address": address, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_compute_disk_enumerator.go b/pkg/remote/google/google_compute_disk_enumerator.go deleted file mode 100644 index a7786db0..00000000 --- a/pkg/remote/google/google_compute_disk_enumerator.go +++ /dev/null @@ -1,49 +0,0 @@ -package google - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleComputeDiskEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleComputeDiskEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeDiskEnumerator { - return &GoogleComputeDiskEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleComputeDiskEnumerator) SupportedType() resource.ResourceType { - return google.GoogleComputeDiskResourceType -} - -func (e *GoogleComputeDiskEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllDisks() - - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{ - "name": res.GetDisplayName(), - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_compute_firewall_enumerator.go b/pkg/remote/google/google_compute_firewall_enumerator.go deleted file mode 100644 index b6b71568..00000000 --- a/pkg/remote/google/google_compute_firewall_enumerator.go +++ /dev/null @@ -1,59 +0,0 @@ -package google - -import ( - "strings" - - "github.com/sirupsen/logrus" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleComputeFirewallEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleComputeFirewallEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeFirewallEnumerator { - return &GoogleComputeFirewallEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleComputeFirewallEnumerator) SupportedType() resource.ResourceType { - return google.GoogleComputeFirewallResourceType -} - -func (e *GoogleComputeFirewallEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllFirewalls() - - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - splittedName := strings.Split(res.GetName(), "/") - if len(splittedName) != 8 { - logrus.WithField("name", res.GetName()).Error("Unable to decode project from firewall name") - continue - } - project := splittedName[4] - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{ - "name": res.DisplayName, - "project": project, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_compute_forwarding_rule_enumerator.go b/pkg/remote/google/google_compute_forwarding_rule_enumerator.go deleted file mode 100644 index 99561491..00000000 --- a/pkg/remote/google/google_compute_forwarding_rule_enumerator.go +++ /dev/null @@ -1,45 +0,0 @@ -package google - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleComputeForwardingRuleEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleComputeForwardingRuleEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeForwardingRuleEnumerator { - return &GoogleComputeForwardingRuleEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleComputeForwardingRuleEnumerator) SupportedType() resource.ResourceType { - return google.GoogleComputeForwardingRuleResourceType -} - -func (e *GoogleComputeForwardingRuleEnumerator) Enumerate() ([]*resource.Resource, error) { - forwardingRules, err := e.repository.SearchAllForwardingRules() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(forwardingRules)) - for _, res := range forwardingRules { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_compute_global_address_enumerator.go b/pkg/remote/google/google_compute_global_address_enumerator.go deleted file mode 100644 index 0980e2be..00000000 --- a/pkg/remote/google/google_compute_global_address_enumerator.go +++ /dev/null @@ -1,60 +0,0 @@ -package google - -import ( - "github.com/sirupsen/logrus" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleComputeGlobalAddressEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleComputeGlobalAddressEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeGlobalAddressEnumerator { - return &GoogleComputeGlobalAddressEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleComputeGlobalAddressEnumerator) SupportedType() resource.ResourceType { - return google.GoogleComputeGlobalAddressResourceType -} - -func (e *GoogleComputeGlobalAddressEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllGlobalAddresses() - - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - name, exist := res.GetResource().GetData().GetFields()["name"] - if !exist || name.GetStringValue() == "" { - logrus.WithField("name", res.GetName()).Warn("Unable to retrieve resource name") - continue - } - address := "" - if addr, exist := res.GetResource().GetData().GetFields()["address"]; exist { - address = addr.GetStringValue() - } - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{ - "name": name.GetStringValue(), - "address": address, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_compute_global_forwarding_rule_enumerator.go b/pkg/remote/google/google_compute_global_forwarding_rule_enumerator.go deleted file mode 100644 index 8e293a2d..00000000 --- a/pkg/remote/google/google_compute_global_forwarding_rule_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package google - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleComputeGlobalForwardingRuleEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleComputeGlobalForwardingRuleEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeGlobalForwardingRuleEnumerator { - return &GoogleComputeGlobalForwardingRuleEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleComputeGlobalForwardingRuleEnumerator) SupportedType() resource.ResourceType { - return google.GoogleComputeGlobalForwardingRuleResourceType -} - -func (e *GoogleComputeGlobalForwardingRuleEnumerator) Enumerate() ([]*resource.Resource, error) { - globalForwardingRules, err := e.repository.SearchAllGlobalForwardingRules() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(globalForwardingRules)) - - for _, res := range globalForwardingRules { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_compute_health_check_enumerator.go b/pkg/remote/google/google_compute_health_check_enumerator.go deleted file mode 100644 index 04631560..00000000 --- a/pkg/remote/google/google_compute_health_check_enumerator.go +++ /dev/null @@ -1,47 +0,0 @@ -package google - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleComputeHealthCheckEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleComputeHealthCheckEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeHealthCheckEnumerator { - return &GoogleComputeHealthCheckEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleComputeHealthCheckEnumerator) SupportedType() resource.ResourceType { - return google.GoogleComputeHealthCheckResourceType -} - -func (e *GoogleComputeHealthCheckEnumerator) Enumerate() ([]*resource.Resource, error) { - checks, err := e.repository.SearchAllHealthChecks() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(checks)) - for _, res := range checks { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{ - "name": res.GetDisplayName(), - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_compute_image_enumerator.go b/pkg/remote/google/google_compute_image_enumerator.go deleted file mode 100644 index 7ee88314..00000000 --- a/pkg/remote/google/google_compute_image_enumerator.go +++ /dev/null @@ -1,49 +0,0 @@ -package google - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleComputeImageEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleComputeImageEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeImageEnumerator { - return &GoogleComputeImageEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleComputeImageEnumerator) SupportedType() resource.ResourceType { - return google.GoogleComputeImageResourceType -} - -func (e *GoogleComputeImageEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllImages() - - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{ - "name": res.GetDisplayName(), - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_compute_instance_enumerator.go b/pkg/remote/google/google_compute_instance_enumerator.go deleted file mode 100644 index 05f87830..00000000 --- a/pkg/remote/google/google_compute_instance_enumerator.go +++ /dev/null @@ -1,47 +0,0 @@ -package google - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleComputeInstanceEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleComputeInstanceEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeInstanceEnumerator { - return &GoogleComputeInstanceEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleComputeInstanceEnumerator) SupportedType() resource.ResourceType { - return google.GoogleComputeInstanceResourceType -} - -func (e *GoogleComputeInstanceEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllInstances() - - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_compute_instance_group_enumerator.go b/pkg/remote/google/google_compute_instance_group_enumerator.go deleted file mode 100644 index eadc4e6a..00000000 --- a/pkg/remote/google/google_compute_instance_group_enumerator.go +++ /dev/null @@ -1,58 +0,0 @@ -package google - -import ( - "strings" - - "github.com/sirupsen/logrus" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleComputeInstanceGroupEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleComputeInstanceGroupEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeInstanceGroupEnumerator { - return &GoogleComputeInstanceGroupEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleComputeInstanceGroupEnumerator) SupportedType() resource.ResourceType { - return google.GoogleComputeInstanceGroupResourceType -} - -func (e *GoogleComputeInstanceGroupEnumerator) Enumerate() ([]*resource.Resource, error) { - groups, err := e.repository.SearchAllInstanceGroups() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(groups)) - for _, res := range groups { - splittedName := strings.Split(res.GetName(), "/") - if len(splittedName) != 9 { - logrus.WithField("name", res.GetName()).Error("Unable to decode project from instance group name") - continue - } - project := splittedName[4] - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{ - "name": res.GetDisplayName(), - "project": project, - "location": res.GetLocation(), - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_compute_instance_group_manager_enumerator.go b/pkg/remote/google/google_compute_instance_group_manager_enumerator.go deleted file mode 100644 index 6c466aea..00000000 --- a/pkg/remote/google/google_compute_instance_group_manager_enumerator.go +++ /dev/null @@ -1,56 +0,0 @@ -package google - -import ( - "strings" - - "github.com/sirupsen/logrus" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleComputeInstanceGroupManagerEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleComputeInstanceGroupManagerEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeInstanceGroupManagerEnumerator { - return &GoogleComputeInstanceGroupManagerEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleComputeInstanceGroupManagerEnumerator) SupportedType() resource.ResourceType { - return google.GoogleComputeInstanceGroupManagerResourceType -} - -func (e *GoogleComputeInstanceGroupManagerEnumerator) Enumerate() ([]*resource.Resource, error) { - items, err := e.repository.SearchAllInstanceGroupManagers() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(items)) - for _, res := range items { - splittedName := strings.Split(res.GetName(), "/") - if len(splittedName) != 9 { - logrus.WithField("name", res.GetName()).Error("Unable to decode project from instance group name") - continue - } - name := splittedName[8] - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{ - "name": name, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_compute_network_enumerator.go b/pkg/remote/google/google_compute_network_enumerator.go deleted file mode 100644 index 07e0e262..00000000 --- a/pkg/remote/google/google_compute_network_enumerator.go +++ /dev/null @@ -1,48 +0,0 @@ -package google - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleComputeNetworkEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleComputeNetworkEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeNetworkEnumerator { - return &GoogleComputeNetworkEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleComputeNetworkEnumerator) SupportedType() resource.ResourceType { - return google.GoogleComputeNetworkResourceType -} - -func (e *GoogleComputeNetworkEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllNetworks() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{ - "name": res.DisplayName, - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_compute_node_group_enumerator.go b/pkg/remote/google/google_compute_node_group_enumerator.go deleted file mode 100644 index 910b9372..00000000 --- a/pkg/remote/google/google_compute_node_group_enumerator.go +++ /dev/null @@ -1,47 +0,0 @@ -package google - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleComputeNodeGroupEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleComputeNodeGroupEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeNodeGroupEnumerator { - return &GoogleComputeNodeGroupEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleComputeNodeGroupEnumerator) SupportedType() resource.ResourceType { - return google.GoogleComputeNodeGroupResourceType -} - -func (e *GoogleComputeNodeGroupEnumerator) Enumerate() ([]*resource.Resource, error) { - nodeGroups, err := e.repository.SearchAllNodeGroups() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(nodeGroups)) - for _, res := range nodeGroups { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{ - "name": res.GetName(), - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_compute_router_enumerator.go b/pkg/remote/google/google_compute_router_enumerator.go deleted file mode 100644 index 8b8ec802..00000000 --- a/pkg/remote/google/google_compute_router_enumerator.go +++ /dev/null @@ -1,46 +0,0 @@ -package google - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleComputeRouterEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleComputeRouterEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeRouterEnumerator { - return &GoogleComputeRouterEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleComputeRouterEnumerator) SupportedType() resource.ResourceType { - return google.GoogleComputeRouterResourceType -} - -func (e *GoogleComputeRouterEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllRouters() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_compute_subnetwork_enumerator.go b/pkg/remote/google/google_compute_subnetwork_enumerator.go deleted file mode 100644 index 307eee79..00000000 --- a/pkg/remote/google/google_compute_subnetwork_enumerator.go +++ /dev/null @@ -1,49 +0,0 @@ -package google - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleComputeSubnetworkEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleComputeSubnetworkEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleComputeSubnetworkEnumerator { - return &GoogleComputeSubnetworkEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleComputeSubnetworkEnumerator) SupportedType() resource.ResourceType { - return google.GoogleComputeSubnetworkResourceType -} - -func (e *GoogleComputeSubnetworkEnumerator) Enumerate() ([]*resource.Resource, error) { - subnets, err := e.repository.SearchAllSubnetworks() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(subnets)) - - for _, res := range subnets { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - trimResourceName(res.GetName()), - map[string]interface{}{ - "name": res.GetDisplayName(), - "region": res.GetLocation(), - }, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_dns_managed_zone_enumerator.go b/pkg/remote/google/google_dns_managed_zone_enumerator.go deleted file mode 100644 index aa9953d7..00000000 --- a/pkg/remote/google/google_dns_managed_zone_enumerator.go +++ /dev/null @@ -1,59 +0,0 @@ -package google - -import ( - "strings" - - "github.com/sirupsen/logrus" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleDNSManagedZoneEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleDNSManagedZoneEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleDNSManagedZoneEnumerator { - return &GoogleDNSManagedZoneEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleDNSManagedZoneEnumerator) SupportedType() resource.ResourceType { - return google.GoogleDNSManagedZoneResourceType -} - -func (e *GoogleDNSManagedZoneEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllDNSManagedZones() - - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - // We should have ID = "projects/cloudskiff-dev-elie/managedZones/example-zone" - // We have projects/cloudskiff-dev-elie/managedZones/2435093289230056557 - for _, res := range resources { - id := trimResourceName(res.Name) - splittedId := strings.Split(id, "/managedZones/") - if len(splittedId) != 2 { - logrus.WithField("id", res.Name).Warn("Cannot parse google_dns_managed_zone ID") - continue - } - id = strings.Join([]string{splittedId[0], "managedZones", res.DisplayName}, "/") - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - id, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_project_iam_member_enumerator.go b/pkg/remote/google/google_project_iam_member_enumerator.go deleted file mode 100644 index 44f40d9c..00000000 --- a/pkg/remote/google/google_project_iam_member_enumerator.go +++ /dev/null @@ -1,57 +0,0 @@ -package google - -import ( - "fmt" - - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleProjectIamMemberEnumerator struct { - repository repository.CloudResourceManagerRepository - factory resource.ResourceFactory -} - -func NewGoogleProjectIamMemberEnumerator(repo repository.CloudResourceManagerRepository, factory resource.ResourceFactory) *GoogleProjectIamMemberEnumerator { - return &GoogleProjectIamMemberEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleProjectIamMemberEnumerator) SupportedType() resource.ResourceType { - return google.GoogleProjectIamMemberResourceType -} - -func (e *GoogleProjectIamMemberEnumerator) Enumerate() ([]*resource.Resource, error) { - results := make([]*resource.Resource, 0) - - bindingsByProject, err := e.repository.ListProjectsBindings() - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - for project, bindings := range bindingsByProject { - for roleName, members := range bindings { - for _, member := range members { - id := fmt.Sprintf("%s/%s/%s", project, roleName, member) - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - id, - map[string]interface{}{ - "id": id, - "project": project, - "role": roleName, - "member": member, - }, - ), - ) - } - } - } - - return results, err -} diff --git a/pkg/remote/google/google_sql_database_instance_enumerator.go b/pkg/remote/google/google_sql_database_instance_enumerator.go deleted file mode 100644 index 524b09c2..00000000 --- a/pkg/remote/google/google_sql_database_instance_enumerator.go +++ /dev/null @@ -1,53 +0,0 @@ -package google - -import ( - "github.com/sirupsen/logrus" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleSQLDatabaseInstanceEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleSQLDatabaseInstanceEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleSQLDatabaseInstanceEnumerator { - return &GoogleSQLDatabaseInstanceEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleSQLDatabaseInstanceEnumerator) SupportedType() resource.ResourceType { - return google.GoogleSQLDatabaseInstanceResourceType -} - -func (e *GoogleSQLDatabaseInstanceEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllSQLDatabaseInstances() - - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - name, exist := res.GetResource().GetData().GetFields()["name"] - if !exist || name.GetStringValue() == "" { - logrus.WithField("name", res.GetName()).Warn("Unable to retrieve resource name") - continue - } - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - name.GetStringValue(), - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_storage_bucket_enumerator.go b/pkg/remote/google/google_storage_bucket_enumerator.go deleted file mode 100644 index a2f1ee10..00000000 --- a/pkg/remote/google/google_storage_bucket_enumerator.go +++ /dev/null @@ -1,47 +0,0 @@ -package google - -import ( - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleStorageBucketEnumerator struct { - repository repository.AssetRepository - factory resource.ResourceFactory -} - -func NewGoogleStorageBucketEnumerator(repo repository.AssetRepository, factory resource.ResourceFactory) *GoogleStorageBucketEnumerator { - return &GoogleStorageBucketEnumerator{ - repository: repo, - factory: factory, - } -} - -func (e *GoogleStorageBucketEnumerator) SupportedType() resource.ResourceType { - return google.GoogleStorageBucketResourceType -} - -func (e *GoogleStorageBucketEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllBuckets() - - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, res := range resources { - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - res.DisplayName, - map[string]interface{}{}, - ), - ) - } - - return results, err -} diff --git a/pkg/remote/google/google_storage_bucket_iam_member_enumerator.go b/pkg/remote/google/google_storage_bucket_iam_member_enumerator.go deleted file mode 100644 index badf0d81..00000000 --- a/pkg/remote/google/google_storage_bucket_iam_member_enumerator.go +++ /dev/null @@ -1,64 +0,0 @@ -package google - -import ( - "fmt" - - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" -) - -type GoogleStorageBucketIamMemberEnumerator struct { - repository repository.AssetRepository - storageRepository repository.StorageRepository - factory resource.ResourceFactory -} - -func NewGoogleStorageBucketIamMemberEnumerator(repo repository.AssetRepository, storageRepo repository.StorageRepository, factory resource.ResourceFactory) *GoogleStorageBucketIamMemberEnumerator { - return &GoogleStorageBucketIamMemberEnumerator{ - repository: repo, - storageRepository: storageRepo, - factory: factory, - } -} - -func (e *GoogleStorageBucketIamMemberEnumerator) SupportedType() resource.ResourceType { - return google.GoogleStorageBucketIamMemberResourceType -} - -func (e *GoogleStorageBucketIamMemberEnumerator) Enumerate() ([]*resource.Resource, error) { - resources, err := e.repository.SearchAllBuckets() - if err != nil { - return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), google.GoogleStorageBucketResourceType) - } - - results := make([]*resource.Resource, 0, len(resources)) - - for _, bucket := range resources { - bindings, err := e.storageRepository.ListAllBindings(bucket.DisplayName) - if err != nil { - return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) - } - for roleName, members := range bindings { - for _, member := range members { - id := fmt.Sprintf("b/%s/%s/%s", bucket.DisplayName, roleName, member) - results = append( - results, - e.factory.CreateAbstractResource( - string(e.SupportedType()), - id, - map[string]interface{}{ - "id": id, - "bucket": fmt.Sprintf("b/%s", bucket.DisplayName), - "role": roleName, - "member": member, - }, - ), - ) - } - } - } - - return results, err -} diff --git a/pkg/remote/google/init.go b/pkg/remote/google/init.go deleted file mode 100644 index caaaa642..00000000 --- a/pkg/remote/google/init.go +++ /dev/null @@ -1,119 +0,0 @@ -package google - -import ( - "context" - - asset "cloud.google.com/go/asset/apiv1" - "cloud.google.com/go/storage" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/resource/google" - "github.com/snyk/driftctl/pkg/terraform" - "google.golang.org/api/cloudresourcemanager/v1" -) - -func Init(version string, alerter *alerter.Alerter, - providerLibrary *terraform.ProviderLibrary, - remoteLibrary *common.RemoteLibrary, - progress output.Progress, - resourceSchemaRepository *resource.SchemaRepository, - factory resource.ResourceFactory, - configDir string) error { - - provider, err := NewGCPTerraformProvider(version, progress, configDir) - if err != nil { - return err - } - - err = provider.CheckCredentialsExist() - if err != nil { - return err - } - - err = provider.Init() - if err != nil { - return err - } - - repositoryCache := cache.New(100) - - ctx := context.Background() - assetClient, err := asset.NewClient(ctx) - if err != nil { - return err - } - - storageClient, err := storage.NewClient(ctx) - if err != nil { - return err - } - - crmService, err := cloudresourcemanager.NewService(ctx) - if err != nil { - return err - } - - assetRepository := repository.NewAssetRepository(assetClient, provider.GetConfig(), repositoryCache) - storageRepository := repository.NewStorageRepository(storageClient, repositoryCache) - iamRepository := repository.NewCloudResourceManagerRepository(crmService, provider.GetConfig(), repositoryCache) - - providerLibrary.AddProvider(terraform.GOOGLE, provider) - deserializer := resource.NewDeserializer(factory) - - remoteLibrary.AddEnumerator(NewGoogleStorageBucketEnumerator(assetRepository, factory)) - remoteLibrary.AddDetailsFetcher(google.GoogleStorageBucketResourceType, common.NewGenericDetailsFetcher(google.GoogleStorageBucketResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewGoogleComputeFirewallEnumerator(assetRepository, factory)) - remoteLibrary.AddDetailsFetcher(google.GoogleComputeFirewallResourceType, common.NewGenericDetailsFetcher(google.GoogleComputeFirewallResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewGoogleComputeRouterEnumerator(assetRepository, factory)) - - remoteLibrary.AddEnumerator(NewGoogleComputeInstanceEnumerator(assetRepository, factory)) - - remoteLibrary.AddEnumerator(NewGoogleProjectIamMemberEnumerator(iamRepository, factory)) - remoteLibrary.AddDetailsFetcher(google.GoogleProjectIamMemberResourceType, common.NewGenericDetailsFetcher(google.GoogleProjectIamMemberResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewGoogleStorageBucketIamMemberEnumerator(assetRepository, storageRepository, factory)) - remoteLibrary.AddDetailsFetcher(google.GoogleStorageBucketIamMemberResourceType, common.NewGenericDetailsFetcher(google.GoogleStorageBucketIamMemberResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewGoogleComputeNetworkEnumerator(assetRepository, factory)) - remoteLibrary.AddDetailsFetcher(google.GoogleComputeNetworkResourceType, common.NewGenericDetailsFetcher(google.GoogleComputeNetworkResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewGoogleComputeSubnetworkEnumerator(assetRepository, factory)) - remoteLibrary.AddDetailsFetcher(google.GoogleComputeSubnetworkResourceType, common.NewGenericDetailsFetcher(google.GoogleComputeSubnetworkResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewGoogleDNSManagedZoneEnumerator(assetRepository, factory)) - - remoteLibrary.AddEnumerator(NewGoogleComputeInstanceGroupEnumerator(assetRepository, factory)) - remoteLibrary.AddDetailsFetcher(google.GoogleComputeInstanceGroupResourceType, common.NewGenericDetailsFetcher(google.GoogleComputeInstanceGroupResourceType, provider, deserializer)) - - remoteLibrary.AddEnumerator(NewGoogleBigqueryDatasetEnumerator(assetRepository, factory)) - remoteLibrary.AddEnumerator(NewGoogleBigqueryTableEnumerator(assetRepository, factory)) - - remoteLibrary.AddEnumerator(NewGoogleComputeAddressEnumerator(assetRepository, factory)) - remoteLibrary.AddEnumerator(NewGoogleComputeGlobalAddressEnumerator(assetRepository, factory)) - remoteLibrary.AddEnumerator(NewGoogleCloudFunctionsFunctionEnumerator(assetRepository, factory)) - remoteLibrary.AddEnumerator(NewGoogleComputeDiskEnumerator(assetRepository, factory)) - remoteLibrary.AddEnumerator(NewGoogleComputeImageEnumerator(assetRepository, factory)) - remoteLibrary.AddEnumerator(NewGoogleBigTableInstanceEnumerator(assetRepository, factory)) - remoteLibrary.AddEnumerator(NewGoogleBigtableTableEnumerator(assetRepository, factory)) - remoteLibrary.AddEnumerator(NewGoogleSQLDatabaseInstanceEnumerator(assetRepository, factory)) - remoteLibrary.AddEnumerator(NewGoogleComputeHealthCheckEnumerator(assetRepository, factory)) - remoteLibrary.AddEnumerator(NewGoogleCloudRunServiceEnumerator(assetRepository, factory)) - remoteLibrary.AddEnumerator(NewGoogleComputeNodeGroupEnumerator(assetRepository, factory)) - remoteLibrary.AddEnumerator(NewGoogleComputeForwardingRuleEnumerator(assetRepository, factory)) - remoteLibrary.AddEnumerator(NewGoogleComputeInstanceGroupManagerEnumerator(assetRepository, factory)) - remoteLibrary.AddEnumerator(NewGoogleComputeGlobalForwardingRuleEnumerator(assetRepository, factory)) - - err = resourceSchemaRepository.Init(terraform.GOOGLE, provider.Version(), provider.Schema()) - if err != nil { - return err - } - google.InitResourcesMetadata(resourceSchemaRepository) - - return nil -} diff --git a/pkg/remote/google/provider.go b/pkg/remote/google/provider.go deleted file mode 100644 index 35995fce..00000000 --- a/pkg/remote/google/provider.go +++ /dev/null @@ -1,77 +0,0 @@ -package google - -import ( - "context" - "errors" - "os" - - asset "cloud.google.com/go/asset/apiv1" - "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/remote/google/config" - "github.com/snyk/driftctl/pkg/remote/terraform" - tf "github.com/snyk/driftctl/pkg/terraform" -) - -type GCPTerraformProvider struct { - *terraform.TerraformProvider - name string - version string -} - -func NewGCPTerraformProvider(version string, progress output.Progress, configDir string) (*GCPTerraformProvider, error) { - if version == "" { - version = "3.78.0" - } - p := &GCPTerraformProvider{ - version: version, - name: tf.GOOGLE, - } - installer, err := tf.NewProviderInstaller(tf.ProviderConfig{ - Key: p.name, - Version: version, - ConfigDir: configDir, - }) - if err != nil { - return nil, err - } - tfProvider, err := terraform.NewTerraformProvider(installer, terraform.TerraformProviderConfig{ - Name: p.name, - GetProviderConfig: func(alias string) interface{} { - return p.GetConfig() - }, - }, progress) - - if err != nil { - return nil, err - } - - p.TerraformProvider = tfProvider - - return p, err -} - -func (p *GCPTerraformProvider) Name() string { - return p.name -} - -func (p *GCPTerraformProvider) Version() string { - return p.version -} - -func (p *GCPTerraformProvider) GetConfig() config.GCPTerraformConfig { - return config.GCPTerraformConfig{ - Project: os.Getenv("CLOUDSDK_CORE_PROJECT"), - Region: os.Getenv("CLOUDSDK_COMPUTE_REGION"), - Zone: os.Getenv("CLOUDSDK_COMPUTE_ZONE"), - } -} - -func (p *GCPTerraformProvider) CheckCredentialsExist() error { - client, err := asset.NewClient(context.Background()) - if err != nil { - return errors.New("Please use a Service Account to authenticate on GCP.\n" + - "For more information: https://cloud.google.com/docs/authentication/production") - } - _ = client.Close() - return nil -} diff --git a/pkg/remote/google/repository/asset.go b/pkg/remote/google/repository/asset.go deleted file mode 100644 index 5e3474c3..00000000 --- a/pkg/remote/google/repository/asset.go +++ /dev/null @@ -1,282 +0,0 @@ -package repository - -import ( - "context" - "fmt" - - asset "cloud.google.com/go/asset/apiv1" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/google/config" - "google.golang.org/api/iterator" - assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" -) - -// https://cloud.google.com/asset-inventory/docs/supported-asset-types#supported_resource_types -const ( - storageBucketAssetType = "storage.googleapis.com/Bucket" - computeFirewallAssetType = "compute.googleapis.com/Firewall" - computeRouterAssetType = "compute.googleapis.com/Router" - computeInstanceAssetType = "compute.googleapis.com/Instance" - computeNetworkAssetType = "compute.googleapis.com/Network" - computeSubnetworkAssetType = "compute.googleapis.com/Subnetwork" - computeDiskAssetType = "compute.googleapis.com/Disk" - computeImageAssetType = "compute.googleapis.com/Image" - dnsManagedZoneAssetType = "dns.googleapis.com/ManagedZone" - computeInstanceGroupAssetType = "compute.googleapis.com/InstanceGroup" - bigqueryDatasetAssetType = "bigquery.googleapis.com/Dataset" - bigqueryTableAssetType = "bigquery.googleapis.com/Table" - computeAddressAssetType = "compute.googleapis.com/Address" - computeGlobalAddressAssetType = "compute.googleapis.com/GlobalAddress" - cloudFunctionsFunction = "cloudfunctions.googleapis.com/CloudFunction" - bigtableInstanceAssetType = "bigtableadmin.googleapis.com/Instance" - bigtableTableAssetType = "bigtableadmin.googleapis.com/Table" - sqlDatabaseInstanceAssetType = "sqladmin.googleapis.com/Instance" - healthCheckAssetType = "compute.googleapis.com/HealthCheck" - cloudRunServiceAssetType = "run.googleapis.com/Service" - nodeGroupAssetType = "compute.googleapis.com/NodeGroup" - computeForwardingRuleAssetType = "compute.googleapis.com/ForwardingRule" - instanceGroupManagerAssetType = "compute.googleapis.com/InstanceGroupManager" - computeGlobalForwardingRuleAssetType = "compute.googleapis.com/GlobalForwardingRule" -) - -type AssetRepository interface { - SearchAllBuckets() ([]*assetpb.ResourceSearchResult, error) - SearchAllFirewalls() ([]*assetpb.ResourceSearchResult, error) - SearchAllRouters() ([]*assetpb.ResourceSearchResult, error) - SearchAllInstances() ([]*assetpb.ResourceSearchResult, error) - SearchAllNetworks() ([]*assetpb.ResourceSearchResult, error) - SearchAllDisks() ([]*assetpb.ResourceSearchResult, error) - SearchAllImages() ([]*assetpb.ResourceSearchResult, error) - SearchAllDNSManagedZones() ([]*assetpb.ResourceSearchResult, error) - SearchAllInstanceGroups() ([]*assetpb.ResourceSearchResult, error) - SearchAllDatasets() ([]*assetpb.ResourceSearchResult, error) - SearchAllTables() ([]*assetpb.ResourceSearchResult, error) - SearchAllAddresses() ([]*assetpb.ResourceSearchResult, error) - SearchAllGlobalAddresses() ([]*assetpb.Asset, error) - SearchAllFunctions() ([]*assetpb.Asset, error) - SearchAllSubnetworks() ([]*assetpb.ResourceSearchResult, error) - SearchAllBigtableInstances() ([]*assetpb.Asset, error) - SearchAllBigtableTables() ([]*assetpb.Asset, error) - SearchAllSQLDatabaseInstances() ([]*assetpb.Asset, error) - SearchAllHealthChecks() ([]*assetpb.ResourceSearchResult, error) - SearchAllCloudRunServices() ([]*assetpb.ResourceSearchResult, error) - SearchAllNodeGroups() ([]*assetpb.Asset, error) - SearchAllForwardingRules() ([]*assetpb.Asset, error) - SearchAllInstanceGroupManagers() ([]*assetpb.Asset, error) - SearchAllGlobalForwardingRules() ([]*assetpb.Asset, error) -} - -type assetRepository struct { - client *asset.Client - config config.GCPTerraformConfig - cache cache.Cache -} - -func NewAssetRepository(client *asset.Client, config config.GCPTerraformConfig, c cache.Cache) *assetRepository { - return &assetRepository{ - client, - config, - c, - } -} - -func (s assetRepository) listAllResources(ty string) ([]*assetpb.Asset, error) { - req := &assetpb.ListAssetsRequest{ - Parent: fmt.Sprintf("projects/%s", s.config.Project), - ContentType: assetpb.ContentType_RESOURCE, - AssetTypes: []string{ - cloudFunctionsFunction, - bigtableInstanceAssetType, - bigtableTableAssetType, - sqlDatabaseInstanceAssetType, - computeGlobalAddressAssetType, - nodeGroupAssetType, - computeForwardingRuleAssetType, - instanceGroupManagerAssetType, - computeGlobalForwardingRuleAssetType, - }, - } - var results []*assetpb.Asset - - cacheKey := "listAllResources" - cachedResults := s.cache.GetAndLock(cacheKey) - defer s.cache.Unlock(cacheKey) - if cachedResults != nil { - results = cachedResults.([]*assetpb.Asset) - } - - if results == nil { - it := s.client.ListAssets(context.Background(), req) - for { - resource, err := it.Next() - if err == iterator.Done { - break - } - if err != nil { - return nil, err - } - results = append(results, resource) - } - s.cache.Put(cacheKey, results) - } - - filteredResults := []*assetpb.Asset{} - for _, result := range results { - if result.AssetType == ty { - filteredResults = append(filteredResults, result) - } - } - - return filteredResults, nil -} - -func (s assetRepository) searchAllResources(ty string) ([]*assetpb.ResourceSearchResult, error) { - req := &assetpb.SearchAllResourcesRequest{ - Scope: fmt.Sprintf("projects/%s", s.config.Project), - AssetTypes: []string{ - storageBucketAssetType, - computeFirewallAssetType, - computeRouterAssetType, - computeInstanceAssetType, - computeNetworkAssetType, - computeSubnetworkAssetType, - dnsManagedZoneAssetType, - computeInstanceGroupAssetType, - bigqueryDatasetAssetType, - bigqueryTableAssetType, - computeAddressAssetType, - computeDiskAssetType, - computeImageAssetType, - healthCheckAssetType, - cloudRunServiceAssetType, - }, - } - var results []*assetpb.ResourceSearchResult - - cacheKey := "SearchAllResources" - cachedResults := s.cache.GetAndLock(cacheKey) - defer s.cache.Unlock(cacheKey) - if cachedResults != nil { - results = cachedResults.([]*assetpb.ResourceSearchResult) - } - - if results == nil { - it := s.client.SearchAllResources(context.Background(), req) - for { - resource, err := it.Next() - if err == iterator.Done { - break - } - if err != nil { - return nil, err - } - results = append(results, resource) - } - s.cache.Put(cacheKey, results) - } - - filteredResults := []*assetpb.ResourceSearchResult{} - for _, result := range results { - if result.AssetType == ty { - filteredResults = append(filteredResults, result) - } - } - - return filteredResults, nil -} - -func (s assetRepository) SearchAllBuckets() ([]*assetpb.ResourceSearchResult, error) { - return s.searchAllResources(storageBucketAssetType) -} - -func (s assetRepository) SearchAllFirewalls() ([]*assetpb.ResourceSearchResult, error) { - return s.searchAllResources(computeFirewallAssetType) -} - -func (s assetRepository) SearchAllRouters() ([]*assetpb.ResourceSearchResult, error) { - return s.searchAllResources(computeRouterAssetType) -} - -func (s assetRepository) SearchAllInstances() ([]*assetpb.ResourceSearchResult, error) { - return s.searchAllResources(computeInstanceAssetType) -} - -func (s assetRepository) SearchAllNetworks() ([]*assetpb.ResourceSearchResult, error) { - return s.searchAllResources(computeNetworkAssetType) -} - -func (s assetRepository) SearchAllDNSManagedZones() ([]*assetpb.ResourceSearchResult, error) { - return s.searchAllResources(dnsManagedZoneAssetType) -} - -func (s assetRepository) SearchAllInstanceGroups() ([]*assetpb.ResourceSearchResult, error) { - return s.searchAllResources(computeInstanceGroupAssetType) -} - -func (s assetRepository) SearchAllDatasets() ([]*assetpb.ResourceSearchResult, error) { - return s.searchAllResources(bigqueryDatasetAssetType) -} - -func (s assetRepository) SearchAllTables() ([]*assetpb.ResourceSearchResult, error) { - return s.searchAllResources(bigqueryTableAssetType) -} - -func (s assetRepository) SearchAllAddresses() ([]*assetpb.ResourceSearchResult, error) { - return s.searchAllResources(computeAddressAssetType) -} - -func (s assetRepository) SearchAllGlobalAddresses() ([]*assetpb.Asset, error) { - return s.listAllResources(computeGlobalAddressAssetType) -} - -func (s assetRepository) SearchAllFunctions() ([]*assetpb.Asset, error) { - return s.listAllResources(cloudFunctionsFunction) -} - -func (s assetRepository) SearchAllSubnetworks() ([]*assetpb.ResourceSearchResult, error) { - return s.searchAllResources(computeSubnetworkAssetType) -} - -func (s assetRepository) SearchAllDisks() ([]*assetpb.ResourceSearchResult, error) { - return s.searchAllResources(computeDiskAssetType) -} - -func (s assetRepository) SearchAllImages() ([]*assetpb.ResourceSearchResult, error) { - return s.searchAllResources(computeImageAssetType) -} - -func (s assetRepository) SearchAllBigtableInstances() ([]*assetpb.Asset, error) { - return s.listAllResources(bigtableInstanceAssetType) -} - -func (s assetRepository) SearchAllBigtableTables() ([]*assetpb.Asset, error) { - return s.listAllResources(bigtableTableAssetType) -} - -func (s assetRepository) SearchAllSQLDatabaseInstances() ([]*assetpb.Asset, error) { - return s.listAllResources(sqlDatabaseInstanceAssetType) -} - -func (s assetRepository) SearchAllHealthChecks() ([]*assetpb.ResourceSearchResult, error) { - return s.searchAllResources(healthCheckAssetType) -} - -func (s assetRepository) SearchAllCloudRunServices() ([]*assetpb.ResourceSearchResult, error) { - return s.searchAllResources(cloudRunServiceAssetType) -} - -func (s assetRepository) SearchAllNodeGroups() ([]*assetpb.Asset, error) { - return s.listAllResources(nodeGroupAssetType) -} - -func (s assetRepository) SearchAllForwardingRules() ([]*assetpb.Asset, error) { - return s.listAllResources(computeForwardingRuleAssetType) -} - -func (s assetRepository) SearchAllInstanceGroupManagers() ([]*assetpb.Asset, error) { - return s.listAllResources(instanceGroupManagerAssetType) -} - -func (s assetRepository) SearchAllGlobalForwardingRules() ([]*assetpb.Asset, error) { - return s.listAllResources(computeGlobalForwardingRuleAssetType) -} diff --git a/pkg/remote/google/repository/asset_test.go b/pkg/remote/google/repository/asset_test.go deleted file mode 100644 index 4db980a6..00000000 --- a/pkg/remote/google/repository/asset_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package repository - -import ( - "testing" - - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/google/config" - "github.com/snyk/driftctl/test/google" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" -) - -func Test_assetRepository_searchAllResources_CacheHit(t *testing.T) { - - expectedResults := []*assetpb.ResourceSearchResult{ - { - AssetType: "google_fake_type", - DisplayName: "driftctl-unittest-1", - }, - { - AssetType: "google_another_fake_type", - DisplayName: "driftctl-unittest-1", - }, - } - - c := &cache.MockCache{} - c.On("GetAndLock", "SearchAllResources").Return(expectedResults).Times(1) - c.On("Unlock", "SearchAllResources").Times(1) - repo := NewAssetRepository(nil, config.GCPTerraformConfig{Project: ""}, c) - - got, err := repo.searchAllResources("google_fake_type") - c.AssertExpectations(t) - assert.Nil(t, err) - assert.Len(t, got, 1) -} - -func Test_assetRepository_searchAllResources_CacheMiss(t *testing.T) { - - expectedResults := []*assetpb.ResourceSearchResult{ - { - AssetType: "google_fake_type", - DisplayName: "driftctl-unittest-1", - }, - { - AssetType: "google_another_fake_type", - DisplayName: "driftctl-unittest-1", - }, - } - assetClient, err := google.NewFakeAssetServer(expectedResults, nil) - if err != nil { - t.Fatal(err) - } - c := &cache.MockCache{} - c.On("GetAndLock", "SearchAllResources").Return(nil).Times(1) - c.On("Unlock", "SearchAllResources").Times(1) - c.On("Put", "SearchAllResources", mock.IsType([]*assetpb.ResourceSearchResult{})).Return(false).Times(1) - repo := NewAssetRepository(assetClient, config.GCPTerraformConfig{Project: ""}, c) - - got, err := repo.searchAllResources("google_fake_type") - c.AssertExpectations(t) - assert.Nil(t, err) - assert.Len(t, got, 1) -} diff --git a/pkg/remote/google/repository/cloudresourcemanager.go b/pkg/remote/google/repository/cloudresourcemanager.go deleted file mode 100644 index 1c7f4e3b..00000000 --- a/pkg/remote/google/repository/cloudresourcemanager.go +++ /dev/null @@ -1,50 +0,0 @@ -package repository - -import ( - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/google/config" - "google.golang.org/api/cloudresourcemanager/v1" -) - -type CloudResourceManagerRepository interface { - ListProjectsBindings() (map[string]map[string][]string, error) -} - -type cloudResourceManagerRepository struct { - service *cloudresourcemanager.Service - config config.GCPTerraformConfig - cache cache.Cache -} - -func NewCloudResourceManagerRepository(service *cloudresourcemanager.Service, config config.GCPTerraformConfig, cache cache.Cache) CloudResourceManagerRepository { - return &cloudResourceManagerRepository{ - service: service, - config: config, - cache: cache, - } -} - -func (s *cloudResourceManagerRepository) ListProjectsBindings() (map[string]map[string][]string, error) { - if cachedResults := s.cache.Get("ListProjectsBindings"); cachedResults != nil { - return cachedResults.(map[string]map[string][]string), nil - } - - request := new(cloudresourcemanager.GetIamPolicyRequest) - policy, err := s.service.Projects.GetIamPolicy(s.config.Project, request).Do() - if err != nil { - return nil, err - } - - bindings := make(map[string][]string) - - for _, binding := range policy.Bindings { - bindings[binding.Role] = binding.Members - } - - bindingsByProject := make(map[string]map[string][]string) - bindingsByProject[s.config.Project] = bindings - - s.cache.Put("ListProjectsBindings", bindingsByProject) - - return bindingsByProject, nil -} diff --git a/pkg/remote/google/repository/storage.go b/pkg/remote/google/repository/storage.go deleted file mode 100644 index 3679c131..00000000 --- a/pkg/remote/google/repository/storage.go +++ /dev/null @@ -1,52 +0,0 @@ -package repository - -import ( - "context" - "fmt" - "sync" - - "cloud.google.com/go/storage" - "github.com/snyk/driftctl/pkg/remote/cache" -) - -type StorageRepository interface { - ListAllBindings(bucketName string) (map[string][]string, error) -} - -type storageRepository struct { - client *storage.Client - cache cache.Cache - lock sync.Locker -} - -func NewStorageRepository(client *storage.Client, cache cache.Cache) *storageRepository { - return &storageRepository{ - client: client, - cache: cache, - lock: &sync.Mutex{}, - } -} - -func (s storageRepository) ListAllBindings(bucketName string) (map[string][]string, error) { - - s.lock.Lock() - defer s.lock.Unlock() - if cachedResults := s.cache.Get(fmt.Sprintf("%s-%s", "ListAllBindings", bucketName)); cachedResults != nil { - return cachedResults.(map[string][]string), nil - } - - bucket := s.client.Bucket(bucketName) - policy, err := bucket.IAM().Policy(context.Background()) - if err != nil { - return nil, err - } - bindings := make(map[string][]string) - for _, name := range policy.Roles() { - members := policy.Members(name) - bindings[string(name)] = members - } - - s.cache.Put("ListAllBindings", bindings) - - return bindings, nil -} diff --git a/pkg/remote/google_bigquery_scanner_test.go b/pkg/remote/google_bigquery_scanner_test.go deleted file mode 100644 index c37e4287..00000000 --- a/pkg/remote/google_bigquery_scanner_test.go +++ /dev/null @@ -1,231 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - googleresource "github.com/snyk/driftctl/pkg/resource/google" - "github.com/snyk/driftctl/pkg/terraform" - testgoogle "github.com/snyk/driftctl/test/google" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -func TestGoogleBigqueryDataset(t *testing.T) { - - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.ResourceSearchResult - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no dataset", - response: []*assetpb.ResourceSearchResult{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiples dataset", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - assert.Equal(t, "projects/cloudskiff-dev-elie/datasets/example_dataset", got[0].ResourceId()) - assert.Equal(t, "google_bigquery_dataset", got[0].ResourceType()) - }, - response: []*assetpb.ResourceSearchResult{ - { - AssetType: "bigquery.googleapis.com/Dataset", - Name: "//bigquery.googleapis.com/projects/cloudskiff-dev-elie/datasets/example_dataset", - }, - }, - }, - { - test: "cannot list datasets", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_bigquery_dataset", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_bigquery_dataset", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleBigqueryDatasetEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} - -func TestGoogleBigqueryTable(t *testing.T) { - - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.ResourceSearchResult - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no table", - response: []*assetpb.ResourceSearchResult{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiples table", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - assert.Equal(t, "projects/cloudskiff-dev-elie/datasets/example_dataset/tables/bar", got[0].ResourceId()) - assert.Equal(t, "google_bigquery_table", got[0].ResourceType()) - }, - response: []*assetpb.ResourceSearchResult{ - { - AssetType: "bigquery.googleapis.com/Table", - Name: "//bigquery.googleapis.com/projects/cloudskiff-dev-elie/datasets/example_dataset/tables/bar", - }, - }, - }, - { - test: "cannot list table", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_bigquery_table", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_bigquery_table", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleBigqueryTableEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} diff --git a/pkg/remote/google_bigtable_scanner_test.go b/pkg/remote/google_bigtable_scanner_test.go deleted file mode 100644 index f75798cd..00000000 --- a/pkg/remote/google_bigtable_scanner_test.go +++ /dev/null @@ -1,291 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - googleresource "github.com/snyk/driftctl/pkg/resource/google" - "github.com/snyk/driftctl/pkg/terraform" - testgoogle "github.com/snyk/driftctl/test/google" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/structpb" -) - -func TestGoogleBigtableInstance(t *testing.T) { - - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.Asset - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no instance", - response: []*assetpb.Asset{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "one instance returned", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - assert.Equal(t, "projects/cloudskiff-dev-elie/instances/tf-instance", got[0].ResourceId()) - assert.Equal(t, "google_bigtable_instance", got[0].ResourceType()) - }, - response: []*assetpb.Asset{ - { - AssetType: "bigtableadmin.googleapis.com/Instance", - Name: "//bigtable.googleapis.com/projects/cloudskiff-dev-elie/instances/tf-instance", - Resource: &assetpb.Resource{ - Data: func() *structpb.Struct { - v, err := structpb.NewStruct(map[string]interface{}{ - "name": "projects/cloudskiff-dev-elie/instances/tf-instance", - }) - if err != nil { - t.Fatal(err) - } - return v - }(), - }, - }, - }, - }, - { - test: "one instance without resource data", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - response: []*assetpb.Asset{ - { - AssetType: "bigtableadmin.googleapis.com/Instance", - Name: "//bigtable.googleapis.com/projects/cloudskiff-dev-elie/instances/tf-instance", - }, - { - AssetType: "bigtableadmin.googleapis.com/Instance", - Name: "//bigtable.googleapis.com/projects/cloudskiff-dev-elie/instances/tf-instance-2", - Resource: &assetpb.Resource{ - Data: func() *structpb.Struct { - v, err := structpb.NewStruct(map[string]interface{}{}) - if err != nil { - t.Fatal(err) - } - return v - }(), - }, - }, - }, - }, - { - test: "cannot list instances", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_bigtable_instance", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_bigtable_instance", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleBigTableInstanceEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} - -func TestGoogleBigtableTable(t *testing.T) { - - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.Asset - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no table", - response: []*assetpb.Asset{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "one resource returned", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - assert.Equal(t, "projects/cloudskiff-dev-elie/instances/tf-instance/tables/tf-table", got[0].ResourceId()) - assert.Equal(t, "google_bigtable_table", got[0].ResourceType()) - }, - response: []*assetpb.Asset{ - { - AssetType: "bigtableadmin.googleapis.com/Table", - Name: "//bigtable.googleapis.com/projects/cloudskiff-dev-elie/instances/tf-instance/tables/tf-table", - Resource: &assetpb.Resource{ - Data: func() *structpb.Struct { - v, err := structpb.NewStruct(map[string]interface{}{ - "name": "projects/cloudskiff-dev-elie/instances/tf-instance/tables/tf-table", - }) - if err != nil { - t.Fatal(err) - } - return v - }(), - }, - }, - }, - }, - { - test: "one resource without resource data", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - response: []*assetpb.Asset{ - { - AssetType: "bigtableadmin.googleapis.com/Table", - Name: "//bigtable.googleapis.com/projects/cloudskiff-dev-elie/instances/tf-instance/tables/tf-table", - }, - }, - }, - { - test: "cannot list cloud functions", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_bigtable_table", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_bigtable_table", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleBigtableTableEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} diff --git a/pkg/remote/google_cloudfunctions_scanner_test.go b/pkg/remote/google_cloudfunctions_scanner_test.go deleted file mode 100644 index bec5fb24..00000000 --- a/pkg/remote/google_cloudfunctions_scanner_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - googleresource "github.com/snyk/driftctl/pkg/resource/google" - "github.com/snyk/driftctl/pkg/terraform" - testgoogle "github.com/snyk/driftctl/test/google" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/structpb" -) - -func TestGoogleCloudFunctionsFunction(t *testing.T) { - - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.Asset - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no compute instance", - response: []*assetpb.Asset{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "one cloud function returned", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - assert.Equal(t, "projects/cloudskiff-dev-elie/locations/us-central1/functions/function-test", got[0].ResourceId()) - assert.Equal(t, "google_cloudfunctions_function", got[0].ResourceType()) - }, - response: []*assetpb.Asset{ - { - AssetType: "cloudfunctions.googleapis.com/CloudFunction", - Name: "//cloudfunctions.googleapis.com/projects/cloudskiff-dev-elie/locations/us-central1/functions/function-test", - Resource: &assetpb.Resource{ - Data: func() *structpb.Struct { - v, err := structpb.NewStruct(map[string]interface{}{ - "name": "projects/cloudskiff-dev-elie/locations/us-central1/functions/function-test", - }) - if err != nil { - t.Fatal(err) - } - return v - }(), - }, - }, - }, - }, - { - test: "one cloud function without resource data", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - response: []*assetpb.Asset{ - { - AssetType: "cloudfunctions.googleapis.com/CloudFunction", - Name: "//cloudfunctions.googleapis.com/projects/cloudskiff-dev-elie/locations/us-central1/functions/function-test", - }, - }, - }, - { - test: "cannot list cloud functions", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_cloudfunctions_function", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_cloudfunctions_function", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleCloudFunctionsFunctionEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} diff --git a/pkg/remote/google_cloudrun_scanner_test.go b/pkg/remote/google_cloudrun_scanner_test.go deleted file mode 100644 index 032752ec..00000000 --- a/pkg/remote/google_cloudrun_scanner_test.go +++ /dev/null @@ -1,148 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - googleresource "github.com/snyk/driftctl/pkg/resource/google" - "github.com/snyk/driftctl/pkg/terraform" - testgoogle "github.com/snyk/driftctl/test/google" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -func TestGoogleCloudRunService(t *testing.T) { - - cases := []struct { - test string - response []*assetpb.ResourceSearchResult - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - assertExpected func(t *testing.T, got []*resource.Resource) - }{ - { - test: "no resource", - response: []*assetpb.ResourceSearchResult{}, - wantErr: nil, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiples resources", - response: []*assetpb.ResourceSearchResult{ - { - AssetType: "run.googleapis.com/Service", - Name: "invalid ID", // Should be ignored - }, - { - AssetType: "run.googleapis.com/Service", - DisplayName: "cloudrun-srv-1", - Name: "//run.googleapis.com/projects/cloudskiff-dev-elie/locations/us-central1/services/cloudrun-srv-1", - Location: "us-central1", - }, - { - AssetType: "run.googleapis.com/Service", - DisplayName: "cloudrun-srv-2", - Name: "//run.googleapis.com/projects/cloudskiff-dev-elie/locations/us-central1/services/cloudrun-srv-2", - Location: "us-central1", - }, - }, - wantErr: nil, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - - assert.Equal(t, got[0].ResourceId(), "locations/us-central1/namespaces/cloudskiff-dev-elie/services/cloudrun-srv-1") - assert.Equal(t, got[0].ResourceType(), googleresource.GoogleCloudRunServiceResourceType) - - assert.Equal(t, got[1].ResourceId(), "locations/us-central1/namespaces/cloudskiff-dev-elie/services/cloudrun-srv-2") - assert.Equal(t, got[1].ResourceType(), googleresource.GoogleCloudRunServiceResourceType) - }, - }, - { - test: "should return access denied error", - wantErr: nil, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - googleresource.GoogleCloudRunServiceResourceType, - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - googleresource.GoogleCloudRunServiceResourceType, - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleCloudRunServiceEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} diff --git a/pkg/remote/google_compute_scanner_test.go b/pkg/remote/google_compute_scanner_test.go deleted file mode 100644 index bd2fa0bc..00000000 --- a/pkg/remote/google_compute_scanner_test.go +++ /dev/null @@ -1,1763 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - googleresource "github.com/snyk/driftctl/pkg/resource/google" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testgoogle "github.com/snyk/driftctl/test/google" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/structpb" -) - -func TestGoogleComputeFirewall(t *testing.T) { - - cases := []struct { - test string - dirName string - response []*assetpb.ResourceSearchResult - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no compute firewall", - dirName: "google_compute_firewall_empty", - response: []*assetpb.ResourceSearchResult{}, - wantErr: nil, - }, - { - test: "multiples compute firewall", - dirName: "google_compute_firewall", - response: []*assetpb.ResourceSearchResult{ - { - AssetType: "compute.googleapis.com/Firewall", - DisplayName: "test-firewall-0", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/global/firewalls/test-firewall-0", - }, - { - AssetType: "compute.googleapis.com/Firewall", - DisplayName: "test-firewall-1", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/global/firewalls/test-firewall-1", - }, - { - AssetType: "compute.googleapis.com/Firewall", - DisplayName: "test-firewall-2", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/global/firewalls/test-firewall-2", - }, - }, - wantErr: nil, - }, - { - test: "cannot list compute firewall", - dirName: "google_compute_firewall_empty", - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_compute_firewall", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_compute_firewall", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - wantErr: nil, - }, - } - - providerVersion := "3.78.0" - resType := resource.ResourceType(googleresource.GoogleComputeFirewallResourceType) - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err = realProvider.Init() - if err != nil { - tt.Fatal(err) - } - provider.ShouldUpdate() - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleComputeFirewallEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resType, common.NewGenericDetailsFetcher(resType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - test.TestAgainstGoldenFile(got, resType.String(), c.dirName, provider, deserializer, shouldUpdate, tt) - }) - } -} - -func TestGoogleComputeRouter(t *testing.T) { - - cases := []struct { - test string - response []*assetpb.ResourceSearchResult - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - assertExpected func(t *testing.T, got []*resource.Resource) - }{ - { - test: "no compute router", - response: []*assetpb.ResourceSearchResult{}, - wantErr: nil, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiples compute routers", - response: []*assetpb.ResourceSearchResult{ - { - AssetType: "compute.googleapis.com/Router", - DisplayName: "test-router-0", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/regions/us-central1/routers/test-router-0", - }, - { - AssetType: "compute.googleapis.com/Router", - DisplayName: "test-router-1", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/regions/us-central1/routers/test-router-1", - }, - { - AssetType: "compute.googleapis.com/Router", - DisplayName: "test-router-2", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/regions/us-central1/routers/test-router-2", - }, - }, - wantErr: nil, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 3) - - assert.Equal(t, got[0].ResourceId(), "projects/cloudskiff-dev-raphael/regions/us-central1/routers/test-router-0") - assert.Equal(t, got[0].ResourceType(), googleresource.GoogleComputeRouterResourceType) - - assert.Equal(t, got[1].ResourceId(), "projects/cloudskiff-dev-raphael/regions/us-central1/routers/test-router-1") - assert.Equal(t, got[1].ResourceType(), googleresource.GoogleComputeRouterResourceType) - - assert.Equal(t, got[2].ResourceId(), "projects/cloudskiff-dev-raphael/regions/us-central1/routers/test-router-2") - assert.Equal(t, got[2].ResourceType(), googleresource.GoogleComputeRouterResourceType) - }, - }, - { - test: "should return access denied error", - wantErr: nil, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - googleresource.GoogleComputeRouterResourceType, - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - googleresource.GoogleComputeRouterResourceType, - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleComputeRouterEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} - -func TestGoogleComputeInstance(t *testing.T) { - - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.ResourceSearchResult - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no compute instance", - response: []*assetpb.ResourceSearchResult{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiples compute instances", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - assert.Equal(t, "projects/cloudskiff-dev-elie/zones/us-central1-a/instances/test", got[0].ResourceId()) - assert.Equal(t, "google_compute_instance", got[0].ResourceType()) - }, - response: []*assetpb.ResourceSearchResult{ - { - AssetType: "compute.googleapis.com/Instance", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/zones/us-central1-a/instances/test", - }, - }, - }, - { - test: "cannot list compute firewall", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_compute_instance", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_compute_instance", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleComputeInstanceEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} - -func TestGoogleComputeNetwork(t *testing.T) { - - cases := []struct { - test string - dirName string - response []*assetpb.ResourceSearchResult - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no network", - dirName: "google_compute_network_empty", - response: []*assetpb.ResourceSearchResult{}, - wantErr: nil, - }, - { - test: "multiple networks", - dirName: "google_compute_network", - response: []*assetpb.ResourceSearchResult{ - { - AssetType: "compute.googleapis.com/Network", - DisplayName: "driftctl-unittest-1", - Name: "//compute.googleapis.com/projects/driftctl-qa-1/global/networks/driftctl-unittest-1", - }, - { - AssetType: "compute.googleapis.com/Network", - DisplayName: "driftctl-unittest-2", - Name: "//compute.googleapis.com/projects/driftctl-qa-1/global/networks/driftctl-unittest-2", - }, - { - AssetType: "compute.googleapis.com/Network", - DisplayName: "driftctl-unittest-3", - Name: "//compute.googleapis.com/projects/driftctl-qa-1/global/networks/driftctl-unittest-3", - }, - }, - wantErr: nil, - }, - { - test: "cannot list compute networks", - dirName: "google_compute_network_empty", - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_compute_network", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_compute_network", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - wantErr: nil, - }, - } - - providerVersion := "3.78.0" - resType := resource.ResourceType(googleresource.GoogleComputeNetworkResourceType) - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err = realProvider.Init() - if err != nil { - tt.Fatal(err) - } - provider.ShouldUpdate() - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleComputeNetworkEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resType, common.NewGenericDetailsFetcher(resType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - test.TestAgainstGoldenFile(got, resType.String(), c.dirName, provider, deserializer, shouldUpdate, tt) - }) - } -} - -func TestGoogleComputeInstanceGroup(t *testing.T) { - - cases := []struct { - test string - dirName string - response []*assetpb.ResourceSearchResult - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no instance group", - dirName: "google_compute_instance_group_empty", - response: []*assetpb.ResourceSearchResult{}, - wantErr: nil, - }, - { - test: "multiple instance groups", - dirName: "google_compute_instance_group", - response: []*assetpb.ResourceSearchResult{ - { - AssetType: "compute.googleapis.com/InstanceGroup", - DisplayName: "driftctl-test-1", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroups/driftctl-test-1", - Project: "cloudskiff-dev-raphael", - Location: "us-central1-a", - }, - { - AssetType: "compute.googleapis.com/InstanceGroup", - DisplayName: "driftctl-test-2", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroups/driftctl-test-2", - Project: "cloudskiff-dev-raphael", - Location: "us-central1-a", - }, - }, - wantErr: nil, - }, - { - test: "cannot list instance groups", - dirName: "google_compute_instance_group_empty", - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_compute_instance_group", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_compute_instance_group", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - wantErr: nil, - }, - } - - providerVersion := "3.78.0" - resType := resource.ResourceType(googleresource.GoogleComputeInstanceGroupResourceType) - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err = realProvider.Init() - if err != nil { - tt.Fatal(err) - } - provider.ShouldUpdate() - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleComputeInstanceGroupEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(googleresource.GoogleComputeInstanceGroupResourceType, common.NewGenericDetailsFetcher(googleresource.GoogleComputeInstanceGroupResourceType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - test.TestAgainstGoldenFile(got, resType.String(), c.dirName, provider, deserializer, shouldUpdate, tt) - }) - } -} - -func TestGoogleComputeAddress(t *testing.T) { - - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.ResourceSearchResult - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no compute address", - response: []*assetpb.ResourceSearchResult{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiples compute address", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - assert.Equal(t, "projects/cloudskiff-dev-elie/regions/us-central1/addresses/my-address", got[0].ResourceId()) - assert.Equal(t, "google_compute_address", got[0].ResourceType()) - - assert.Equal(t, "projects/cloudskiff-dev-elie/regions/us-central1/addresses/my-address-2", got[1].ResourceId()) - assert.Equal(t, "google_compute_address", got[1].ResourceType()) - assert.Equal(t, "1.2.3.4", *got[1].Attributes().GetString("address")) - }, - response: []*assetpb.ResourceSearchResult{ - { - AssetType: "compute.googleapis.com/Address", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/regions/us-central1/addresses/my-address", - }, - { - AssetType: "compute.googleapis.com/Address", - Location: "global", // Global addresses should be ignored - }, - { - AssetType: "compute.googleapis.com/Address", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/regions/us-central1/addresses/my-address-2", - AdditionalAttributes: func() *structpb.Struct { - str, _ := structpb.NewStruct(map[string]interface{}{ - "address": "1.2.3.4", - }) - return str - }(), - }, - }, - }, - { - test: "cannot list compute address", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_compute_address", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_compute_address", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleComputeAddressEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} - -func TestGoogleComputeGlobalAddress(t *testing.T) { - - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.Asset - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no resource", - response: []*assetpb.Asset{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "one resource returned", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - assert.Equal(t, "projects/cloudskiff-dev-elie/global/addresses/global-appserver-ip", got[0].ResourceId()) - assert.Equal(t, "google_compute_global_address", got[0].ResourceType()) - }, - response: []*assetpb.Asset{ - { - AssetType: "compute.googleapis.com/GlobalAddress", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/global/addresses/global-appserver-ip", - Resource: &assetpb.Resource{ - Data: func() *structpb.Struct { - v, err := structpb.NewStruct(map[string]interface{}{ - "name": "projects/cloudskiff-dev-elie/global/addresses/global-appserver-ip", - }) - if err != nil { - t.Fatal(err) - } - return v - }(), - }, - }, - }, - }, - { - test: "one resource without resource data", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - response: []*assetpb.Asset{ - { - AssetType: "compute.googleapis.com/GlobalAddress", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/global/addresses/global-appserver-ip", - }, - }, - }, - { - test: "cannot list cloud functions", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_compute_global_address", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_compute_global_address", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleComputeGlobalAddressEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} - -func TestGoogleComputeSubnetwork(t *testing.T) { - - cases := []struct { - test string - dirName string - response []*assetpb.ResourceSearchResult - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no subnetwork", - dirName: "google_compute_subnetwork_empty", - response: []*assetpb.ResourceSearchResult{}, - wantErr: nil, - }, - { - test: "multiple subnetworks", - dirName: "google_compute_subnetwork_multiple", - response: []*assetpb.ResourceSearchResult{ - { - AssetType: "compute.googleapis.com/Subnetwork", - DisplayName: "driftctl-unittest-1", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/regions/us-central1/subnetworks/driftctl-unittest-1", - }, - { - AssetType: "compute.googleapis.com/Subnetwork", - DisplayName: "driftctl-unittest-2", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/regions/us-central1/subnetworks/driftctl-unittest-2", - }, - { - AssetType: "compute.googleapis.com/Subnetwork", - DisplayName: "driftctl-unittest-3", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/regions/us-central1/subnetworks/driftctl-unittest-3", - }, - }, - wantErr: nil, - }, - { - test: "cannot list compute subnetworks", - dirName: "google_compute_subnetwork_empty", - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_compute_subnetwork", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_compute_subnetwork", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - wantErr: nil, - }, - } - - providerVersion := "3.78.0" - resType := resource.ResourceType(googleresource.GoogleComputeSubnetworkResourceType) - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - err = realProvider.Init() - if err != nil { - tt.Fatal(err) - } - provider.ShouldUpdate() - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleComputeSubnetworkEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resType, common.NewGenericDetailsFetcher(resType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - test.TestAgainstGoldenFile(got, resType.String(), c.dirName, provider, deserializer, shouldUpdate, tt) - }) - } -} - -func TestGoogleComputeDisk(t *testing.T) { - - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.ResourceSearchResult - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no compute disk", - response: []*assetpb.ResourceSearchResult{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiples compute disk", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - assert.Equal(t, "projects/cloudskiff-dev-elie/zones/us-central1-a/disks/test-disk", got[0].ResourceId()) - assert.Equal(t, "google_compute_disk", got[0].ResourceType()) - - assert.Equal(t, "projects/cloudskiff-dev-elie/zones/us-central1-a/disks/test-disk-2", got[1].ResourceId()) - assert.Equal(t, "google_compute_disk", got[1].ResourceType()) - }, - response: []*assetpb.ResourceSearchResult{ - { - AssetType: "compute.googleapis.com/Disk", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/zones/us-central1-a/disks/test-disk", - }, - { - AssetType: "compute.googleapis.com/Disk", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/zones/us-central1-a/disks/test-disk-2", - }, - }, - }, - { - test: "cannot list compute disk", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_compute_disk", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_compute_disk", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleComputeDiskEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} - -func TestGoogleComputeImage(t *testing.T) { - - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.ResourceSearchResult - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no compute image", - response: []*assetpb.ResourceSearchResult{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiples images", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - assert.Equal(t, "projects/cloudskiff-dev-elie/global/images/example-image", got[0].ResourceId()) - assert.Equal(t, "google_compute_image", got[0].ResourceType()) - - assert.Equal(t, "projects/cloudskiff-dev-elie/global/images/example-image-2", got[1].ResourceId()) - assert.Equal(t, "google_compute_image", got[1].ResourceType()) - }, - response: []*assetpb.ResourceSearchResult{ - { - AssetType: "compute.googleapis.com/Image", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/global/images/example-image", - }, - { - AssetType: "compute.googleapis.com/Image", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-elie/global/images/example-image-2", - }, - }, - }, - { - test: "cannot list images", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_compute_image", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_compute_image", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleComputeImageEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} - -func TestGoogleComputeHealthCheck(t *testing.T) { - - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.ResourceSearchResult - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no compute health check", - response: []*assetpb.ResourceSearchResult{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiples compute health checks", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - assert.Equal(t, "projects/cloudskiff-dev-raphael/global/healthChecks/test-health-check-1", got[0].ResourceId()) - assert.Equal(t, "google_compute_health_check", got[0].ResourceType()) - - assert.Equal(t, "projects/cloudskiff-dev-raphael/global/healthChecks/test-health-check-2", got[1].ResourceId()) - assert.Equal(t, "google_compute_health_check", got[1].ResourceType()) - }, - response: []*assetpb.ResourceSearchResult{ - { - AssetType: "compute.googleapis.com/HealthCheck", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/global/healthChecks/test-health-check-1", - }, - { - AssetType: "compute.googleapis.com/HealthCheck", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/global/healthChecks/test-health-check-2", - }, - }, - }, - { - test: "cannot list compute health checks", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_compute_health_check", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_compute_health_check", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository(terraform.GOOGLE, providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleComputeHealthCheckEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} - -func TestGoogleComputeNodeGroup(t *testing.T) { - - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.Asset - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no compute node group", - response: []*assetpb.Asset{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiples compute node group", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - assert.Equal(t, "projects/cloudskiff-dev-martin/zones/us-central1-f/nodeGroups/soletenant-group", got[0].ResourceId()) - assert.Equal(t, "google_compute_node_group", got[0].ResourceType()) - - assert.Equal(t, "projects/cloudskiff-dev-martin/zones/us-central1-f/nodeGroups/simple-group", got[1].ResourceId()) - assert.Equal(t, "google_compute_node_group", got[1].ResourceType()) - }, - response: []*assetpb.Asset{ - { - AssetType: "compute.googleapis.com/NodeGroup", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-martin/zones/us-central1-f/nodeGroups/soletenant-group", - }, - { - AssetType: "compute.googleapis.com/NodeGroup", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-martin/zones/us-central1-f/nodeGroups/simple-group", - }, - }, - }, - { - test: "cannot list compute node group", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_compute_node_group", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_compute_node_group", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository(terraform.GOOGLE, providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleComputeNodeGroupEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} - -func TestGoogleComputeForwardingRule(t *testing.T) { - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.Asset - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no compute forwarding rules", - response: []*assetpb.Asset{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple compute forwarding rules", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - assert.Equal(t, "projects/cloudskiff-dev-william/regions/us-east1/forwardingRules/foo", got[0].ResourceId()) - assert.Equal(t, "google_compute_forwarding_rule", got[0].ResourceType()) - - assert.Equal(t, "projects/cloudskiff-dev-william/regions/us-east1/forwardingRules/bar", got[1].ResourceId()) - assert.Equal(t, "google_compute_forwarding_rule", got[1].ResourceType()) - }, - response: []*assetpb.Asset{ - { - AssetType: "compute.googleapis.com/ForwardingRule", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-william/regions/us-east1/forwardingRules/foo", - }, - { - AssetType: "compute.googleapis.com/ForwardingRule", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-william/regions/us-east1/forwardingRules/bar", - }, - }, - }, - { - test: "cannot list compute forwarding rules", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_compute_forwarding_rule", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_compute_forwarding_rule", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository(terraform.GOOGLE, providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleComputeForwardingRuleEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} - -func TestGoogleComputeInstanceGroupManager(t *testing.T) { - - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.Asset - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no compute instance group manager", - response: []*assetpb.Asset{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiples compute instance group managers", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - assert.Equal(t, "projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroupManagers/appserver-abc", got[0].ResourceId()) - assert.Equal(t, "google_compute_instance_group_manager", got[0].ResourceType()) - - assert.Equal(t, "projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroupManagers/appserver-def", got[1].ResourceId()) - assert.Equal(t, "google_compute_instance_group_manager", got[1].ResourceType()) - }, - response: []*assetpb.Asset{ - { - AssetType: "compute.googleapis.com/InstanceGroupManager", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroupManagers/appserver-abc", - }, - { - AssetType: "compute.googleapis.com/InstanceGroupManager", - Name: "//compute.googleapis.com/projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroupManagers/appserver-def", - }, - }, - }, - { - test: "cannot list compute instance group managers", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_compute_instance_group_manager", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_compute_instance_group_manager", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository(terraform.GOOGLE, providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleComputeInstanceGroupManagerEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} - -func TestGoogleComputeGlobalForwardingRule(t *testing.T) { - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.Asset - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no compute global forwarding rules", - response: []*assetpb.Asset{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiple compute global forwarding rules", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 2) - assert.Equal(t, "//projects/driftctl-qa-1/global/forwardingRules/global-rule-foo", got[0].ResourceId()) - assert.Equal(t, "google_compute_global_forwarding_rule", got[0].ResourceType()) - - assert.Equal(t, "//projects/driftctl-qa-1/global/forwardingRules/global-rule-bar", got[1].ResourceId()) - assert.Equal(t, "google_compute_global_forwarding_rule", got[1].ResourceType()) - }, - response: []*assetpb.Asset{ - { - AssetType: "compute.googleapis.com/GlobalForwardingRule", - Name: "//projects/driftctl-qa-1/global/forwardingRules/global-rule-foo", - }, - { - AssetType: "compute.googleapis.com/GlobalForwardingRule", - Name: "//projects/driftctl-qa-1/global/forwardingRules/global-rule-bar", - }, - }, - }, - { - test: "cannot list compute global forwarding rules", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_compute_global_forwarding_rule", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_compute_global_forwarding_rule", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository(terraform.GOOGLE, providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleComputeGlobalForwardingRuleEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} diff --git a/pkg/remote/google_network_scanner_test.go b/pkg/remote/google_network_scanner_test.go deleted file mode 100644 index 08cf9332..00000000 --- a/pkg/remote/google_network_scanner_test.go +++ /dev/null @@ -1,154 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - googleresource "github.com/snyk/driftctl/pkg/resource/google" - "github.com/snyk/driftctl/pkg/terraform" - testgoogle "github.com/snyk/driftctl/test/google" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -func TestGoogleDNSNanagedZone(t *testing.T) { - - cases := []struct { - test string - response []*assetpb.ResourceSearchResult - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - assertExpected func(t *testing.T, got []*resource.Resource) - }{ - { - test: "no managed zone", - response: []*assetpb.ResourceSearchResult{}, - wantErr: nil, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "multiples managed zones", - response: []*assetpb.ResourceSearchResult{ - { - AssetType: "dns.googleapis.com/ManagedZone", - Name: "invalid ID", // Should be ignored - }, - { - AssetType: "dns.googleapis.com/ManagedZone", - DisplayName: "test-zone-0", - Name: "//dns.googleapis.com/projects/cloudskiff-dev-raphael/managedZones/123456789", - }, - { - AssetType: "dns.googleapis.com/ManagedZone", - DisplayName: "test-zone-1", - Name: "//dns.googleapis.com/projects/cloudskiff-dev-raphael/managedZones/123456789", - }, - { - AssetType: "dns.googleapis.com/ManagedZone", - DisplayName: "test-zone-2", - Name: "//dns.googleapis.com/projects/cloudskiff-dev-raphael/managedZones/123456789", - }, - }, - wantErr: nil, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 3) - - assert.Equal(t, got[0].ResourceId(), "projects/cloudskiff-dev-raphael/managedZones/test-zone-0") - assert.Equal(t, got[0].ResourceType(), googleresource.GoogleDNSManagedZoneResourceType) - - assert.Equal(t, got[1].ResourceId(), "projects/cloudskiff-dev-raphael/managedZones/test-zone-1") - assert.Equal(t, got[1].ResourceType(), googleresource.GoogleDNSManagedZoneResourceType) - - assert.Equal(t, got[2].ResourceId(), "projects/cloudskiff-dev-raphael/managedZones/test-zone-2") - assert.Equal(t, got[2].ResourceType(), googleresource.GoogleDNSManagedZoneResourceType) - }, - }, - { - test: "should return access denied error", - wantErr: nil, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - googleresource.GoogleDNSManagedZoneResourceType, - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - googleresource.GoogleDNSManagedZoneResourceType, - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssetServer(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleDNSManagedZoneEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} diff --git a/pkg/remote/google_project_scanner_test.go b/pkg/remote/google_project_scanner_test.go deleted file mode 100644 index abd65f4f..00000000 --- a/pkg/remote/google_project_scanner_test.go +++ /dev/null @@ -1,138 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - googleresource "github.com/snyk/driftctl/pkg/resource/google" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestGoogleProjectIAMMember(t *testing.T) { - - cases := []struct { - test string - dirName string - repositoryMock func(repository *repository.MockCloudResourceManagerRepository) - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no bindings", - dirName: "google_project_member_empty", - repositoryMock: func(repository *repository.MockCloudResourceManagerRepository) { - repository.On("ListProjectsBindings").Return(map[string]map[string][]string{}, nil) - }, - wantErr: nil, - }, - { - test: "Cannot list bindings", - dirName: "google_project_member_listing_error", - repositoryMock: func(repository *repository.MockCloudResourceManagerRepository) { - repository.On("ListProjectsBindings").Return( - map[string]map[string][]string{}, - errors.New("googleapi: Error 403: driftctl-acc-circle@driftctl-qa-1.iam.gserviceaccount.com does not have project.getIamPolicy access., forbidden")) - }, - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_project_iam_member", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - errors.New("googleapi: Error 403: driftctl-acc-circle@driftctl-qa-1.iam.gserviceaccount.com does not have project.getIamPolicy access., forbidden"), - "google_project_iam_member", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - wantErr: nil, - }, - { - test: "multiples storage buckets, multiple bindings", - dirName: "google_project_member_listing_multiple", - repositoryMock: func(repository *repository.MockCloudResourceManagerRepository) { - repository.On("ListProjectsBindings").Return(map[string]map[string][]string{ - "": { - "roles/editor": { - "user:martin.guibert@cloudskiff.com", - "serviceAccount:drifctl-admin@cloudskiff-dev-martin.iam.gserviceaccount.com", - }, - "roles/storage.admin": {"user:martin.guibert@cloudskiff.com"}, - "roles/viewer": {"serviceAccount:driftctl@cloudskiff-dev-martin.iam.gserviceaccount.com"}, - "roles/cloudasset.viewer": {"serviceAccount:driftctl@cloudskiff-dev-martin.iam.gserviceaccount.com"}, - "roles/iam.securityReviewer": {"serviceAccount:driftctl@cloudskiff-dev-martin.iam.gserviceaccount.com"}, - }, - }, nil) - }, - wantErr: nil, - }, - } - - providerVersion := "3.78.0" - resType := resource.ResourceType(googleresource.GoogleProjectIamMemberResourceType) - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - managerRepository := &repository.MockCloudResourceManagerRepository{} - if c.repositoryMock != nil { - c.repositoryMock(managerRepository) - } - - remoteLibrary.AddEnumerator(google.NewGoogleProjectIamMemberEnumerator(managerRepository, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - test.TestAgainstGoldenFile(got, resType.String(), c.dirName, provider, deserializer, shouldUpdate, tt) - }) - } -} diff --git a/pkg/remote/google_sql_scanner_test.go b/pkg/remote/google_sql_scanner_test.go deleted file mode 100644 index 65037442..00000000 --- a/pkg/remote/google_sql_scanner_test.go +++ /dev/null @@ -1,150 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - googleresource "github.com/snyk/driftctl/pkg/resource/google" - "github.com/snyk/driftctl/pkg/terraform" - testgoogle "github.com/snyk/driftctl/test/google" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/structpb" -) - -func TestGoogleSQLDatabaseInstance(t *testing.T) { - - cases := []struct { - test string - assertExpected func(t *testing.T, got []*resource.Resource) - response []*assetpb.Asset - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no instance", - response: []*assetpb.Asset{}, - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - }, - { - test: "one resource returned", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 1) - assert.Equal(t, "instance-test", got[0].ResourceId()) - assert.Equal(t, "google_sql_database_instance", got[0].ResourceType()) - }, - response: []*assetpb.Asset{ - { - AssetType: "sqladmin.googleapis.com/Instance", - Resource: &assetpb.Resource{ - Data: func() *structpb.Struct { - v, err := structpb.NewStruct(map[string]interface{}{ - "name": "instance-test", - }) - if err != nil { - t.Fatal(err) - } - return v - }(), - }, - }, - }, - }, - { - test: "one resource without resource data", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - response: []*assetpb.Asset{ - { - AssetType: "sqladmin.googleapis.com/Instance", - }, - }, - }, - { - test: "cannot list resources", - assertExpected: func(t *testing.T, got []*resource.Resource) { - assert.Len(t, got, 0) - }, - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_sql_database_instance", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_sql_database_instance", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - }, - } - - providerVersion := "3.78.0" - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - scanOptions := ScannerOptions{} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - assetClient, err := testgoogle.NewFakeAssertServerWithList(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleSQLDatabaseInstanceEnumerator(repo, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - if c.assertExpected != nil { - c.assertExpected(t, got) - } - }) - } -} diff --git a/pkg/remote/google_storage_scanner_test.go b/pkg/remote/google_storage_scanner_test.go deleted file mode 100644 index cb718016..00000000 --- a/pkg/remote/google_storage_scanner_test.go +++ /dev/null @@ -1,329 +0,0 @@ -package remote - -import ( - "context" - "testing" - - asset "cloud.google.com/go/asset/apiv1" - "cloud.google.com/go/storage" - "github.com/pkg/errors" - "github.com/snyk/driftctl/mocks" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/cache" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - "github.com/snyk/driftctl/pkg/remote/google" - "github.com/snyk/driftctl/pkg/remote/google/repository" - "github.com/snyk/driftctl/pkg/resource" - googleresource "github.com/snyk/driftctl/pkg/resource/google" - "github.com/snyk/driftctl/pkg/terraform" - "github.com/snyk/driftctl/test" - "github.com/snyk/driftctl/test/goldenfile" - testgoogle "github.com/snyk/driftctl/test/google" - testresource "github.com/snyk/driftctl/test/resource" - terraform2 "github.com/snyk/driftctl/test/terraform" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - assetpb "google.golang.org/genproto/googleapis/cloud/asset/v1" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -func TestGoogleStorageBucket(t *testing.T) { - - cases := []struct { - test string - dirName string - response []*assetpb.ResourceSearchResult - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no storage buckets", - dirName: "google_storage_bucket_empty", - response: []*assetpb.ResourceSearchResult{}, - wantErr: nil, - }, - { - test: "multiples storage buckets", - dirName: "google_storage_bucket", - response: []*assetpb.ResourceSearchResult{ - { - AssetType: "storage.googleapis.com/Bucket", - DisplayName: "driftctl-unittest-1", - }, - { - AssetType: "storage.googleapis.com/Bucket", - DisplayName: "driftctl-unittest-2", - }, - { - AssetType: "storage.googleapis.com/Bucket", - DisplayName: "driftctl-unittest-3", - }, - }, - wantErr: nil, - }, - { - test: "cannot list storage buckets", - dirName: "google_storage_bucket_empty", - responseErr: status.Error(codes.PermissionDenied, "The caller does not have permission"), - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_storage_bucket", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - status.Error(codes.PermissionDenied, "The caller does not have permission"), - "google_storage_bucket", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - wantErr: nil, - }, - } - - providerVersion := "3.78.0" - resType := resource.ResourceType(googleresource.GoogleStorageBucketResourceType) - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - var assetClient *asset.Client - if !shouldUpdate { - var err error - assetClient, err = testgoogle.NewFakeAssetServer(c.response, c.responseErr) - if err != nil { - tt.Fatal(err) - } - } - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - // Replace mock by real resources if we are in update mode - if shouldUpdate { - ctx := context.Background() - assetClient, err = asset.NewClient(ctx) - if err != nil { - tt.Fatal(err) - } - err = realProvider.Init() - if err != nil { - tt.Fatal(err) - } - provider.ShouldUpdate() - } - - repo := repository.NewAssetRepository(assetClient, realProvider.GetConfig(), cache.New(0)) - - remoteLibrary.AddEnumerator(google.NewGoogleStorageBucketEnumerator(repo, factory)) - remoteLibrary.AddDetailsFetcher(resType, common.NewGenericDetailsFetcher(resType, provider, deserializer)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, err, c.wantErr) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - test.TestAgainstGoldenFile(got, resType.String(), c.dirName, provider, deserializer, shouldUpdate, tt) - }) - } -} - -func TestGoogleStorageBucketIAMMember(t *testing.T) { - - cases := []struct { - test string - dirName string - assetRepositoryMock func(assetRepository *repository.MockAssetRepository) - storageRepositoryMock func(storageRepository *repository.MockStorageRepository) - responseErr error - setupAlerterMock func(alerter *mocks.AlerterInterface) - wantErr error - }{ - { - test: "no storage buckets", - dirName: "google_storage_bucket_member_empty", - assetRepositoryMock: func(assetRepository *repository.MockAssetRepository) { - assetRepository.On("SearchAllBuckets").Return([]*assetpb.ResourceSearchResult{}, nil) - }, - wantErr: nil, - }, - { - test: "multiples storage buckets, no bindings", - dirName: "google_storage_bucket_member_empty", - assetRepositoryMock: func(assetRepository *repository.MockAssetRepository) { - assetRepository.On("SearchAllBuckets").Return([]*assetpb.ResourceSearchResult{ - { - AssetType: "storage.googleapis.com/Bucket", - DisplayName: "dctlgstoragebucketiambinding-1", - }, - { - AssetType: "storage.googleapis.com/Bucket", - DisplayName: "dctlgstoragebucketiambinding-2", - }, - }, nil) - }, - storageRepositoryMock: func(storageRepository *repository.MockStorageRepository) { - storageRepository.On("ListAllBindings", "dctlgstoragebucketiambinding-1").Return(map[string][]string{}, nil) - storageRepository.On("ListAllBindings", "dctlgstoragebucketiambinding-2").Return(map[string][]string{}, nil) - }, - wantErr: nil, - }, - { - test: "Cannot list bindings", - dirName: "google_storage_bucket_member_listing_error", - assetRepositoryMock: func(assetRepository *repository.MockAssetRepository) { - assetRepository.On("SearchAllBuckets").Return([]*assetpb.ResourceSearchResult{ - { - AssetType: "storage.googleapis.com/Bucket", - DisplayName: "dctlgstoragebucketiambinding-1", - }, - }, nil) - }, - storageRepositoryMock: func(storageRepository *repository.MockStorageRepository) { - storageRepository.On("ListAllBindings", "dctlgstoragebucketiambinding-1").Return( - map[string][]string{}, - errors.New("googleapi: Error 403: driftctl-acc-circle@driftctl-qa-1.iam.gserviceaccount.com does not have storage.buckets.getIamPolicy access to the Google Cloud Storage bucket., forbidden")) - }, - setupAlerterMock: func(alerter *mocks.AlerterInterface) { - alerter.On( - "SendAlert", - "google_storage_bucket_iam_member", - alerts.NewRemoteAccessDeniedAlert( - common.RemoteGoogleTerraform, - remoteerr.NewResourceListingError( - errors.New("googleapi: Error 403: driftctl-acc-circle@driftctl-qa-1.iam.gserviceaccount.com does not have storage.buckets.getIamPolicy access to the Google Cloud Storage bucket., forbidden"), - "google_storage_bucket_iam_member", - ), - alerts.EnumerationPhase, - ), - ).Once() - }, - wantErr: nil, - }, - { - test: "multiples storage buckets, multiple bindings", - dirName: "google_storage_bucket_member_listing_multiple", - assetRepositoryMock: func(assetRepository *repository.MockAssetRepository) { - assetRepository.On("SearchAllBuckets").Return([]*assetpb.ResourceSearchResult{ - { - AssetType: "storage.googleapis.com/Bucket", - DisplayName: "dctlgstoragebucketiambinding-1", - }, - { - AssetType: "storage.googleapis.com/Bucket", - DisplayName: "dctlgstoragebucketiambinding-2", - }, - }, nil) - }, - storageRepositoryMock: func(storageRepository *repository.MockStorageRepository) { - storageRepository.On("ListAllBindings", "dctlgstoragebucketiambinding-1").Return(map[string][]string{ - "roles/storage.admin": {"user:elie.charra@cloudskiff.com"}, - "roles/storage.objectViewer": {"user:william.beuil@cloudskiff.com"}, - }, nil) - - storageRepository.On("ListAllBindings", "dctlgstoragebucketiambinding-2").Return(map[string][]string{ - "roles/storage.admin": {"user:william.beuil@cloudskiff.com"}, - "roles/storage.objectViewer": {"user:elie.charra@cloudskiff.com"}, - }, nil) - }, - wantErr: nil, - }, - } - - providerVersion := "3.78.0" - resType := resource.ResourceType(googleresource.GoogleStorageBucketIamMemberResourceType) - schemaRepository := testresource.InitFakeSchemaRepository("google", providerVersion) - googleresource.InitResourcesMetadata(schemaRepository) - factory := terraform.NewTerraformResourceFactory(schemaRepository) - deserializer := resource.NewDeserializer(factory) - - for _, c := range cases { - t.Run(c.test, func(tt *testing.T) { - repositoryCache := cache.New(100) - - shouldUpdate := c.dirName == *goldenfile.Update - - scanOptions := ScannerOptions{Deep: true} - providerLibrary := terraform.NewProviderLibrary() - remoteLibrary := common.NewRemoteLibrary() - - // Initialize mocks - alerter := &mocks.AlerterInterface{} - if c.setupAlerterMock != nil { - c.setupAlerterMock(alerter) - } - - storageRepo := &repository.MockStorageRepository{} - if c.storageRepositoryMock != nil { - c.storageRepositoryMock(storageRepo) - } - var storageRepository repository.StorageRepository = storageRepo - if shouldUpdate { - storageClient, err := storage.NewClient(context.Background()) - if err != nil { - panic(err) - } - storageRepository = repository.NewStorageRepository(storageClient, repositoryCache) - } - - assetRepo := &repository.MockAssetRepository{} - if c.assetRepositoryMock != nil { - c.assetRepositoryMock(assetRepo) - } - var assetRepository repository.AssetRepository = assetRepo - - realProvider, err := terraform2.InitTestGoogleProvider(providerLibrary, providerVersion) - if err != nil { - tt.Fatal(err) - } - provider := terraform2.NewFakeTerraformProvider(realProvider) - provider.WithResponse(c.dirName) - - remoteLibrary.AddEnumerator(google.NewGoogleStorageBucketIamMemberEnumerator(assetRepository, storageRepository, factory)) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", mock.Anything).Return(false) - - s := NewScanner(remoteLibrary, alerter, scanOptions, testFilter) - got, err := s.Resources() - assert.Equal(tt, c.wantErr, err) - if err != nil { - return - } - alerter.AssertExpectations(tt) - testFilter.AssertExpectations(tt) - test.TestAgainstGoldenFile(got, resType.String(), c.dirName, provider, deserializer, shouldUpdate, tt) - }) - } -} diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go deleted file mode 100644 index 30f4b112..00000000 --- a/pkg/remote/remote.go +++ /dev/null @@ -1,56 +0,0 @@ -package remote - -import ( - "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/azurerm" - "github.com/snyk/driftctl/pkg/remote/common" - "github.com/snyk/driftctl/pkg/remote/github" - "github.com/snyk/driftctl/pkg/remote/google" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/terraform" -) - -var supportedRemotes = []string{ - common.RemoteAWSTerraform, - common.RemoteGithubTerraform, - common.RemoteGoogleTerraform, - common.RemoteAzureTerraform, -} - -func IsSupported(remote string) bool { - for _, r := range supportedRemotes { - if r == remote { - return true - } - } - return false -} - -func Activate(remote, version string, alerter *alerter.Alerter, - providerLibrary *terraform.ProviderLibrary, - remoteLibrary *common.RemoteLibrary, - progress output.Progress, - resourceSchemaRepository *resource.SchemaRepository, - factory resource.ResourceFactory, - configDir string) error { - switch remote { - case common.RemoteAWSTerraform: - return aws.Init(version, alerter, providerLibrary, remoteLibrary, progress, resourceSchemaRepository, factory, configDir) - case common.RemoteGithubTerraform: - return github.Init(version, alerter, providerLibrary, remoteLibrary, progress, resourceSchemaRepository, factory, configDir) - case common.RemoteGoogleTerraform: - return google.Init(version, alerter, providerLibrary, remoteLibrary, progress, resourceSchemaRepository, factory, configDir) - case common.RemoteAzureTerraform: - return azurerm.Init(version, alerter, providerLibrary, remoteLibrary, progress, resourceSchemaRepository, factory, configDir) - - default: - return errors.Errorf("unsupported remote '%s'", remote) - } -} - -func GetSupportedRemotes() []string { - return supportedRemotes -} diff --git a/pkg/remote/resource_enumeration_error_handler.go b/pkg/remote/resource_enumeration_error_handler.go deleted file mode 100644 index 25e16ba8..00000000 --- a/pkg/remote/resource_enumeration_error_handler.go +++ /dev/null @@ -1,115 +0,0 @@ -package remote - -import ( - "strings" - - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerror "github.com/snyk/driftctl/pkg/remote/error" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -func HandleResourceEnumerationError(err error, alerter alerter.AlerterInterface) error { - listError, ok := err.(*remoteerror.ResourceScanningError) - if !ok { - return err - } - - rootCause := listError.RootCause() - - // We cannot use the status.FromError() method because AWS errors are not well-formed. - // Indeed, they compose the error interface without implementing the Error() method and thus triggering a nil panic - // when returning an unknown error from status.FromError() - // As a workaround we duplicated the logic from status.FromError here - if _, ok := rootCause.(interface{ GRPCStatus() *status.Status }); ok { - return handleGoogleEnumerationError(alerter, listError, status.Convert(rootCause)) - } - - // at least for storage api google sdk does not return grpc error so we parse the error message. - if shouldHandleGoogleForbiddenError(listError) { - alerts.SendEnumerationAlert(common.RemoteGoogleTerraform, alerter, listError) - return nil - } - - reqerr, ok := rootCause.(awserr.RequestFailure) - if ok { - return handleAWSError(alerter, listError, reqerr) - } - - // This handles access denied errors like the following: - // aws_s3_bucket_policy: AccessDenied: Error listing bucket policy - if strings.Contains(rootCause.Error(), "AccessDenied") { - alerts.SendEnumerationAlert(common.RemoteAWSTerraform, alerter, listError) - return nil - } - - if strings.HasPrefix( - rootCause.Error(), - "Your token has not been granted the required scopes to execute this query.", - ) { - alerts.SendEnumerationAlert(common.RemoteGithubTerraform, alerter, listError) - return nil - } - - return err -} - -func HandleResourceDetailsFetchingError(err error, alerter alerter.AlerterInterface) error { - listError, ok := err.(*remoteerror.ResourceScanningError) - if !ok { - return err - } - - rootCause := listError.RootCause() - - if shouldHandleGoogleForbiddenError(listError) { - alerts.SendDetailsFetchingAlert(common.RemoteGoogleTerraform, alerter, listError) - return nil - } - - // This handles access denied errors like the following: - // iam_role_policy: error reading IAM Role Policy (): AccessDenied: User: ... - if strings.HasPrefix(rootCause.Error(), "AccessDeniedException") || - strings.Contains(rootCause.Error(), "AccessDenied") || - strings.Contains(rootCause.Error(), "AuthorizationError") { - alerts.SendDetailsFetchingAlert(common.RemoteAWSTerraform, alerter, listError) - return nil - } - - return err -} - -func handleAWSError(alerter alerter.AlerterInterface, listError *remoteerror.ResourceScanningError, reqerr awserr.RequestFailure) error { - if reqerr.StatusCode() == 403 || (reqerr.StatusCode() == 400 && strings.Contains(reqerr.Code(), "AccessDenied")) { - alerts.SendEnumerationAlert(common.RemoteAWSTerraform, alerter, listError) - return nil - } - - return reqerr -} - -func handleGoogleEnumerationError(alerter alerter.AlerterInterface, err *remoteerror.ResourceScanningError, st *status.Status) error { - if st.Code() == codes.PermissionDenied { - alerts.SendEnumerationAlert(common.RemoteGoogleTerraform, alerter, err) - return nil - } - return err -} - -func shouldHandleGoogleForbiddenError(err *remoteerror.ResourceScanningError) bool { - errMsg := err.RootCause().Error() - - // Check if this is a Google related error - if !strings.Contains(errMsg, "googleapi") { - return false - } - - if strings.Contains(errMsg, "Error 403") { - return true - } - - return false -} diff --git a/pkg/remote/resource_enumeration_error_handler_test.go b/pkg/remote/resource_enumeration_error_handler_test.go deleted file mode 100644 index 4ce74749..00000000 --- a/pkg/remote/resource_enumeration_error_handler_test.go +++ /dev/null @@ -1,375 +0,0 @@ -package remote - -import ( - "errors" - "testing" - - "github.com/snyk/driftctl/pkg/remote/alerts" - "github.com/snyk/driftctl/pkg/remote/common" - remoteerr "github.com/snyk/driftctl/pkg/remote/error" - resourcegithub "github.com/snyk/driftctl/pkg/resource/github" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - "github.com/stretchr/testify/assert" - - "github.com/aws/aws-sdk-go/aws/awserr" - resourceaws "github.com/snyk/driftctl/pkg/resource/aws" - - "github.com/snyk/driftctl/pkg/alerter" -) - -func TestHandleAwsEnumerationErrors(t *testing.T) { - - tests := []struct { - name string - err error - wantAlerts alerter.Alerts - wantErr bool - }{ - { - name: "Handled error 403", - err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("", "", errors.New("")), 403, ""), resourceaws.AwsVpcResourceType), - wantAlerts: alerter.Alerts{"aws_vpc": []alerter.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awserr.NewRequestFailure(awserr.New("", "", errors.New("")), 403, ""), "aws_vpc", "aws_vpc"), alerts.EnumerationPhase)}}, - wantErr: false, - }, - { - name: "Handled error AccessDenied", - err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, ""), resourceaws.AwsDynamodbTableResourceType), - wantAlerts: alerter.Alerts{"aws_dynamodb_table": []alerter.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, ""), "aws_dynamodb_table", "aws_dynamodb_table"), alerts.EnumerationPhase)}}, - wantErr: false, - }, - { - name: "Not Handled error code", - err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("", "", errors.New("")), 404, ""), resourceaws.AwsVpcResourceType), - wantAlerts: map[string][]alerter.Alert{}, - wantErr: true, - }, - { - name: "Not Handled error type", - err: errors.New("error"), - wantAlerts: map[string][]alerter.Alert{}, - wantErr: true, - }, - { - name: "Not Handled root error type", - err: remoteerr.NewResourceListingError(errors.New("error"), resourceaws.AwsVpcResourceType), - wantAlerts: map[string][]alerter.Alert{}, - wantErr: true, - }, - { - name: "Handle AccessDenied error", - err: remoteerr.NewResourceListingError(errors.New("an error occured: AccessDenied: 403"), resourceaws.AwsVpcResourceType), - wantAlerts: alerter.Alerts{"aws_vpc": []alerter.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("an error occured: AccessDenied: 403"), "aws_vpc", "aws_vpc"), alerts.EnumerationPhase)}}, - wantErr: false, - }, - { - name: "Access denied error on a single resource", - err: remoteerr.NewResourceScanningError(errors.New("Error: AccessDenied: 403 ..."), resourceaws.AwsS3BucketResourceType, "my-bucket"), - wantAlerts: alerter.Alerts{"aws_s3_bucket.my-bucket": []alerter.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Error: AccessDenied: 403 ..."), "aws_s3_bucket.my-bucket", "aws_s3_bucket"), alerts.EnumerationPhase)}}, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - alertr := alerter.NewAlerter() - gotErr := HandleResourceEnumerationError(tt.err, alertr) - assert.Equal(t, tt.wantErr, gotErr != nil) - - retrieve := alertr.Retrieve() - assert.Equal(t, tt.wantAlerts, retrieve) - - }) - } -} - -func TestHandleGithubEnumerationErrors(t *testing.T) { - - tests := []struct { - name string - err error - wantAlerts alerter.Alerts - wantErr bool - }{ - { - name: "Handled graphql error", - err: remoteerr.NewResourceListingError(errors.New("Your token has not been granted the required scopes to execute this query."), resourcegithub.GithubTeamResourceType), - wantAlerts: alerter.Alerts{"github_team": []alerter.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteGithubTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Your token has not been granted the required scopes to execute this query."), "github_team", "github_team"), alerts.EnumerationPhase)}}, - wantErr: false, - }, - { - name: "Not handled graphql error", - err: remoteerr.NewResourceListingError(errors.New("This is a not handler graphql error"), resourcegithub.GithubTeamResourceType), - wantAlerts: map[string][]alerter.Alert{}, - wantErr: true, - }, - { - name: "Not Handled error type", - err: errors.New("error"), - wantAlerts: map[string][]alerter.Alert{}, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - alertr := alerter.NewAlerter() - gotErr := HandleResourceEnumerationError(tt.err, alertr) - assert.Equal(t, tt.wantErr, gotErr != nil) - - retrieve := alertr.Retrieve() - assert.Equal(t, tt.wantAlerts, retrieve) - - }) - } -} - -func TestHandleGoogleEnumerationErrors(t *testing.T) { - tests := []struct { - name string - err error - wantAlerts alerter.Alerts - wantErr bool - }{ - { - name: "Handled 403 error", - err: remoteerr.NewResourceListingError(status.Error(codes.PermissionDenied, "useless message"), "google_type"), - wantAlerts: alerter.Alerts{"google_type": []alerter.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteGoogleTerraform, remoteerr.NewResourceListingErrorWithType(status.Error(codes.PermissionDenied, "useless message"), "google_type", "google_type"), alerts.EnumerationPhase)}}, - wantErr: false, - }, - { - name: "Not handled non 403 error", - err: status.Error(codes.Unknown, ""), - wantAlerts: map[string][]alerter.Alert{}, - wantErr: true, - }, - { - name: "Not Handled error type", - err: errors.New("error"), - wantAlerts: map[string][]alerter.Alert{}, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - alertr := alerter.NewAlerter() - gotErr := HandleResourceEnumerationError(tt.err, alertr) - assert.Equal(t, tt.wantErr, gotErr != nil) - - retrieve := alertr.Retrieve() - assert.Equal(t, tt.wantAlerts, retrieve) - - }) - } -} - -func TestHandleAwsDetailsFetchingErrors(t *testing.T) { - - tests := []struct { - name string - err error - wantAlerts alerter.Alerts - wantErr bool - }{ - { - name: "Handle AccessDeniedException error", - err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("AccessDeniedException", "test", errors.New("")), 403, ""), resourceaws.AwsVpcResourceType), - wantAlerts: alerter.Alerts{"aws_vpc": []alerter.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awserr.NewRequestFailure(awserr.New("AccessDeniedException", "test", errors.New("")), 403, ""), "aws_vpc", "aws_vpc"), alerts.DetailsFetchingPhase)}}, - wantErr: false, - }, - { - name: "Handle AccessDenied error", - err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("test", "error: AccessDenied", errors.New("")), 403, ""), resourceaws.AwsVpcResourceType), - wantAlerts: alerter.Alerts{"aws_vpc": []alerter.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awserr.NewRequestFailure(awserr.New("test", "error: AccessDenied", errors.New("")), 403, ""), "aws_vpc", "aws_vpc"), alerts.DetailsFetchingPhase)}}, - wantErr: false, - }, - { - name: "Handle AuthorizationError error", - err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("test", "error: AuthorizationError", errors.New("")), 403, ""), resourceaws.AwsVpcResourceType), - wantAlerts: alerter.Alerts{"aws_vpc": []alerter.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(awserr.NewRequestFailure(awserr.New("test", "error: AuthorizationError", errors.New("")), 403, ""), "aws_vpc", "aws_vpc"), alerts.DetailsFetchingPhase)}}, - wantErr: false, - }, - { - name: "Unhandled error", - err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("test", "error: dummy error", errors.New("")), 403, ""), resourceaws.AwsVpcResourceType), - wantAlerts: alerter.Alerts{}, - wantErr: true, - }, - { - name: "Not Handled error type", - err: errors.New("error"), - wantAlerts: map[string][]alerter.Alert{}, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - alertr := alerter.NewAlerter() - gotErr := HandleResourceDetailsFetchingError(tt.err, alertr) - assert.Equal(t, tt.wantErr, gotErr != nil) - - retrieve := alertr.Retrieve() - assert.Equal(t, tt.wantAlerts, retrieve) - - }) - } -} - -func TestHandleGoogleDetailsFetchingErrors(t *testing.T) { - - tests := []struct { - name string - err error - wantAlerts alerter.Alerts - wantErr bool - }{ - { - name: "Handle 403 error", - err: remoteerr.NewResourceScanningError( - errors.New("Error when reading or editing Storage Bucket \"driftctl-unittest-1\": googleapi: Error 403: driftctl@elie-dev.iam.gserviceaccount.com does not have storage.buckets.get access to the Google Cloud Storage bucket., forbidden"), - "google_type", - "resource_id", - ), - wantAlerts: alerter.Alerts{"google_type.resource_id": []alerter.Alert{alerts.NewRemoteAccessDeniedAlert(common.RemoteGoogleTerraform, remoteerr.NewResourceListingErrorWithType(errors.New("Error when reading or editing Storage Bucket \"driftctl-unittest-1\": googleapi: Error 403: driftctl@elie-dev.iam.gserviceaccount.com does not have storage.buckets.get access to the Google Cloud Storage bucket., forbidden"), "google_type.resource_id", "google_type"), alerts.DetailsFetchingPhase)}}, - wantErr: false, - }, - { - name: "do not handle google unrelated error", - err: remoteerr.NewResourceScanningError( - errors.New("this string does not contains g o o g l e a p i string and thus should not be matched"), - "google_type", - "resource_id", - ), wantAlerts: alerter.Alerts{}, - wantErr: true, - }, - { - name: "do not handle google error other than 403", - err: remoteerr.NewResourceScanningError( - errors.New("Error when reading or editing Storage Bucket \"driftctl-unittest-1\": googleapi: Error 404: not found"), - "google_type", - "resource_id", - ), wantAlerts: alerter.Alerts{}, - wantErr: true, - }, - { - name: "Not Handled error type", - err: errors.New("error"), - wantAlerts: map[string][]alerter.Alert{}, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - alertr := alerter.NewAlerter() - gotErr := HandleResourceDetailsFetchingError(tt.err, alertr) - assert.Equal(t, tt.wantErr, gotErr != nil) - - retrieve := alertr.Retrieve() - assert.Equal(t, tt.wantAlerts, retrieve) - - }) - } -} - -func TestEnumerationAccessDeniedAlert_GetProviderMessage(t *testing.T) { - tests := []struct { - name string - provider string - want string - }{ - { - name: "test for unsupported provider", - provider: "foobar", - want: "", - }, - { - name: "test for AWS", - provider: common.RemoteAWSTerraform, - want: "It seems that we got access denied exceptions while listing resources.\nThe latest minimal read-only IAM policy for driftctl is always available here, please update yours: https://docs.driftctl.com/aws/policy", - }, - { - name: "test for github", - provider: common.RemoteGithubTerraform, - want: "It seems that we got access denied exceptions while listing resources.\nPlease be sure that your Github token has the right permissions, check the last up-to-date documentation there: https://docs.driftctl.com/github/policy", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - e := alerts.NewRemoteAccessDeniedAlert(tt.provider, remoteerr.NewResourceListingErrorWithType(errors.New("dummy error"), "supplier_type", "listed_type_error"), alerts.EnumerationPhase) - if got := e.GetProviderMessage(); got != tt.want { - t.Errorf("GetProviderMessage() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestDetailsFetchingAccessDeniedAlert_GetProviderMessage(t *testing.T) { - tests := []struct { - name string - provider string - want string - }{ - { - name: "test for unsupported provider", - provider: "foobar", - want: "", - }, - { - name: "test for AWS", - provider: common.RemoteAWSTerraform, - want: "It seems that we got access denied exceptions while reading details of resources.\nThe latest minimal read-only IAM policy for driftctl is always available here, please update yours: https://docs.driftctl.com/aws/policy", - }, - { - name: "test for github", - provider: common.RemoteGithubTerraform, - want: "It seems that we got access denied exceptions while reading details of resources.\nPlease be sure that your Github token has the right permissions, check the last up-to-date documentation there: https://docs.driftctl.com/github/policy", - }, - { - name: "test for google", - provider: common.RemoteGoogleTerraform, - want: "It seems that we got access denied exceptions while reading details of resources.\nPlease ensure that you have configured the required roles, please check our documentation at https://docs.driftctl.com/google/policy", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - e := alerts.NewRemoteAccessDeniedAlert(tt.provider, remoteerr.NewResourceListingErrorWithType(errors.New("dummy error"), "supplier_type", "listed_type_error"), alerts.DetailsFetchingPhase) - if got := e.GetProviderMessage(); got != tt.want { - t.Errorf("GetProviderMessage() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestResourceScanningErrorMethods(t *testing.T) { - - tests := []struct { - name string - err *remoteerr.ResourceScanningError - expectedError string - expectedResourceType string - }{ - { - name: "Handled error AccessDenied", - err: remoteerr.NewResourceListingError(awserr.NewRequestFailure(awserr.New("AccessDeniedException", "", errors.New("")), 403, ""), resourceaws.AwsDynamodbTableResourceType), - expectedError: "error scanning resource type aws_dynamodb_table: AccessDeniedException: \n\tstatus code: 403, request id: \ncaused by: ", - expectedResourceType: resourceaws.AwsDynamodbTableResourceType, - }, - { - name: "Handle AccessDenied error", - err: remoteerr.NewResourceListingError(errors.New("an error occured: AccessDenied: 403"), resourceaws.AwsVpcResourceType), - expectedError: "error scanning resource type aws_vpc: an error occured: AccessDenied: 403", - expectedResourceType: resourceaws.AwsVpcResourceType, - }, - { - name: "Access denied error on a single resource", - err: remoteerr.NewResourceScanningError(errors.New("Error: AccessDenied: 403 ..."), resourceaws.AwsS3BucketResourceType, "my-bucket"), - expectedError: "error scanning resource aws_s3_bucket.my-bucket: Error: AccessDenied: 403 ...", - expectedResourceType: resourceaws.AwsS3BucketResourceType, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.expectedError, tt.err.Error()) - assert.Equal(t, tt.expectedResourceType, tt.err.ResourceType()) - }) - } -} diff --git a/pkg/remote/scanner.go b/pkg/remote/scanner.go deleted file mode 100644 index 38ca0ce6..00000000 --- a/pkg/remote/scanner.go +++ /dev/null @@ -1,135 +0,0 @@ -package remote - -import ( - "context" - - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/parallel" - "github.com/snyk/driftctl/pkg/remote/common" - "github.com/snyk/driftctl/pkg/resource" -) - -type ScannerOptions struct { - Deep bool -} - -type Scanner struct { - enumeratorRunner *parallel.ParallelRunner - detailsFetcherRunner *parallel.ParallelRunner - remoteLibrary *common.RemoteLibrary - alerter alerter.AlerterInterface - options ScannerOptions - filter filter.Filter -} - -func NewScanner(remoteLibrary *common.RemoteLibrary, alerter alerter.AlerterInterface, options ScannerOptions, filter filter.Filter) *Scanner { - return &Scanner{ - enumeratorRunner: parallel.NewParallelRunner(context.TODO(), 10), - detailsFetcherRunner: parallel.NewParallelRunner(context.TODO(), 10), - remoteLibrary: remoteLibrary, - alerter: alerter, - options: options, - filter: filter, - } -} - -func (s *Scanner) retrieveRunnerResults(runner *parallel.ParallelRunner) ([]*resource.Resource, error) { - results := make([]*resource.Resource, 0) -loop: - for { - select { - case resources, ok := <-runner.Read(): - if !ok || resources == nil { - break loop - } - - for _, res := range resources.([]*resource.Resource) { - if res != nil { - results = append(results, res) - } - } - case <-runner.DoneChan(): - break loop - } - } - return results, runner.Err() -} - -func (s *Scanner) scan() ([]*resource.Resource, error) { - for _, enumerator := range s.remoteLibrary.Enumerators() { - if s.filter.IsTypeIgnored(enumerator.SupportedType()) { - logrus.WithFields(logrus.Fields{ - "type": enumerator.SupportedType(), - }).Debug("Ignored enumeration of resources since it is ignored in filter") - continue - } - enumerator := enumerator - s.enumeratorRunner.Run(func() (interface{}, error) { - resources, err := enumerator.Enumerate() - if err != nil { - err := HandleResourceEnumerationError(err, s.alerter) - if err == nil { - return []*resource.Resource{}, nil - } - return nil, err - } - for _, res := range resources { - if res == nil { - continue - } - logrus.WithFields(logrus.Fields{ - "id": res.ResourceId(), - "type": res.ResourceType(), - }).Debug("Found cloud resource") - } - return resources, nil - }) - } - - enumerationResult, err := s.retrieveRunnerResults(s.enumeratorRunner) - if err != nil { - return nil, err - } - - if !s.options.Deep { - return enumerationResult, nil - } - - for _, res := range enumerationResult { - res := res - s.detailsFetcherRunner.Run(func() (interface{}, error) { - fetcher := s.remoteLibrary.GetDetailsFetcher(resource.ResourceType(res.ResourceType())) - if fetcher == nil { - return []*resource.Resource{res}, nil - } - - resourceWithDetails, err := fetcher.ReadDetails(res) - if err != nil { - if err := HandleResourceDetailsFetchingError(err, s.alerter); err != nil { - return nil, err - } - return []*resource.Resource{}, nil - } - return []*resource.Resource{resourceWithDetails}, nil - }) - } - - return s.retrieveRunnerResults(s.detailsFetcherRunner) -} - -func (s *Scanner) Resources() ([]*resource.Resource, error) { - resources, err := s.scan() - if err != nil { - return nil, err - } - return resources, err -} - -func (s *Scanner) Stop() { - logrus.Debug("Stopping scanner") - s.enumeratorRunner.Stop(errors.New("interrupted")) - s.detailsFetcherRunner.Stop(errors.New("interrupted")) -} diff --git a/pkg/remote/scanner_test.go b/pkg/remote/scanner_test.go deleted file mode 100644 index 77d2cebf..00000000 --- a/pkg/remote/scanner_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package remote - -import ( - "testing" - - "github.com/snyk/driftctl/pkg/alerter" - "github.com/snyk/driftctl/pkg/filter" - "github.com/snyk/driftctl/pkg/remote/common" - "github.com/snyk/driftctl/pkg/resource" - "github.com/stretchr/testify/assert" -) - -func TestScannerShouldIgnoreType(t *testing.T) { - - // Initialize mocks - alerter := alerter.NewAlerter() - fakeEnumerator := &common.MockEnumerator{} - fakeEnumerator.On("SupportedType").Return(resource.ResourceType("FakeType")) - fakeEnumerator.AssertNotCalled(t, "Enumerate") - - remoteLibrary := common.NewRemoteLibrary() - remoteLibrary.AddEnumerator(fakeEnumerator) - - testFilter := &filter.MockFilter{} - testFilter.On("IsTypeIgnored", resource.ResourceType("FakeType")).Return(true) - - s := NewScanner(remoteLibrary, alerter, ScannerOptions{}, testFilter) - _, err := s.Resources() - assert.Nil(t, err) - fakeEnumerator.AssertExpectations(t) -} diff --git a/pkg/remote/terraform/provider.go b/pkg/remote/terraform/provider.go deleted file mode 100644 index a7ad22de..00000000 --- a/pkg/remote/terraform/provider.go +++ /dev/null @@ -1,224 +0,0 @@ -package terraform - -import ( - "context" - "os" - "os/signal" - "sync" - "syscall" - "time" - - "github.com/eapache/go-resiliency/retrier" - "github.com/hashicorp/terraform/plugin" - "github.com/hashicorp/terraform/plugin/discovery" - "github.com/hashicorp/terraform/providers" - "github.com/hashicorp/terraform/terraform" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/cmd/scan" - "github.com/snyk/driftctl/pkg/output" - "github.com/zclconf/go-cty/cty" - "github.com/zclconf/go-cty/cty/gocty" - - "github.com/snyk/driftctl/pkg/parallel" - tf "github.com/snyk/driftctl/pkg/terraform" -) - -// "alias" in these struct are a way to namespace gRPC clients. -// For example, if we need to read S3 bucket from multiple AWS region, -// we'll have an alias per region, and the alias IS the region itself. -// So we can query resources using a specific custom provider configuration -type TerraformProviderConfig struct { - Name string - DefaultAlias string - GetProviderConfig func(alias string) interface{} -} - -type TerraformProvider struct { - lock sync.Mutex - providerInstaller *tf.ProviderInstaller - grpcProviders map[string]*plugin.GRPCProvider - schemas map[string]providers.Schema - Config TerraformProviderConfig - runner *parallel.ParallelRunner - progress output.Progress -} - -func NewTerraformProvider(installer *tf.ProviderInstaller, config TerraformProviderConfig, progress output.Progress) (*TerraformProvider, error) { - p := TerraformProvider{ - providerInstaller: installer, - runner: parallel.NewParallelRunner(context.TODO(), 10), - grpcProviders: make(map[string]*plugin.GRPCProvider), - Config: config, - progress: progress, - } - return &p, nil -} - -func (p *TerraformProvider) Init() error { - stopCh := make(chan bool) - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - go func() { - select { - case <-c: - logrus.Warn("Detected interrupt during terraform provider configuration, cleanup ...") - p.Cleanup() - os.Exit(scan.EXIT_ERROR) - case <-stopCh: - return - } - }() - defer func() { - stopCh <- true - }() - err := p.configure(p.Config.DefaultAlias) - if err != nil { - return err - } - return nil -} - -func (p *TerraformProvider) Schema() map[string]providers.Schema { - return p.schemas -} - -func (p *TerraformProvider) Runner() *parallel.ParallelRunner { - return p.runner -} - -func (p *TerraformProvider) configure(alias string) error { - providerPath, err := p.providerInstaller.Install() - if err != nil { - return err - } - - if p.grpcProviders[alias] == nil { - logrus.WithFields(logrus.Fields{ - "alias": alias, - }).Debug("Starting gRPC client") - GRPCProvider, err := tf.NewGRPCProvider(discovery.PluginMeta{ - Path: providerPath, - }) - - if err != nil { - return err - } - p.grpcProviders[alias] = GRPCProvider - } - - schema := p.grpcProviders[alias].GetSchema() - if p.schemas == nil { - p.schemas = schema.ResourceTypes - } - - // This value is optional. It'll be overridden by the provider config. - config := cty.NullVal(cty.DynamicPseudoType) - - if p.Config.GetProviderConfig != nil { - configType := schema.Provider.Block.ImpliedType() - config, err = gocty.ToCtyValue(p.Config.GetProviderConfig(alias), configType) - if err != nil { - return err - } - } - - resp := p.grpcProviders[alias].Configure(providers.ConfigureRequest{ - Config: config, - }) - if resp.Diagnostics.HasErrors() { - return resp.Diagnostics.Err() - } - - logrus.WithFields(logrus.Fields{ - "alias": alias, - }).Debug("New gRPC client started") - - logrus.WithFields(logrus.Fields{ - "name": p.Config.Name, - "alias": alias, - }).Debug("Terraform provider initialized") - - return nil -} - -func (p *TerraformProvider) ReadResource(args tf.ReadResourceArgs) (*cty.Value, error) { - - logrus.WithFields(logrus.Fields{ - "id": args.ID, - "type": args.Ty, - "attrs": args.Attributes, - }).Debugf("Reading cloud resource") - - typ := string(args.Ty) - state := &terraform.InstanceState{ - ID: args.ID, - Attributes: map[string]string{}, - } - - alias := p.Config.DefaultAlias - if args.Attributes["alias"] != "" { - alias = args.Attributes["alias"] - delete(args.Attributes, "alias") - } - - p.lock.Lock() - if p.grpcProviders[alias] == nil { - err := p.configure(alias) - if err != nil { - return nil, err - } - } - p.lock.Unlock() - - if args.Attributes != nil && len(args.Attributes) > 0 { - // call to the provider sometimes add and delete field to their attribute this may broke caller so we deep copy attributes - state.Attributes = make(map[string]string, len(args.Attributes)) - for k, v := range args.Attributes { - state.Attributes[k] = v - } - } - - impliedType := p.schemas[typ].Block.ImpliedType() - - priorState, err := state.AttrsAsObjectValue(impliedType) - if err != nil { - return nil, err - } - - var newState cty.Value - r := retrier.New(retrier.ConstantBackoff(3, 100*time.Millisecond), nil) - - err = r.Run(func() error { - resp := p.grpcProviders[alias].ReadResource(providers.ReadResourceRequest{ - TypeName: typ, - PriorState: priorState, - Private: []byte{}, - ProviderMeta: cty.NullVal(cty.DynamicPseudoType), - }) - if resp.Diagnostics.HasErrors() { - return resp.Diagnostics.Err() - } - nonFatalErr := resp.Diagnostics.NonFatalErr() - if resp.NewState.IsNull() && nonFatalErr != nil { - return errors.Errorf("state returned by ReadResource is nil: %+v", nonFatalErr) - } - newState = resp.NewState - return nil - }) - - if err != nil { - return nil, err - } - p.progress.Inc() - return &newState, nil -} - -func (p *TerraformProvider) Cleanup() { - for alias, client := range p.grpcProviders { - logrus.WithFields(logrus.Fields{ - "alias": alias, - }).Debug("Closing gRPC client") - client.Close() - } -} diff --git a/pkg/remote/test/aws_cloudformation_stack_multiple/results.golden.json b/pkg/remote/test/aws_cloudformation_stack_multiple/results.golden.json deleted file mode 100755 index d15c9237..00000000 --- a/pkg/remote/test/aws_cloudformation_stack_multiple/results.golden.json +++ /dev/null @@ -1,42 +0,0 @@ -[ - { - "capabilities": null, - "disable_rollback": false, - "iam_role_arn": "", - "id": "arn:aws:cloudformation:us-east-1:047081014315:stack/foo-stack/c7aa0ab0-0f21-11ec-ba25-129d8c0b3757", - "name": "foo-stack", - "notification_arns": null, - "on_failure": null, - "outputs": null, - "parameters": { - "VPCCidr": "10.0.0.0/16" - }, - "policy_body": null, - "policy_url": null, - "tags": null, - "template_body": "{\"Parameters\":{\"VPCCidr\":{\"Default\":\"10.0.0.0/16\",\"Description\":\"Enter the CIDR block for the VPC. Default is 10.0.0.0/16.\",\"Type\":\"String\"}},\"Resources\":{\"myVpc\":{\"Properties\":{\"CidrBlock\":{\"Ref\":\"VPCCidr\"},\"Tags\":[{\"Key\":\"Name\",\"Value\":\"Primary_CF_VPC\"}]},\"Type\":\"AWS::EC2::VPC\"}}}", - "template_url": null, - "timeout_in_minutes": null, - "timeouts": null - }, - { - "capabilities": [ - "CAPABILITY_NAMED_IAM" - ], - "disable_rollback": false, - "iam_role_arn": "", - "id": "arn:aws:cloudformation:us-east-1:047081014315:stack/bar-stack/c7a96e70-0f21-11ec-bd2a-0a2d95c2b2ab", - "name": "bar-stack", - "notification_arns": null, - "on_failure": null, - "outputs": null, - "parameters": null, - "policy_body": null, - "policy_url": null, - "tags": null, - "template_body": "Resources:\n myUser:\n Type: AWS::IAM::User\n Properties:\n UserName: \"bar_cfn\"\n", - "template_url": null, - "timeout_in_minutes": null, - "timeouts": null - } -] \ No newline at end of file diff --git a/pkg/remote/test/aws_ec2_default_network_acl/results.golden.json b/pkg/remote/test/aws_ec2_default_network_acl/results.golden.json deleted file mode 100755 index 953ba503..00000000 --- a/pkg/remote/test/aws_ec2_default_network_acl/results.golden.json +++ /dev/null @@ -1,37 +0,0 @@ -[ - { - "arn": "arn:aws:ec2:us-east-1:929327065333:network-acl/acl-e88ee595", - "default_network_acl_id": null, - "egress": [ - { - "action": "allow", - "cidr_block": "0.0.0.0/0", - "from_port": 0, - "icmp_code": 0, - "icmp_type": 0, - "ipv6_cidr_block": "", - "protocol": "17", - "rule_no": 100, - "to_port": 0 - } - ], - "id": "acl-e88ee595", - "ingress": [ - { - "action": "allow", - "cidr_block": "172.31.0.0/16", - "from_port": 0, - "icmp_code": 0, - "icmp_type": 0, - "ipv6_cidr_block": "", - "protocol": "6", - "rule_no": 100, - "to_port": 0 - } - ], - "owner_id": "929327065333", - "subnet_ids": null, - "tags": null, - "vpc_id": "vpc-41d1d13b" - } -] \ No newline at end of file diff --git a/pkg/remote/test/aws_ec2_network_acl_rule/results.golden.json b/pkg/remote/test/aws_ec2_network_acl_rule/results.golden.json deleted file mode 100755 index 7224a5c8..00000000 --- a/pkg/remote/test/aws_ec2_network_acl_rule/results.golden.json +++ /dev/null @@ -1,72 +0,0 @@ -[ - { - "cidr_block": "172.31.0.0/16", - "egress": true, - "from_port": 80, - "icmp_code": null, - "icmp_type": null, - "id": "nacl-246660311", - "ipv6_cidr_block": null, - "network_acl_id": "acl-0ad6d657494d17ee2", - "protocol": "17", - "rule_action": "allow", - "rule_number": 100, - "to_port": 80 - }, - { - "cidr_block": null, - "egress": false, - "from_port": 80, - "icmp_code": null, - "icmp_type": null, - "id": "nacl-2289824980", - "ipv6_cidr_block": "::/0", - "network_acl_id": "acl-0ad6d657494d17ee2", - "protocol": "6", - "rule_action": "allow", - "rule_number": 200, - "to_port": 80 - }, - { - "cidr_block": "172.31.0.0/16", - "egress": false, - "from_port": 80, - "icmp_code": null, - "icmp_type": null, - "id": "nacl-515082162", - "ipv6_cidr_block": null, - "network_acl_id": "acl-0de54ef59074b622e", - "protocol": "17", - "rule_action": "allow", - "rule_number": 100, - "to_port": 80 - }, - { - "cidr_block": "172.31.0.0/16", - "egress": true, - "from_port": 80, - "icmp_code": null, - "icmp_type": null, - "id": "nacl-4268384215", - "ipv6_cidr_block": null, - "network_acl_id": "acl-0de54ef59074b622e", - "protocol": "17", - "rule_action": "allow", - "rule_number": 100, - "to_port": 80 - }, - { - "cidr_block": "172.31.0.0/16", - "egress": false, - "from_port": 80, - "icmp_code": null, - "icmp_type": null, - "id": "nacl-4293207588", - "ipv6_cidr_block": null, - "network_acl_id": "acl-0ad6d657494d17ee2", - "protocol": "6", - "rule_action": "allow", - "rule_number": 100, - "to_port": 80 - } -] \ No newline at end of file diff --git a/pkg/remote/test/aws_rds_clusters_results/results.golden.json b/pkg/remote/test/aws_rds_clusters_results/results.golden.json deleted file mode 100755 index af95979e..00000000 --- a/pkg/remote/test/aws_rds_clusters_results/results.golden.json +++ /dev/null @@ -1,112 +0,0 @@ -[ - { - "allow_major_version_upgrade": null, - "apply_immediately": null, - "arn": "arn:aws:rds:us-east-1:533948124879:cluster:aurora-cluster-demo", - "availability_zones": [ - "us-east-1a", - "us-east-1b", - "us-east-1d" - ], - "backtrack_window": 0, - "backup_retention_period": 5, - "cluster_identifier": "aurora-cluster-demo", - "cluster_identifier_prefix": null, - "cluster_members": [ - "aurora-cluster-demo-0" - ], - "cluster_resource_id": "cluster-TISYDSSX4J5R6ZGUTV6LLJW73E", - "copy_tags_to_snapshot": false, - "database_name": "mydb", - "db_cluster_parameter_group_name": "default.aurora-postgresql11", - "db_subnet_group_name": "default", - "deletion_protection": false, - "enable_http_endpoint": false, - "enabled_cloudwatch_logs_exports": null, - "endpoint": "aurora-cluster-demo.cluster-cd539r6quiux.us-east-1.rds.amazonaws.com", - "engine": "aurora-postgresql", - "engine_mode": "provisioned", - "engine_version": "11.9", - "final_snapshot_identifier": null, - "global_cluster_identifier": "", - "hosted_zone_id": "Z2R2ITUGPM61AM", - "iam_database_authentication_enabled": false, - "iam_roles": null, - "id": "aurora-cluster-demo", - "kms_key_id": "", - "master_password": null, - "master_username": "foo", - "port": 5432, - "preferred_backup_window": "07:00-09:00", - "preferred_maintenance_window": "fri:03:03-fri:03:33", - "reader_endpoint": "aurora-cluster-demo.cluster-ro-cd539r6quiux.us-east-1.rds.amazonaws.com", - "replication_source_identifier": "", - "restore_to_point_in_time": null, - "s3_import": null, - "scaling_configuration": null, - "skip_final_snapshot": false, - "snapshot_identifier": null, - "source_region": null, - "storage_encrypted": false, - "tags": null, - "timeouts": null, - "vpc_security_group_ids": [ - "sg-49e38646" - ] - }, - { - "allow_major_version_upgrade": null, - "apply_immediately": null, - "arn": "arn:aws:rds:us-east-1:533948124879:cluster:aurora-cluster-demo", - "availability_zones": [ - "us-east-1a", - "us-east-1b", - "us-east-1d" - ], - "backtrack_window": 0, - "backup_retention_period": 5, - "cluster_identifier": "aurora-cluster-demo-2", - "cluster_identifier_prefix": null, - "cluster_members": [ - "aurora-cluster-demo-0" - ], - "cluster_resource_id": "cluster-TISYDSSX4J5R6ZGUTV6LLJW73E", - "copy_tags_to_snapshot": false, - "database_name": "", - "db_cluster_parameter_group_name": "default.aurora-postgresql11", - "db_subnet_group_name": "default", - "deletion_protection": false, - "enable_http_endpoint": false, - "enabled_cloudwatch_logs_exports": null, - "endpoint": "aurora-cluster-demo.cluster-cd539r6quiux.us-east-1.rds.amazonaws.com", - "engine": "aurora-postgresql", - "engine_mode": "provisioned", - "engine_version": "11.9", - "final_snapshot_identifier": null, - "global_cluster_identifier": "", - "hosted_zone_id": "Z2R2ITUGPM61AM", - "iam_database_authentication_enabled": false, - "iam_roles": null, - "id": "aurora-cluster-demo-2", - "kms_key_id": "", - "master_password": null, - "master_username": "foo", - "port": 5432, - "preferred_backup_window": "07:00-09:00", - "preferred_maintenance_window": "fri:03:03-fri:03:33", - "reader_endpoint": "aurora-cluster-demo.cluster-ro-cd539r6quiux.us-east-1.rds.amazonaws.com", - "replication_source_identifier": "", - "restore_to_point_in_time": null, - "s3_import": null, - "scaling_configuration": null, - "skip_final_snapshot": false, - "snapshot_identifier": null, - "source_region": null, - "storage_encrypted": false, - "tags": null, - "timeouts": null, - "vpc_security_group_ids": [ - "sg-49e38646" - ] - } -] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_lb_rule_multiple/results.golden.json b/pkg/remote/test/azurerm_lb_rule_multiple/results.golden.json deleted file mode 100755 index 132a1c16..00000000 --- a/pkg/remote/test/azurerm_lb_rule_multiple/results.golden.json +++ /dev/null @@ -1,40 +0,0 @@ -[ - { - "backend_address_pool_id": "", - "backend_port": 80, - "disable_outbound_snat": false, - "enable_floating_ip": false, - "enable_tcp_reset": false, - "frontend_ip_configuration_id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress", - "frontend_ip_configuration_name": "PublicIPAddress", - "frontend_port": 80, - "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/loadBalancingRules/LBRule2", - "idle_timeout_in_minutes": 4, - "load_distribution": "Default", - "loadbalancer_id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer", - "name": "LBRule2", - "probe_id": "", - "protocol": "Tcp", - "resource_group_name": "raphael-dev", - "timeouts": null - }, - { - "backend_address_pool_id": "", - "backend_port": 3389, - "disable_outbound_snat": false, - "enable_floating_ip": false, - "enable_tcp_reset": false, - "frontend_ip_configuration_id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/frontendIPConfigurations/PublicIPAddress", - "frontend_ip_configuration_name": "PublicIPAddress", - "frontend_port": 3389, - "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer/loadBalancingRules/LBRule", - "idle_timeout_in_minutes": 4, - "load_distribution": "Default", - "loadbalancer_id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/raphael-dev/providers/Microsoft.Network/loadBalancers/TestLoadBalancer", - "name": "LBRule", - "probe_id": "", - "protocol": "Tcp", - "resource_group_name": "raphael-dev", - "timeouts": null - } -] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_network_security_group_multiple/results.golden.json b/pkg/remote/test/azurerm_network_security_group_multiple/results.golden.json deleted file mode 100755 index 5a57c3a1..00000000 --- a/pkg/remote/test/azurerm_network_security_group_multiple/results.golden.json +++ /dev/null @@ -1,41 +0,0 @@ -[ - { - "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/example-resources/providers/Microsoft.Network/networkSecurityGroups/acceptanceTestSecurityGroup2", - "location": "westeurope", - "name": "acceptanceTestSecurityGroup2", - "resource_group_name": "example-resources", - "security_rule": null, - "tags": null, - "timeouts": null - }, - { - "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/example-resources/providers/Microsoft.Network/networkSecurityGroups/acceptanceTestSecurityGroup1", - "location": "westeurope", - "name": "acceptanceTestSecurityGroup1", - "resource_group_name": "example-resources", - "security_rule": [ - { - "access": "Allow", - "description": "", - "destination_address_prefix": "*", - "destination_address_prefixes": null, - "destination_application_security_group_ids": null, - "destination_port_range": "*", - "destination_port_ranges": null, - "direction": "Inbound", - "name": "test123", - "priority": 100, - "protocol": "Tcp", - "source_address_prefix": "*", - "source_address_prefixes": null, - "source_application_security_group_ids": null, - "source_port_range": "*", - "source_port_ranges": null - } - ], - "tags": { - "environment": "Production" - }, - "timeouts": null - } -] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_a_record_multiple/results.golden.json b/pkg/remote/test/azurerm_private_dns_a_record_multiple/results.golden.json deleted file mode 100755 index d6aeff7a..00000000 --- a/pkg/remote/test/azurerm_private_dns_a_record_multiple/results.golden.json +++ /dev/null @@ -1,29 +0,0 @@ -[ - { - "fqdn": "test.thisisatestusingtf.com.", - "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/A/test", - "name": "test", - "records": [ - "10.0.180.17", - "10.0.180.20" - ], - "resource_group_name": "martin-dev", - "tags": null, - "timeouts": null, - "ttl": 300, - "zone_name": "thisisatestusingtf.com" - }, - { - "fqdn": "othertest.thisisatestusingtf.com.", - "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/A/othertest", - "name": "othertest", - "records": [ - "10.0.180.20" - ], - "resource_group_name": "martin-dev", - "tags": null, - "timeouts": null, - "ttl": 300, - "zone_name": "thisisatestusingtf.com" - } -] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_aaaaa_record_multiple/results.golden.json b/pkg/remote/test/azurerm_private_dns_aaaaa_record_multiple/results.golden.json deleted file mode 100755 index 6af37161..00000000 --- a/pkg/remote/test/azurerm_private_dns_aaaaa_record_multiple/results.golden.json +++ /dev/null @@ -1,30 +0,0 @@ -[ - { - "fqdn": "othertest.thisisatestusingtf.com.", - "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/AAAA/othertest", - "name": "othertest", - "records": [ - "fd5d:70bc:930e:d008:0000:0000:0000:7334", - "fd5d:70bc:930e:d008::7335" - ], - "resource_group_name": "martin-dev", - "tags": null, - "timeouts": null, - "ttl": 300, - "zone_name": "thisisatestusingtf.com" - }, - { - "fqdn": "test.thisisatestusingtf.com.", - "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/AAAA/test", - "name": "test", - "records": [ - "fd5d:70bc:930e:d008:0000:0000:0000:7334", - "fd5d:70bc:930e:d008::7335" - ], - "resource_group_name": "martin-dev", - "tags": null, - "timeouts": null, - "ttl": 300, - "zone_name": "thisisatestusingtf.com" - } -] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_cname_record_multiple/results.golden.json b/pkg/remote/test/azurerm_private_dns_cname_record_multiple/results.golden.json deleted file mode 100755 index 91a68e86..00000000 --- a/pkg/remote/test/azurerm_private_dns_cname_record_multiple/results.golden.json +++ /dev/null @@ -1,24 +0,0 @@ -[ - { - "fqdn": "test.thisisatestusingtf.com.", - "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/CNAME/test", - "name": "test", - "record": "test.com", - "resource_group_name": "martin-dev", - "tags": null, - "timeouts": null, - "ttl": 300, - "zone_name": "thisisatestusingtf.com" - }, - { - "fqdn": "othertest.thisisatestusingtf.com.", - "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/CNAME/othertest", - "name": "othertest", - "record": "othertest.com", - "resource_group_name": "martin-dev", - "tags": null, - "timeouts": null, - "ttl": 300, - "zone_name": "thisisatestusingtf.com" - } -] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_mx_record_multiple/results.golden.json b/pkg/remote/test/azurerm_private_dns_mx_record_multiple/results.golden.json deleted file mode 100755 index 36dd442e..00000000 --- a/pkg/remote/test/azurerm_private_dns_mx_record_multiple/results.golden.json +++ /dev/null @@ -1,38 +0,0 @@ -[ - { - "fqdn": "testmx.thisisatestusingtf.com.", - "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/MX/testmx", - "name": "testmx", - "record": [ - { - "exchange": "bkpmx.thisisatestusingtf.com", - "preference": 30 - } - ], - "resource_group_name": "martin-dev", - "tags": null, - "timeouts": null, - "ttl": 300, - "zone_name": "thisisatestusingtf.com" - }, - { - "fqdn": "othertestmx.thisisatestusingtf.com.", - "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/MX/othertestmx", - "name": "othertestmx", - "record": [ - { - "exchange": "backupmx.thisisatestusingtf.com", - "preference": 20 - }, - { - "exchange": "mx.thisisatestusingtf.com", - "preference": 10 - } - ], - "resource_group_name": "martin-dev", - "tags": null, - "timeouts": null, - "ttl": 300, - "zone_name": "thisisatestusingtf.com" - } -] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_private_zone_multiple/results.golden.json b/pkg/remote/test/azurerm_private_dns_private_zone_multiple/results.golden.json deleted file mode 100755 index 96fbf208..00000000 --- a/pkg/remote/test/azurerm_private_dns_private_zone_multiple/results.golden.json +++ /dev/null @@ -1,79 +0,0 @@ -[ - { - "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf2.com", - "max_number_of_record_sets": 25000, - "max_number_of_virtual_network_links": 1000, - "max_number_of_virtual_network_links_with_registration": 100, - "name": "thisisatestusingtf2.com", - "number_of_record_sets": null, - "resource_group_name": "martin-dev", - "soa_record": [ - { - "email": "azureprivatedns-host.microsoft.com", - "expire_time": 2419200, - "fqdn": "thisisatestusingtf2.com.", - "host_name": "azureprivatedns.net", - "minimum_ttl": 10, - "refresh_time": 3600, - "retry_time": 300, - "serial_number": 1, - "tags": null, - "ttl": 3600 - } - ], - "tags": null, - "timeouts": null - }, - { - "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/testmartin.com", - "max_number_of_record_sets": 25000, - "max_number_of_virtual_network_links": 1000, - "max_number_of_virtual_network_links_with_registration": 100, - "name": "testmartin.com", - "number_of_record_sets": null, - "resource_group_name": "martin-dev", - "soa_record": [ - { - "email": "azureprivatedns-host.microsoft.com", - "expire_time": 2419200, - "fqdn": "testmartin.com.", - "host_name": "azureprivatedns.net", - "minimum_ttl": 10, - "refresh_time": 3600, - "retry_time": 300, - "serial_number": 1, - "tags": null, - "ttl": 3600 - } - ], - "tags": { - "test": "test" - }, - "timeouts": null - }, - { - "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com", - "max_number_of_record_sets": 25000, - "max_number_of_virtual_network_links": 1000, - "max_number_of_virtual_network_links_with_registration": 100, - "name": "thisisatestusingtf.com", - "number_of_record_sets": null, - "resource_group_name": "martin-dev", - "soa_record": [ - { - "email": "azureprivatedns-host.microsoft.com", - "expire_time": 2419200, - "fqdn": "thisisatestusingtf.com.", - "host_name": "azureprivatedns.net", - "minimum_ttl": 10, - "refresh_time": 3600, - "retry_time": 300, - "serial_number": 1, - "tags": null, - "ttl": 3600 - } - ], - "tags": null, - "timeouts": null - } -] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_ptr_record_multiple/results.golden.json b/pkg/remote/test/azurerm_private_dns_ptr_record_multiple/results.golden.json deleted file mode 100755 index afa5d613..00000000 --- a/pkg/remote/test/azurerm_private_dns_ptr_record_multiple/results.golden.json +++ /dev/null @@ -1,29 +0,0 @@ -[ - { - "fqdn": "testptr.thisisatestusingtf.com.", - "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/PTR/testptr", - "name": "testptr", - "records": [ - "ptr3.thisisatestusingtf.com" - ], - "resource_group_name": "martin-dev", - "tags": null, - "timeouts": null, - "ttl": 300, - "zone_name": "thisisatestusingtf.com" - }, - { - "fqdn": "othertestptr.thisisatestusingtf.com.", - "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/PTR/othertestptr", - "name": "othertestptr", - "records": [ - "ptr1.thisisatestusingtf.com", - "ptr2.thisisatestusingtf.com" - ], - "resource_group_name": "martin-dev", - "tags": null, - "timeouts": null, - "ttl": 300, - "zone_name": "thisisatestusingtf.com" - } -] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_srv_record_multiple/results.golden.json b/pkg/remote/test/azurerm_private_dns_srv_record_multiple/results.golden.json deleted file mode 100755 index d5f0b670..00000000 --- a/pkg/remote/test/azurerm_private_dns_srv_record_multiple/results.golden.json +++ /dev/null @@ -1,44 +0,0 @@ -[ - { - "fqdn": "othertestptr.thisisatestusingtf.com.", - "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/SRV/othertestptr", - "name": "othertestptr", - "record": [ - { - "port": 8080, - "priority": 10, - "target": "srv2.thisisatestusingtf.com", - "weight": 10 - }, - { - "port": 8080, - "priority": 1, - "target": "srv1.thisisatestusingtf.com", - "weight": 5 - } - ], - "resource_group_name": "martin-dev", - "tags": null, - "timeouts": null, - "ttl": 300, - "zone_name": "thisisatestusingtf.com" - }, - { - "fqdn": "testptr.thisisatestusingtf.com.", - "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/SRV/testptr", - "name": "testptr", - "record": [ - { - "port": 8080, - "priority": 20, - "target": "srv3.thisisatestusingtf.com", - "weight": 15 - } - ], - "resource_group_name": "martin-dev", - "tags": null, - "timeouts": null, - "ttl": 300, - "zone_name": "thisisatestusingtf.com" - } -] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_private_dns_txt_record_multiple/results.golden.json b/pkg/remote/test/azurerm_private_dns_txt_record_multiple/results.golden.json deleted file mode 100755 index a52f6bf1..00000000 --- a/pkg/remote/test/azurerm_private_dns_txt_record_multiple/results.golden.json +++ /dev/null @@ -1,35 +0,0 @@ -[ - { - "fqdn": "testtxt.thisisatestusingtf.com.", - "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/TXT/testtxt", - "name": "testtxt", - "record": [ - { - "value": "this is value line 3" - } - ], - "resource_group_name": "martin-dev", - "tags": null, - "timeouts": null, - "ttl": 300, - "zone_name": "thisisatestusingtf.com" - }, - { - "fqdn": "othertesttxt.thisisatestusingtf.com.", - "id": "/subscriptions/8cb43347-a79f-4bb2-a8b4-c838b41fa5a5/resourceGroups/martin-dev/providers/Microsoft.Network/privateDnsZones/thisisatestusingtf.com/TXT/othertesttxt", - "name": "othertesttxt", - "record": [ - { - "value": "this is value line 1" - }, - { - "value": "this is value line 2" - } - ], - "resource_group_name": "martin-dev", - "tags": null, - "timeouts": null, - "ttl": 300, - "zone_name": "thisisatestusingtf.com" - } -] \ No newline at end of file diff --git a/pkg/remote/test/azurerm_ssh_public_key_multiple/results.golden.json b/pkg/remote/test/azurerm_ssh_public_key_multiple/results.golden.json deleted file mode 100755 index 5ff4d335..00000000 --- a/pkg/remote/test/azurerm_ssh_public_key_multiple/results.golden.json +++ /dev/null @@ -1,20 +0,0 @@ -[ - { - "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/TESTRESGROUP/providers/Microsoft.Compute/sshPublicKeys/example-key2", - "location": "westeurope", - "name": "example-key2", - "public_key": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQCjeC5sO1EdEfZOrdVCpuOgXcXsKZg9zgfJbHgQgX1R2Nd8mNQrUjpsB4XLHNZ3T6UYrsSh7oxYC3UFu6peO4LmA2WTe2wCWVFn9WW/Lo99WcA/G/fGj6s5HK5CHFVPXnNxM47QJMNm5BWOM55+EWP839SHLH9Fk63H575x7jxZvBvaV0uL84XuVpiEBKhnpQfT4cJGoGLOGgjM+TpHyosbKldu5q2UTF9nOGpmLuku41oihqiPPSJnJRv3TDKFi4mIl9Iz5HJINWvLl1kdCfyjPcHcH5GO0tuA9rP5AbsmG5EAGOKtuFipYA4MyY9SYriZ2V1vpgefUS9lilg9hIPEj/8ZPTxf62XeyC1dQ3cOz6yPWR2sODyVECVf6mrmhZPTjVX+DorByX2uBzLDzF9jGMFMJRhxi0yVpXsqBrP+ps9G+s7oNUDp771d1Bix+gm5EyebEbdiQuf0/8wDlhY5jYAFJW1xkPKXcjJdM1FuVVS1B8zhvRVEJZUngruVfh/7jJUOWNS44F7rVz5a4r/vs84ObFIMeYdFn+uxgUqOlNMAvXLvJ2GzlPXInXW90Uv+JJ5msny/5ygGfHr2D6xOf6P7r7oSalXwjd9BcRS6/4GQAY6LVfPwrpnrpyJBiK/FhEbR+ctfDo81eKhmp0EyxvSJGW46/26/kqHvchf+rQ== ?\n", - "resource_group_name": "TESTRESGROUP", - "tags": null, - "timeouts": null - }, - { - "id": "/subscriptions/7bfb2c5c-7308-46ed-8ae4-fffa356eb406/resourceGroups/TESTRESGROUP/providers/Microsoft.Compute/sshPublicKeys/example-key", - "location": "westeurope", - "name": "example-key", - "public_key": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQCjeC5sO1EdEfZOrdVCpuOgXcXsKZg9zgfJbHgQgX1R2Nd8mNQrUjpsB4XLHNZ3T6UYrsSh7oxYC3UFu6peO4LmA2WTe2wCWVFn9WW/Lo99WcA/G/fGj6s5HK5CHFVPXnNxM47QJMNm5BWOM55+EWP839SHLH9Fk63H575x7jxZvBvaV0uL84XuVpiEBKhnpQfT4cJGoGLOGgjM+TpHyosbKldu5q2UTF9nOGpmLuku41oihqiPPSJnJRv3TDKFi4mIl9Iz5HJINWvLl1kdCfyjPcHcH5GO0tuA9rP5AbsmG5EAGOKtuFipYA4MyY9SYriZ2V1vpgefUS9lilg9hIPEj/8ZPTxf62XeyC1dQ3cOz6yPWR2sODyVECVf6mrmhZPTjVX+DorByX2uBzLDzF9jGMFMJRhxi0yVpXsqBrP+ps9G+s7oNUDp771d1Bix+gm5EyebEbdiQuf0/8wDlhY5jYAFJW1xkPKXcjJdM1FuVVS1B8zhvRVEJZUngruVfh/7jJUOWNS44F7rVz5a4r/vs84ObFIMeYdFn+uxgUqOlNMAvXLvJ2GzlPXInXW90Uv+JJ5msny/5ygGfHr2D6xOf6P7r7oSalXwjd9BcRS6/4GQAY6LVfPwrpnrpyJBiK/FhEbR+ctfDo81eKhmp0EyxvSJGW46/26/kqHvchf+rQ== ?\n", - "resource_group_name": "TESTRESGROUP", - "tags": null, - "timeouts": null - } -] \ No newline at end of file diff --git a/pkg/remote/test/google_compute_firewall/results.golden.json b/pkg/remote/test/google_compute_firewall/results.golden.json deleted file mode 100755 index f82df3cd..00000000 --- a/pkg/remote/test/google_compute_firewall/results.golden.json +++ /dev/null @@ -1,116 +0,0 @@ -[ - { - "allow": [ - { - "ports": [ - "80", - "8080", - "1000-2000" - ], - "protocol": "tcp" - }, - { - "ports": null, - "protocol": "icmp" - } - ], - "creation_timestamp": "2021-09-14T05:21:08.730-07:00", - "deny": null, - "description": "", - "destination_ranges": null, - "direction": "INGRESS", - "disabled": false, - "enable_logging": null, - "id": "projects/cloudskiff-dev-elie/global/firewalls/test-firewall-1", - "log_config": null, - "name": "test-firewall-1", - "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-elie/global/networks/test-network", - "priority": 1000, - "project": "cloudskiff-dev-elie", - "self_link": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-elie/global/firewalls/test-firewall-1", - "source_ranges": null, - "source_service_accounts": null, - "source_tags": [ - "web" - ], - "target_service_accounts": null, - "target_tags": null, - "timeouts": null - }, - { - "allow": [ - { - "ports": [ - "80", - "8080", - "1000-2000" - ], - "protocol": "tcp" - }, - { - "ports": null, - "protocol": "icmp" - } - ], - "creation_timestamp": "2021-09-14T05:21:08.744-07:00", - "deny": null, - "description": "", - "destination_ranges": null, - "direction": "INGRESS", - "disabled": false, - "enable_logging": null, - "id": "projects/cloudskiff-dev-elie/global/firewalls/test-firewall-0", - "log_config": null, - "name": "test-firewall-0", - "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-elie/global/networks/test-network", - "priority": 1000, - "project": "cloudskiff-dev-elie", - "self_link": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-elie/global/firewalls/test-firewall-0", - "source_ranges": null, - "source_service_accounts": null, - "source_tags": [ - "web" - ], - "target_service_accounts": null, - "target_tags": null, - "timeouts": null - }, - { - "allow": [ - { - "ports": [ - "80", - "8080", - "1000-2000" - ], - "protocol": "tcp" - }, - { - "ports": null, - "protocol": "icmp" - } - ], - "creation_timestamp": "2021-09-14T05:21:08.624-07:00", - "deny": null, - "description": "", - "destination_ranges": null, - "direction": "INGRESS", - "disabled": false, - "enable_logging": null, - "id": "projects/cloudskiff-dev-elie/global/firewalls/test-firewall-2", - "log_config": null, - "name": "test-firewall-2", - "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-elie/global/networks/test-network", - "priority": 1000, - "project": "cloudskiff-dev-elie", - "self_link": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-elie/global/firewalls/test-firewall-2", - "source_ranges": null, - "source_service_accounts": null, - "source_tags": [ - "web" - ], - "target_service_accounts": null, - "target_tags": null, - "timeouts": null - } -] \ No newline at end of file diff --git a/pkg/remote/test/google_compute_instance_group/results.golden.json b/pkg/remote/test/google_compute_instance_group/results.golden.json deleted file mode 100755 index cc28d15f..00000000 --- a/pkg/remote/test/google_compute_instance_group/results.golden.json +++ /dev/null @@ -1,28 +0,0 @@ -[ - { - "description": "Terraform test instance group", - "id": "projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroups/driftctl-test-2", - "instances": null, - "name": "driftctl-test-2", - "named_port": null, - "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/global/networks/test-network", - "project": "cloudskiff-dev-raphael", - "self_link": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroups/driftctl-test-2", - "size": 0, - "timeouts": null, - "zone": "us-central1-a" - }, - { - "description": "Terraform test instance group", - "id": "projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroups/driftctl-test-1", - "instances": null, - "name": "driftctl-test-1", - "named_port": null, - "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/global/networks/test-network", - "project": "cloudskiff-dev-raphael", - "self_link": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/zones/us-central1-a/instanceGroups/driftctl-test-1", - "size": 0, - "timeouts": null, - "zone": "us-central1-a" - } -] \ No newline at end of file diff --git a/pkg/remote/test/google_compute_network/results.golden.json b/pkg/remote/test/google_compute_network/results.golden.json deleted file mode 100755 index e6e2f8f6..00000000 --- a/pkg/remote/test/google_compute_network/results.golden.json +++ /dev/null @@ -1,41 +0,0 @@ -[ - { - "auto_create_subnetworks": false, - "delete_default_routes_on_create": false, - "description": "", - "gateway_ipv4": null, - "id": "projects/driftctl-qa-1/global/networks/driftctl-unittest-3", - "mtu": 1460, - "name": "driftctl-unittest-3", - "project": "driftctl-qa-1", - "routing_mode": "REGIONAL", - "self_link": null, - "timeouts": null - }, - { - "auto_create_subnetworks": true, - "delete_default_routes_on_create": false, - "description": "", - "gateway_ipv4": null, - "id": "projects/driftctl-qa-1/global/networks/driftctl-unittest-2", - "mtu": 1460, - "name": "driftctl-unittest-2", - "project": "driftctl-qa-1", - "routing_mode": "REGIONAL", - "self_link": null, - "timeouts": null - }, - { - "auto_create_subnetworks": false, - "delete_default_routes_on_create": false, - "description": "", - "gateway_ipv4": null, - "id": "projects/driftctl-qa-1/global/networks/driftctl-unittest-1", - "mtu": 1460, - "name": "driftctl-unittest-1", - "project": "driftctl-qa-1", - "routing_mode": "REGIONAL", - "self_link": null, - "timeouts": null - } -] \ No newline at end of file diff --git a/pkg/remote/test/google_compute_subnetwork_multiple/results.golden.json b/pkg/remote/test/google_compute_subnetwork_multiple/results.golden.json deleted file mode 100755 index 8a6a07c4..00000000 --- a/pkg/remote/test/google_compute_subnetwork_multiple/results.golden.json +++ /dev/null @@ -1,71 +0,0 @@ -[ - { - "creation_timestamp": "2021-10-20T07:39:34.673-07:00", - "description": "", - "fingerprint": null, - "gateway_address": "10.2.0.1", - "id": "projects/cloudskiff-dev-raphael/regions/us-central1/subnetworks/driftctl-unittest-1", - "ip_cidr_range": "10.2.0.0/16", - "log_config": null, - "name": "driftctl-unittest-1", - "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/global/networks/test-network-1871346572", - "private_ip_google_access": false, - "private_ipv6_google_access": "DISABLE_GOOGLE_ACCESS", - "project": "cloudskiff-dev-raphael", - "region": "us-central1", - "secondary_ip_range": [ - { - "ip_cidr_range": "192.168.10.0/24", - "range_name": "tf-test-secondary-range-update1" - } - ], - "self_link": null, - "timeouts": null - }, - { - "creation_timestamp": "2021-10-20T07:39:45.114-07:00", - "description": "", - "fingerprint": null, - "gateway_address": "10.2.0.1", - "id": "projects/cloudskiff-dev-raphael/regions/us-central1/subnetworks/driftctl-unittest-2", - "ip_cidr_range": "10.2.0.0/16", - "log_config": null, - "name": "driftctl-unittest-2", - "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/global/networks/test-network-2871346572", - "private_ip_google_access": false, - "private_ipv6_google_access": "DISABLE_GOOGLE_ACCESS", - "project": "cloudskiff-dev-raphael", - "region": "us-central1", - "secondary_ip_range": [ - { - "ip_cidr_range": "192.168.10.0/24", - "range_name": "tf-test-secondary-range-update1" - } - ], - "self_link": null, - "timeouts": null - }, - { - "creation_timestamp": "2021-10-20T07:39:34.650-07:00", - "description": "", - "fingerprint": null, - "gateway_address": "10.2.0.1", - "id": "projects/cloudskiff-dev-raphael/regions/us-central1/subnetworks/driftctl-unittest-3", - "ip_cidr_range": "10.2.0.0/16", - "log_config": null, - "name": "driftctl-unittest-3", - "network": "https://www.googleapis.com/compute/v1/projects/cloudskiff-dev-raphael/global/networks/test-network-3871346572", - "private_ip_google_access": false, - "private_ipv6_google_access": "DISABLE_GOOGLE_ACCESS", - "project": "cloudskiff-dev-raphael", - "region": "us-central1", - "secondary_ip_range": [ - { - "ip_cidr_range": "192.168.10.0/24", - "range_name": "tf-test-secondary-range-update1" - } - ], - "self_link": null, - "timeouts": null - } -] \ No newline at end of file diff --git a/pkg/resource/aws/aws_ami.go b/pkg/resource/aws/aws_ami.go index b5fd80d3..95d52108 100644 --- a/pkg/resource/aws/aws_ami.go +++ b/pkg/resource/aws/aws_ami.go @@ -1,15 +1,14 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsAmiResourceType = "aws_ami" - func initAwsAmiMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsAmiResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsAmiResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetFlags(AwsAmiResourceType, resource.FlagDeepMode) + resourceSchemaRepository.SetFlags(aws.AwsAmiResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_apigatewayv2_mapping.go b/pkg/resource/aws/aws_apigatewayv2_mapping.go deleted file mode 100644 index 7bd05307..00000000 --- a/pkg/resource/aws/aws_apigatewayv2_mapping.go +++ /dev/null @@ -1,23 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsApiGatewayV2MappingResourceType = "aws_apigatewayv2_api_mapping" - -func initAwsApiGatewayV2MappingMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc( - AwsApiGatewayV2MappingResourceType, - func(res *resource.Resource) map[string]string { - attrs := make(map[string]string) - - if v := res.Attributes().GetString("api_id"); v != nil { - attrs["Api"] = *v - } - if v := res.Attributes().GetString("stage"); v != nil { - attrs["Stage"] = *v - } - - return attrs - }, - ) -} diff --git a/pkg/resource/aws/aws_apigatewayv2_model.go b/pkg/resource/aws/aws_apigatewayv2_model.go deleted file mode 100644 index 3919784f..00000000 --- a/pkg/resource/aws/aws_apigatewayv2_model.go +++ /dev/null @@ -1,16 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsApiGatewayV2ModelResourceType = "aws_apigatewayv2_model" - -func initAwsApiGatewayV2ModelMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc( - AwsApiGatewayV2ModelResourceType, - func(res *resource.Resource) map[string]string { - return map[string]string{ - "name": *res.Attributes().GetString("name"), - } - }, - ) -} diff --git a/pkg/resource/aws/aws_appautoscaling_policy.go b/pkg/resource/aws/aws_appautoscaling_policy.go deleted file mode 100644 index 1bd7ce6c..00000000 --- a/pkg/resource/aws/aws_appautoscaling_policy.go +++ /dev/null @@ -1,24 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsAppAutoscalingPolicyResourceType = "aws_appautoscaling_policy" - -func initAwsAppAutoscalingPolicyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsAppAutoscalingPolicyResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "name": *res.Attributes().GetString("name"), - "resource_id": *res.Attributes().GetString("resource_id"), - "service_namespace": *res.Attributes().GetString("service_namespace"), - "scalable_dimension": *res.Attributes().GetString("scalable_dimension"), - } - }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsAppAutoscalingPolicyResourceType, func(res *resource.Resource) map[string]string { - attrs := make(map[string]string) - if v := res.Attributes().GetString("scalable_dimension"); v != nil && *v != "" { - attrs["Scalable dimension"] = *v - } - return attrs - }) - resourceSchemaRepository.SetFlags(AwsAppAutoscalingPolicyResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_appautoscaling_target.go b/pkg/resource/aws/aws_appautoscaling_target.go index 8c2d6dbf..110037a8 100644 --- a/pkg/resource/aws/aws_appautoscaling_target.go +++ b/pkg/resource/aws/aws_appautoscaling_target.go @@ -1,25 +1,12 @@ package aws -import "github.com/snyk/driftctl/pkg/resource" - -const AwsAppAutoscalingTargetResourceType = "aws_appautoscaling_target" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) func initAwsAppAutoscalingTargetMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsAppAutoscalingTargetResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "service_namespace": *res.Attributes().GetString("service_namespace"), - "scalable_dimension": *res.Attributes().GetString("scalable_dimension"), - } - }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsAppAutoscalingTargetResourceType, func(res *resource.Resource) map[string]string { - attrs := make(map[string]string) - if v := res.Attributes().GetString("scalable_dimension"); v != nil && *v != "" { - attrs["Scalable dimension"] = *v - } - return attrs - }) - resourceSchemaRepository.SetFlags(AwsAppAutoscalingTargetResourceType, resource.FlagDeepMode) - resourceSchemaRepository.SetDiscriminantFunc(AwsAppAutoscalingTargetResourceType, func(self, target *resource.Resource) bool { + resourceSchemaRepository.SetDiscriminantFunc(aws.AwsAppAutoscalingTargetResourceType, func(self, target *resource.Resource) bool { return *self.Attributes().GetString("scalable_dimension") == *target.Attributes().GetString("scalable_dimension") }) } diff --git a/pkg/resource/aws/aws_cloudformation_stack.go b/pkg/resource/aws/aws_cloudformation_stack.go index e092bd8f..b3a53dfb 100644 --- a/pkg/resource/aws/aws_cloudformation_stack.go +++ b/pkg/resource/aws/aws_cloudformation_stack.go @@ -1,26 +1,13 @@ package aws import ( - "strconv" - - "github.com/hashicorp/terraform/flatmap" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsCloudformationStackResourceType = "aws_cloudformation_stack" - func initAwsCloudformationStackMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsCloudformationStackResourceType, func(res *resource.Resource) map[string]string { - attrs := make(map[string]interface{}) - if v := res.Attributes().GetMap("parameters"); v != nil { - attrs["parameters.%"] = strconv.FormatInt(int64(len(v)), 10) - attrs["parameters"] = v - } - return flatmap.Flatten(attrs) - }) - resourceSchemaRepository.SetNormalizeFunc(AwsCloudformationStackResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsCloudformationStackResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetFlags(AwsCloudformationStackResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_cloudfront_distribution.go b/pkg/resource/aws/aws_cloudfront_distribution.go index 0ac48a9d..71450602 100644 --- a/pkg/resource/aws/aws_cloudfront_distribution.go +++ b/pkg/resource/aws/aws_cloudfront_distribution.go @@ -1,13 +1,12 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsCloudfrontDistributionResourceType = "aws_cloudfront_distribution" - func initAwsCloudfrontDistributionMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsCloudfrontDistributionResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsCloudfrontDistributionResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"etag"}) val.SafeDelete([]string{"last_modified_time"}) @@ -15,5 +14,4 @@ func initAwsCloudfrontDistributionMetaData(resourceSchemaRepository resource.Sch val.SafeDelete([]string{"status"}) val.SafeDelete([]string{"wait_for_deployment"}) }) - resourceSchemaRepository.SetFlags(AwsCloudfrontDistributionResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_cloudfront_distribution_test.go b/pkg/resource/aws/aws_cloudfront_distribution_test.go index c03dffb6..84d0106c 100644 --- a/pkg/resource/aws/aws_cloudfront_distribution_test.go +++ b/pkg/resource/aws/aws_cloudfront_distribution_test.go @@ -5,8 +5,8 @@ import ( "github.com/aws/aws-sdk-go/service/cloudfront" + awsresources "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/pkg/analyser" - awsresources "github.com/snyk/driftctl/pkg/resource/aws" "github.com/snyk/driftctl/test" "github.com/r3labs/diff/v2" diff --git a/pkg/resource/aws/aws_db_instance.go b/pkg/resource/aws/aws_db_instance.go index 8290326a..291e4722 100644 --- a/pkg/resource/aws/aws_db_instance.go +++ b/pkg/resource/aws/aws_db_instance.go @@ -1,13 +1,12 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsDbInstanceResourceType = "aws_db_instance" - func initAwsDbInstanceMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsDbInstanceResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsDbInstanceResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"delete_automated_backups"}) val.SafeDelete([]string{"final_snapshot_identifier"}) @@ -20,5 +19,4 @@ func initAwsDbInstanceMetaData(resourceSchemaRepository resource.SchemaRepositor val.SafeDelete([]string{"apply_immediately"}) val.DeleteIfDefault("CharacterSetName") }) - resourceSchemaRepository.SetFlags(AwsDbInstanceResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_db_subnet_group.go b/pkg/resource/aws/aws_db_subnet_group.go index c87ba216..f02e554f 100644 --- a/pkg/resource/aws/aws_db_subnet_group.go +++ b/pkg/resource/aws/aws_db_subnet_group.go @@ -1,15 +1,13 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsDbSubnetGroupResourceType = "aws_db_subnet_group" - func initAwsDbSubnetGroupMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsDbSubnetGroupResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsDbSubnetGroupResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"name_prefix"}) }) - resourceSchemaRepository.SetFlags(AwsDbSubnetGroupResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_default_network_acl.go b/pkg/resource/aws/aws_default_network_acl.go index bfbc8ec9..0279b62e 100644 --- a/pkg/resource/aws/aws_default_network_acl.go +++ b/pkg/resource/aws/aws_default_network_acl.go @@ -1,14 +1,12 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsDefaultNetworkACLResourceType = "aws_default_network_acl" - func initAwsDefaultNetworkACLMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetFlags(AwsDefaultNetworkACLResourceType, resource.FlagDeepMode) - resourceSchemaRepository.SetNormalizeFunc(AwsDefaultNetworkACLResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsDefaultNetworkACLResourceType, func(res *resource.Resource) { res.Attrs.SafeDelete([]string{"default_network_acl_id"}) // https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/default_network_acl#managing-subnets-in-a-default-network-acl diff --git a/pkg/resource/aws/aws_default_route_table.go b/pkg/resource/aws/aws_default_route_table.go deleted file mode 100644 index 0f18d2e8..00000000 --- a/pkg/resource/aws/aws_default_route_table.go +++ /dev/null @@ -1,18 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsDefaultRouteTableResourceType = "aws_default_route_table" - -func initAwsDefaultRouteTableMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsDefaultRouteTableResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "vpc_id": *res.Attributes().GetString("vpc_id"), - } - }) - resourceSchemaRepository.SetFlags(AwsDefaultRouteTableResourceType, resource.FlagDeepMode) - resourceSchemaRepository.SetNormalizeFunc(AwsDefaultRouteTableResourceType, func(res *resource.Resource) { - val := res.Attrs - val.SafeDelete([]string{"timeouts"}) - }) -} diff --git a/pkg/resource/aws/aws_default_security_group.go b/pkg/resource/aws/aws_default_security_group.go index a59ba3fc..a85a9b10 100644 --- a/pkg/resource/aws/aws_default_security_group.go +++ b/pkg/resource/aws/aws_default_security_group.go @@ -1,17 +1,15 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsDefaultSecurityGroupResourceType = "aws_default_security_group" - func initAwsDefaultSecurityGroupMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsDefaultSecurityGroupResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsDefaultSecurityGroupResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"revoke_rules_on_delete"}) val.SafeDelete([]string{"ingress"}) val.SafeDelete([]string{"egress"}) }) - resourceSchemaRepository.SetFlags(AwsDefaultSecurityGroupResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_default_subnet.go b/pkg/resource/aws/aws_default_subnet.go index 622fad33..fd166390 100644 --- a/pkg/resource/aws/aws_default_subnet.go +++ b/pkg/resource/aws/aws_default_subnet.go @@ -1,15 +1,13 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsDefaultSubnetResourceType = "aws_default_subnet" - func initAwsDefaultSubnetMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsDefaultSubnetResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsDefaultSubnetResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetFlags(AwsDefaultSubnetResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_default_vpc.go b/pkg/resource/aws/aws_default_vpc.go deleted file mode 100644 index c9319158..00000000 --- a/pkg/resource/aws/aws_default_vpc.go +++ /dev/null @@ -1,9 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsDefaultVpcResourceType = "aws_default_vpc" - -func initAwsDefaultVpcMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetFlags(AwsDefaultVpcResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_dynamodb_table.go b/pkg/resource/aws/aws_dynamodb_table.go index 74bdefa6..d6d24a5f 100644 --- a/pkg/resource/aws/aws_dynamodb_table.go +++ b/pkg/resource/aws/aws_dynamodb_table.go @@ -1,21 +1,13 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsDynamodbTableResourceType = "aws_dynamodb_table" - func initAwsDynamodbTableMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsDynamodbTableResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "table_name": res.ResourceId(), - } - }) - resourceSchemaRepository.SetNormalizeFunc(AwsDynamodbTableResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsDynamodbTableResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetFlags(AwsDynamodbTableResourceType, resource.FlagDeepMode) - } diff --git a/pkg/resource/aws/aws_ebs_encryption_by_default.go b/pkg/resource/aws/aws_ebs_encryption_by_default.go deleted file mode 100644 index 7676628a..00000000 --- a/pkg/resource/aws/aws_ebs_encryption_by_default.go +++ /dev/null @@ -1,9 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsEbsEncryptionByDefaultResourceType = "aws_ebs_encryption_by_default" - -func initAwsEbsEncryptionByDefaultMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetFlags(AwsEbsEncryptionByDefaultResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_ebs_snapshot.go b/pkg/resource/aws/aws_ebs_snapshot.go index c9d11e13..9bce422a 100644 --- a/pkg/resource/aws/aws_ebs_snapshot.go +++ b/pkg/resource/aws/aws_ebs_snapshot.go @@ -1,15 +1,13 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsEbsSnapshotResourceType = "aws_ebs_snapshot" - func initAwsEbsSnapshotMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsEbsSnapshotResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsEbsSnapshotResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetFlags(AwsEbsSnapshotResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_ebs_volume.go b/pkg/resource/aws/aws_ebs_volume.go index 0a8a299d..d69f59a0 100644 --- a/pkg/resource/aws/aws_ebs_volume.go +++ b/pkg/resource/aws/aws_ebs_volume.go @@ -1,18 +1,16 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsEbsVolumeResourceType = "aws_ebs_volume" - func initAwsEbsVolumeMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsEbsVolumeResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsEbsVolumeResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"arn"}) val.SafeDelete([]string{"outpost_arn"}) val.SafeDelete([]string{"snapshot_id"}) val.DeleteIfDefault("throughput") }) - resourceSchemaRepository.SetFlags(AwsEbsVolumeResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_ecr_repository.go b/pkg/resource/aws/aws_ecr_repository.go index b9723d14..431c3d89 100644 --- a/pkg/resource/aws/aws_ecr_repository.go +++ b/pkg/resource/aws/aws_ecr_repository.go @@ -1,13 +1,13 @@ package aws -import "github.com/snyk/driftctl/pkg/resource" - -const AwsEcrRepositoryResourceType = "aws_ecr_repository" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) func initAwsEcrRepositoryMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsEcrRepositoryResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsEcrRepositoryResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetFlags(AwsEcrRepositoryResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_ecr_repository_test.go b/pkg/resource/aws/aws_ecr_repository_test.go index d2afd53b..c8082ba8 100644 --- a/pkg/resource/aws/aws_ecr_repository_test.go +++ b/pkg/resource/aws/aws_ecr_repository_test.go @@ -8,8 +8,8 @@ import ( "github.com/aws/aws-sdk-go/service/ecr" "github.com/r3labs/diff/v2" + awsresources "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/pkg/analyser" - awsresources "github.com/snyk/driftctl/pkg/resource/aws" "github.com/snyk/driftctl/test" "github.com/snyk/driftctl/test/acceptance" "github.com/snyk/driftctl/test/acceptance/awsutils" diff --git a/pkg/resource/aws/aws_eip.go b/pkg/resource/aws/aws_eip.go index c4f95998..d25e820e 100644 --- a/pkg/resource/aws/aws_eip.go +++ b/pkg/resource/aws/aws_eip.go @@ -1,15 +1,13 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsEipResourceType = "aws_eip" - func initAwsEipMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsEipResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsEipResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetFlags(AwsEipResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_eip_association.go b/pkg/resource/aws/aws_eip_association.go deleted file mode 100644 index 2ad67570..00000000 --- a/pkg/resource/aws/aws_eip_association.go +++ /dev/null @@ -1,9 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsEipAssociationResourceType = "aws_eip_association" - -func initAwsEipAssociationMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetFlags(AwsEipAssociationResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_iam_access_key.go b/pkg/resource/aws/aws_iam_access_key.go index cce99bc9..dc506d8b 100644 --- a/pkg/resource/aws/aws_iam_access_key.go +++ b/pkg/resource/aws/aws_iam_access_key.go @@ -1,20 +1,13 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsIamAccessKeyResourceType = "aws_iam_access_key" - func initAwsIAMAccessKeyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsIamAccessKeyResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "user": *res.Attributes().GetString("user"), - } - }) - - resourceSchemaRepository.SetNormalizeFunc(AwsIamAccessKeyResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsIamAccessKeyResourceType, func(res *resource.Resource) { val := res.Attrs // As we can't read secrets from aws API once access_key created we need to set // fields retrieved from state to nil to avoid drift @@ -26,13 +19,4 @@ func initAwsIAMAccessKeyMetaData(resourceSchemaRepository resource.SchemaReposit val.SafeDelete([]string{"key_fingerprint"}) val.SafeDelete([]string{"pgp_key"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsIamAccessKeyResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if user := val.GetString("user"); user != nil && *user != "" { - attrs["User"] = *user - } - return attrs - }) - resourceSchemaRepository.SetFlags(AwsIamAccessKeyResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_iam_policy.go b/pkg/resource/aws/aws_iam_policy.go index 7f5492e4..e6adfbe6 100644 --- a/pkg/resource/aws/aws_iam_policy.go +++ b/pkg/resource/aws/aws_iam_policy.go @@ -1,19 +1,13 @@ package aws import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/pkg/helpers" - "github.com/snyk/driftctl/pkg/resource" ) -const AwsIamPolicyResourceType = "aws_iam_policy" - func initAwsIAMPolicyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.UpdateSchema(AwsIamPolicyResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ - "policy": func(attributeSchema *resource.AttributeSchema) { - attributeSchema.JsonString = true - }, - }) - resourceSchemaRepository.SetNormalizeFunc(AwsIamPolicyResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsIamPolicyResourceType, func(res *resource.Resource) { val := res.Attrs jsonString, err := helpers.NormalizeJsonString((*val)["policy"]) if err == nil { @@ -22,5 +16,4 @@ func initAwsIAMPolicyMetaData(resourceSchemaRepository resource.SchemaRepository val.SafeDelete([]string{"name_prefix"}) }) - resourceSchemaRepository.SetFlags(AwsIamPolicyResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_iam_policy_attachment.go b/pkg/resource/aws/aws_iam_policy_attachment.go index e50feb30..dcfca92d 100644 --- a/pkg/resource/aws/aws_iam_policy_attachment.go +++ b/pkg/resource/aws/aws_iam_policy_attachment.go @@ -1,15 +1,13 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsIamPolicyAttachmentResourceType = "aws_iam_policy_attachment" - func initAwsIAMPolicyAttachmentMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsIamPolicyAttachmentResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsIamPolicyAttachmentResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"name"}) }) - resourceSchemaRepository.SetFlags(AwsIamPolicyAttachmentResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_iam_role.go b/pkg/resource/aws/aws_iam_role.go index ebd09961..abc9e987 100644 --- a/pkg/resource/aws/aws_iam_role.go +++ b/pkg/resource/aws/aws_iam_role.go @@ -1,20 +1,13 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsIamRoleResourceType = "aws_iam_role" - func initAwsIAMRoleMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.UpdateSchema(AwsIamRoleResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ - "assume_role_policy": func(attributeSchema *resource.AttributeSchema) { - attributeSchema.JsonString = true - }, - }) - resourceSchemaRepository.SetNormalizeFunc(AwsIamRoleResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsIamRoleResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"force_detach_policies"}) }) - resourceSchemaRepository.SetFlags(AwsIamRoleResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_iam_role_policy.go b/pkg/resource/aws/aws_iam_role_policy.go deleted file mode 100644 index 255cdaae..00000000 --- a/pkg/resource/aws/aws_iam_role_policy.go +++ /dev/null @@ -1,16 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/resource" -) - -const AwsIamRolePolicyResourceType = "aws_iam_role_policy" - -func initAwsIAMRolePolicyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.UpdateSchema(AwsIamRolePolicyResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ - "policy": func(attributeSchema *resource.AttributeSchema) { - attributeSchema.JsonString = true - }, - }) - resourceSchemaRepository.SetFlags(AwsIamRolePolicyResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_iam_role_policy_attachment.go b/pkg/resource/aws/aws_iam_role_policy_attachment.go deleted file mode 100644 index c8b7396a..00000000 --- a/pkg/resource/aws/aws_iam_role_policy_attachment.go +++ /dev/null @@ -1,15 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsIamRolePolicyAttachmentResourceType = "aws_iam_role_policy_attachment" - -func initAwsIamRolePolicyAttachmentMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsIamRolePolicyAttachmentResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "role": *res.Attributes().GetString("role"), - "policy_arn": *res.Attributes().GetString("policy_arn"), - } - }) - resourceSchemaRepository.SetFlags(AwsIamRolePolicyAttachmentResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_iam_user.go b/pkg/resource/aws/aws_iam_user.go index c7a6ce54..12b1cd93 100644 --- a/pkg/resource/aws/aws_iam_user.go +++ b/pkg/resource/aws/aws_iam_user.go @@ -1,13 +1,12 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsIamUserResourceType = "aws_iam_user" - func initAwsIAMUserMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsIamUserResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsIamUserResourceType, func(res *resource.Resource) { val := res.Attrs permissionsBoundary, exist := val.Get("permissions_boundary") if exist && permissionsBoundary == "" { @@ -15,5 +14,4 @@ func initAwsIAMUserMetaData(resourceSchemaRepository resource.SchemaRepositoryIn } val.SafeDelete([]string{"force_destroy"}) }) - resourceSchemaRepository.SetFlags(AwsIamUserResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_iam_user_policy.go b/pkg/resource/aws/aws_iam_user_policy.go deleted file mode 100644 index 19cfc22d..00000000 --- a/pkg/resource/aws/aws_iam_user_policy.go +++ /dev/null @@ -1,16 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/resource" -) - -const AwsIamUserPolicyResourceType = "aws_iam_user_policy" - -func initAwsIAMUserPolicyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.UpdateSchema(AwsIamUserPolicyResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ - "policy": func(attributeSchema *resource.AttributeSchema) { - attributeSchema.JsonString = true - }, - }) - resourceSchemaRepository.SetFlags(AwsIamUserPolicyResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_iam_user_policy_attachment.go b/pkg/resource/aws/aws_iam_user_policy_attachment.go deleted file mode 100644 index 5f1aa805..00000000 --- a/pkg/resource/aws/aws_iam_user_policy_attachment.go +++ /dev/null @@ -1,15 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsIamUserPolicyAttachmentResourceType = "aws_iam_user_policy_attachment" - -func initAwsIamUserPolicyAttachmentMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsIamUserPolicyAttachmentResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "user": *res.Attributes().GetString("user"), - "policy_arn": *res.Attributes().GetString("policy_arn"), - } - }) - resourceSchemaRepository.SetFlags(AwsIamUserPolicyAttachmentResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_instance.go b/pkg/resource/aws/aws_instance.go index 4d746092..41760ae5 100644 --- a/pkg/resource/aws/aws_instance.go +++ b/pkg/resource/aws/aws_instance.go @@ -2,13 +2,12 @@ package aws import ( "github.com/hashicorp/go-version" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsInstanceResourceType = "aws_instance" - func initAwsInstanceMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsInstanceResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsInstanceResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"timeouts"}) @@ -16,15 +15,4 @@ func initAwsInstanceMetaData(resourceSchemaRepository resource.SchemaRepositoryI val.SafeDelete([]string{"instance_initiated_shutdown_behavior"}) } }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsInstanceResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if tags := val.GetMap("tags"); tags != nil { - if name, ok := tags["Name"]; ok { - attrs["Name"] = name.(string) - } - } - return attrs - }) - resourceSchemaRepository.SetFlags(AwsInstanceResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_instance_test.go b/pkg/resource/aws/aws_instance_test.go index 6f658f02..f864c075 100644 --- a/pkg/resource/aws/aws_instance_test.go +++ b/pkg/resource/aws/aws_instance_test.go @@ -9,7 +9,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/r3labs/diff/v2" - awsresources "github.com/snyk/driftctl/pkg/resource/aws" + awsresources "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/test/acceptance" "github.com/snyk/driftctl/test/acceptance/awsutils" ) diff --git a/pkg/resource/aws/aws_internet_gateway.go b/pkg/resource/aws/aws_internet_gateway.go deleted file mode 100644 index 8a258909..00000000 --- a/pkg/resource/aws/aws_internet_gateway.go +++ /dev/null @@ -1,11 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/resource" -) - -const AwsInternetGatewayResourceType = "aws_internet_gateway" - -func initAwsInternetGatewayMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetFlags(AwsInternetGatewayResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_key_pair.go b/pkg/resource/aws/aws_key_pair.go index b780a42b..0640764c 100644 --- a/pkg/resource/aws/aws_key_pair.go +++ b/pkg/resource/aws/aws_key_pair.go @@ -1,16 +1,14 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsKeyPairResourceType = "aws_key_pair" - func initAwsKeyPairMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsKeyPairResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsKeyPairResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"key_name_prefix"}) val.SafeDelete([]string{"public_key"}) }) - resourceSchemaRepository.SetFlags(AwsKeyPairResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_kms_alias.go b/pkg/resource/aws/aws_kms_alias.go index 63de3103..4b5cd417 100644 --- a/pkg/resource/aws/aws_kms_alias.go +++ b/pkg/resource/aws/aws_kms_alias.go @@ -1,16 +1,14 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsKmsAliasResourceType = "aws_kms_alias" - func initAwsKmsAliasMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsKmsAliasResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsKmsAliasResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"name"}) val.SafeDelete([]string{"name_prefix"}) }) - resourceSchemaRepository.SetFlags(AwsKmsAliasResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_kms_key.go b/pkg/resource/aws/aws_kms_key.go index c7316da7..ad242564 100644 --- a/pkg/resource/aws/aws_kms_key.go +++ b/pkg/resource/aws/aws_kms_key.go @@ -1,19 +1,13 @@ package aws import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/pkg/helpers" - "github.com/snyk/driftctl/pkg/resource" ) -const AwsKmsKeyResourceType = "aws_kms_key" - func initAwsKmsKeyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.UpdateSchema(AwsKmsKeyResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ - "policy": func(attributeSchema *resource.AttributeSchema) { - attributeSchema.JsonString = true - }, - }) - resourceSchemaRepository.SetNormalizeFunc(AwsKmsKeyResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsKmsKeyResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"deletion_window_in_days"}) jsonString, err := helpers.NormalizeJsonString((*val)["policy"]) @@ -22,5 +16,4 @@ func initAwsKmsKeyMetaData(resourceSchemaRepository resource.SchemaRepositoryInt } _ = val.SafeSet([]string{"policy"}, jsonString) }) - resourceSchemaRepository.SetFlags(AwsKmsKeyResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_lambda_event_source_mapping.go b/pkg/resource/aws/aws_lambda_event_source_mapping.go index 4cf4c50c..6723a2d8 100644 --- a/pkg/resource/aws/aws_lambda_event_source_mapping.go +++ b/pkg/resource/aws/aws_lambda_event_source_mapping.go @@ -1,13 +1,12 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsLambdaEventSourceMappingResourceType = "aws_lambda_event_source_mapping" - func initAwsLambdaEventSourceMappingMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsLambdaEventSourceMappingResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsLambdaEventSourceMappingResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"state_transition_reason"}) val.SafeDelete([]string{"state"}) @@ -16,16 +15,4 @@ func initAwsLambdaEventSourceMappingMetaData(resourceSchemaRepository resource.S val.SafeDelete([]string{"last_processing_result"}) val.SafeDelete([]string{"last_modified"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsLambdaEventSourceMappingResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - source := val.GetString("event_source_arn") - dest := val.GetString("function_name") - if source != nil && *source != "" && dest != nil && *dest != "" { - attrs["Source"] = *source - attrs["Dest"] = *dest - } - return attrs - }) - resourceSchemaRepository.SetFlags(AwsLambdaEventSourceMappingResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_lambda_function.go b/pkg/resource/aws/aws_lambda_function.go index 84b26c38..2e825e8e 100644 --- a/pkg/resource/aws/aws_lambda_function.go +++ b/pkg/resource/aws/aws_lambda_function.go @@ -1,20 +1,12 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsLambdaFunctionResourceType = "aws_lambda_function" - func initAwsLambdaFunctionMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsLambdaFunctionResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "function_name": res.ResourceId(), - } - }) - - resourceSchemaRepository.SetNormalizeFunc(AwsLambdaFunctionResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsLambdaFunctionResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"timeouts"}) val.SafeDelete([]string{"publish"}) @@ -27,5 +19,4 @@ func initAwsLambdaFunctionMetaData(resourceSchemaRepository resource.SchemaRepos val.DeleteIfDefault("signing_profile_version_arn") val.SafeDelete([]string{"source_code_size"}) }) - resourceSchemaRepository.SetFlags(AwsLambdaFunctionResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_launch_template.go b/pkg/resource/aws/aws_launch_template.go deleted file mode 100644 index 83189500..00000000 --- a/pkg/resource/aws/aws_launch_template.go +++ /dev/null @@ -1,9 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsLaunchTemplateResourceType = "aws_launch_template" - -func initAwsLaunchTemplateMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetFlags(AwsLaunchTemplateResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_lb.go b/pkg/resource/aws/aws_lb.go deleted file mode 100644 index 33c350ab..00000000 --- a/pkg/resource/aws/aws_lb.go +++ /dev/null @@ -1,13 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsLoadBalancerResourceType = "aws_lb" - -func initAwsLoadBalancerMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsLoadBalancerResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "Name": *res.Attributes().GetString("name"), - } - }) -} diff --git a/pkg/resource/aws/aws_nat_gateway.go b/pkg/resource/aws/aws_nat_gateway.go deleted file mode 100644 index c8eddae2..00000000 --- a/pkg/resource/aws/aws_nat_gateway.go +++ /dev/null @@ -1,9 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsNatGatewayResourceType = "aws_nat_gateway" - -func initNatGatewayMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetFlags(AwsNatGatewayResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_network_acl.go b/pkg/resource/aws/aws_network_acl.go deleted file mode 100644 index 1897a09f..00000000 --- a/pkg/resource/aws/aws_network_acl.go +++ /dev/null @@ -1,11 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/resource" -) - -const AwsNetworkACLResourceType = "aws_network_acl" - -func initAwsNetworkACLMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetFlags(AwsNetworkACLResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_network_acl_rule.go b/pkg/resource/aws/aws_network_acl_rule.go index 1a8d51f0..18b7dd54 100644 --- a/pkg/resource/aws/aws_network_acl_rule.go +++ b/pkg/resource/aws/aws_network_acl_rule.go @@ -1,16 +1,12 @@ package aws import ( - "bytes" - "fmt" "strconv" - "github.com/hashicorp/terraform/helper/hashcode" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsNetworkACLRuleResourceType = "aws_network_acl_rule" - var protocolsNumbers = map[string]int{ // defined at https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml "all": -1, @@ -161,52 +157,7 @@ var protocolsNumbers = map[string]int{ } func initAwsNetworkACLRuleMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetFlags(AwsNetworkACLRuleResourceType, resource.FlagDeepMode) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsNetworkACLRuleResourceType, func(res *resource.Resource) map[string]string { - - ruleNumber := strconv.FormatInt(int64(*res.Attrs.GetFloat64("rule_number")), 10) - if ruleNumber == "32767" { - ruleNumber = "*" - } - - attrs := map[string]string{ - "Network": *res.Attrs.GetString("network_acl_id"), - "Egress": strconv.FormatBool(*res.Attrs.GetBool("egress")), - "Rule number": ruleNumber, - } - - if proto := res.Attrs.GetString("protocol"); proto != nil { - if *proto == "-1" { - *proto = "All" - } - attrs["Protocol"] = *proto - } - - if res.Attrs.GetFloat64("from_port") != nil && res.Attrs.GetFloat64("to_port") != nil { - attrs["Port range"] = fmt.Sprintf("%d - %d", - int64(*res.Attrs.GetFloat64("from_port")), - int64(*res.Attrs.GetFloat64("to_port")), - ) - } - - if cidr := res.Attrs.GetString("cidr_block"); cidr != nil && *cidr != "" { - attrs["CIDR"] = *cidr - } - - if cidr := res.Attrs.GetString("ipv6_cidr_block"); cidr != nil && *cidr != "" { - attrs["CIDR"] = *cidr - } - - return attrs - }) - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsNetworkACLRuleResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "network_acl_id": *res.Attrs.GetString("network_acl_id"), - "rule_number": strconv.FormatInt(int64(*res.Attrs.GetFloat64("rule_number")), 10), - "egress": strconv.FormatBool(*res.Attrs.GetBool("egress")), - } - }) - resourceSchemaRepository.SetNormalizeFunc(AwsNetworkACLRuleResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsNetworkACLRuleResourceType, func(res *resource.Resource) { res.Attrs.DeleteIfDefault("icmp_code") res.Attrs.DeleteIfDefault("icmp_type") @@ -229,7 +180,7 @@ func initAwsNetworkACLRuleMetaData(resourceSchemaRepository resource.SchemaRepos // While reading remote we always got protocol as a number. // We cannot predict how the user decided to write the protocol on IaC side. // This workaround is mandatory to harmonize resources ID - res.Id = CreateNetworkACLRuleID( + res.Id = aws.CreateNetworkACLRuleID( *res.Attrs.GetString("network_acl_id"), int(*res.Attrs.GetFloat64("rule_number")), *res.Attrs.GetBool("egress"), @@ -241,12 +192,3 @@ func initAwsNetworkACLRuleMetaData(resourceSchemaRepository resource.SchemaRepos res.Attrs.DeleteIfDefault("ipv6_cidr_block") }) } - -func CreateNetworkACLRuleID(networkAclId string, ruleNumber int, egress bool, protocol string) string { - var buf bytes.Buffer - buf.WriteString(fmt.Sprintf("%s-", networkAclId)) - buf.WriteString(fmt.Sprintf("%d-", ruleNumber)) - buf.WriteString(fmt.Sprintf("%t-", egress)) - buf.WriteString(fmt.Sprintf("%s-", protocol)) - return fmt.Sprintf("nacl-%d", hashcode.String(buf.String())) -} diff --git a/pkg/resource/aws/aws_rds_cluster.go b/pkg/resource/aws/aws_rds_cluster.go index bf9eb1fb..6ade66f3 100644 --- a/pkg/resource/aws/aws_rds_cluster.go +++ b/pkg/resource/aws/aws_rds_cluster.go @@ -1,19 +1,12 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsRDSClusterResourceType = "aws_rds_cluster" - func initAwsRDSClusterMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsRDSClusterResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "cluster_identifier": *res.Attributes().GetString("cluster_identifier"), - "database_name": *res.Attributes().GetString("database_name"), - } - }) - resourceSchemaRepository.SetNormalizeFunc(AwsRDSClusterResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsRDSClusterResourceType, func(res *resource.Resource) { val := res.Attributes() val.SafeDelete([]string{"timeouts"}) val.SafeDelete([]string{"master_password"}) @@ -24,5 +17,4 @@ func initAwsRDSClusterMetaData(resourceSchemaRepository resource.SchemaRepositor val.SafeDelete([]string{"final_snapshot_identifier"}) val.SafeDelete([]string{"source_region"}) }) - resourceSchemaRepository.SetFlags(AwsRDSClusterResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_route.go b/pkg/resource/aws/aws_route.go index 2fa9210f..61474b51 100644 --- a/pkg/resource/aws/aws_route.go +++ b/pkg/resource/aws/aws_route.go @@ -1,31 +1,12 @@ package aws import ( - "fmt" - - "github.com/hashicorp/terraform/helper/hashcode" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsRouteResourceType = "aws_route" - func initAwsRouteMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsRouteResourceType, func(res *resource.Resource) map[string]string { - attributes := map[string]string{ - "route_table_id": *res.Attributes().GetString("route_table_id"), - } - if ipv4 := res.Attributes().GetString("destination_cidr_block"); ipv4 != nil && *ipv4 != "" { - attributes["destination_cidr_block"] = *ipv4 - } - if ipv6 := res.Attributes().GetString("destination_ipv6_cidr_block"); ipv6 != nil && *ipv6 != "" { - attributes["destination_ipv6_cidr_block"] = *ipv6 - } - if prefixes := res.Attributes().GetString("destination_prefix_list_id"); prefixes != nil && *prefixes != "" { - attributes["destination_prefix_list_id"] = *prefixes - } - return attributes - }) - resourceSchemaRepository.SetNormalizeFunc(AwsRouteResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsRouteResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"timeouts"}) @@ -44,38 +25,4 @@ func initAwsRouteMetaData(resourceSchemaRepository resource.SchemaRepositoryInte val.DeleteIfDefault("instance_owner_id") val.DeleteIfDefault("carrier_gateway_id") }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsRouteResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if rtID := val.GetString("route_table_id"); rtID != nil && *rtID != "" { - attrs["Table"] = *rtID - } - if ipv4 := val.GetString("destination_cidr_block"); ipv4 != nil && *ipv4 != "" { - attrs["Destination"] = *ipv4 - } - if ipv6 := val.GetString("destination_ipv6_cidr_block"); ipv6 != nil && *ipv6 != "" { - attrs["Destination"] = *ipv6 - } - if prefix := val.GetString("destination_prefix_list_id"); prefix != nil && *prefix != "" { - attrs["Destination"] = *prefix - } - return attrs - }) - resourceSchemaRepository.SetFlags(AwsRouteResourceType, resource.FlagDeepMode) -} - -func CalculateRouteID(tableId, CidrBlock, Ipv6CidrBlock, PrefixListId *string) string { - if CidrBlock != nil && *CidrBlock != "" { - return fmt.Sprintf("r-%s%d", *tableId, hashcode.String(*CidrBlock)) - } - - if Ipv6CidrBlock != nil && *Ipv6CidrBlock != "" { - return fmt.Sprintf("r-%s%d", *tableId, hashcode.String(*Ipv6CidrBlock)) - } - - if PrefixListId != nil && *PrefixListId != "" { - return fmt.Sprintf("r-%s%d", *tableId, hashcode.String(*PrefixListId)) - } - - return "" } diff --git a/pkg/resource/aws/aws_route53_health_check.go b/pkg/resource/aws/aws_route53_health_check.go deleted file mode 100644 index ad5e5eff..00000000 --- a/pkg/resource/aws/aws_route53_health_check.go +++ /dev/null @@ -1,43 +0,0 @@ -package aws - -import ( - "fmt" - - "github.com/snyk/driftctl/pkg/resource" -) - -const AwsRoute53HealthCheckResourceType = "aws_route53_health_check" - -func initAwsRoute53HealthCheckMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsRoute53HealthCheckResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if tags := val.GetMap("tags"); tags != nil { - if name, ok := tags["Name"]; ok { - attrs["Name"] = name.(string) - } - } - port := val.GetInt("port") - path := val.GetString("resource_path") - if fqdn := val.GetString("fqdn"); fqdn != nil && *fqdn != "" { - attrs["Fqdn"] = *fqdn - if port != nil { - attrs["Port"] = fmt.Sprintf("%d", *port) - } - if path != nil && *path != "" { - attrs["Path"] = *path - } - } - if address := val.GetString("ip_address"); address != nil && *address != "" { - attrs["IpAddress"] = *address - if port != nil { - attrs["Port"] = fmt.Sprintf("%d", *port) - } - if path != nil && *path != "" { - attrs["Path"] = *path - } - } - return attrs - }) - resourceSchemaRepository.SetFlags(AwsRoute53HealthCheckResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_route53_health_check_test.go b/pkg/resource/aws/aws_route53_health_check_test.go index 45a60334..6c5ff093 100644 --- a/pkg/resource/aws/aws_route53_health_check_test.go +++ b/pkg/resource/aws/aws_route53_health_check_test.go @@ -7,8 +7,8 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/r3labs/diff/v2" + awsresources "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/pkg/analyser" - awsresources "github.com/snyk/driftctl/pkg/resource/aws" "github.com/snyk/driftctl/test" "github.com/snyk/driftctl/test/acceptance" "github.com/snyk/driftctl/test/acceptance/awsutils" diff --git a/pkg/resource/aws/aws_route53_record.go b/pkg/resource/aws/aws_route53_record.go index 09a2db09..564a2e4e 100644 --- a/pkg/resource/aws/aws_route53_record.go +++ b/pkg/resource/aws/aws_route53_record.go @@ -1,13 +1,12 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsRoute53RecordResourceType = "aws_route53_record" - func initAwsRoute53RecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsRoute53RecordResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsRoute53RecordResourceType, func(res *resource.Resource) { val := res.Attrs val.DeleteIfDefault("health_check_id") val.DeleteIfDefault("set_identifier") @@ -15,19 +14,4 @@ func initAwsRoute53RecordMetaData(resourceSchemaRepository resource.SchemaReposi val.SafeDelete([]string{"name"}) val.SafeDelete([]string{"allow_overwrite"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsRoute53RecordResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if fqdn := val.GetString("fqdn"); fqdn != nil && *fqdn != "" { - attrs["Fqdn"] = *fqdn - } - if ty := val.GetString("type"); ty != nil && *ty != "" { - attrs["Type"] = *ty - } - if zoneID := val.GetString("zone_id"); zoneID != nil && *zoneID != "" { - attrs["ZoneId"] = *zoneID - } - return attrs - }) - resourceSchemaRepository.SetFlags(AwsRoute53RecordResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_route53_zone.go b/pkg/resource/aws/aws_route53_zone.go index 7a038a28..6166cda9 100644 --- a/pkg/resource/aws/aws_route53_zone.go +++ b/pkg/resource/aws/aws_route53_zone.go @@ -1,23 +1,13 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsRoute53ZoneResourceType = "aws_route53_zone" - func initAwsRoute53ZoneMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsRoute53ZoneResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsRoute53ZoneResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"force_destroy"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsRoute53ZoneResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - return attrs - }) - resourceSchemaRepository.SetFlags(AwsRoute53ZoneResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_route_table.go b/pkg/resource/aws/aws_route_table.go deleted file mode 100644 index 95a1af5c..00000000 --- a/pkg/resource/aws/aws_route_table.go +++ /dev/null @@ -1,13 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsRouteTableResourceType = "aws_route_table" - -func initAwsRouteTableMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetFlags(AwsRouteTableResourceType, resource.FlagDeepMode) - resourceSchemaRepository.SetNormalizeFunc(AwsRouteTableResourceType, func(res *resource.Resource) { - val := res.Attrs - val.SafeDelete([]string{"timeouts"}) - }) -} diff --git a/pkg/resource/aws/aws_route_table_association.go b/pkg/resource/aws/aws_route_table_association.go deleted file mode 100644 index 6c48ffcc..00000000 --- a/pkg/resource/aws/aws_route_table_association.go +++ /dev/null @@ -1,30 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsRouteTableAssociationResourceType = "aws_route_table_association" - -func initAwsRouteTableAssociationMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsRouteTableAssociationResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "route_table_id": *res.Attributes().GetString("route_table_id"), - } - }) - - resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsRouteTableAssociationResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if rtID := val.GetString("route_table_id"); rtID != nil && *rtID != "" { - attrs["Table"] = *rtID - } - if gtwID := val.GetString("gateway_id"); gtwID != nil && *gtwID != "" { - attrs["Gateway"] = *gtwID - } - if subnetID := val.GetString("subnet_id"); subnetID != nil && *subnetID != "" { - attrs["Subnet"] = *subnetID - } - return attrs - }) - resourceSchemaRepository.SetFlags(AwsRouteTableAssociationResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_s3_bucket.go b/pkg/resource/aws/aws_s3_bucket.go index 17d94807..bfe58785 100644 --- a/pkg/resource/aws/aws_s3_bucket.go +++ b/pkg/resource/aws/aws_s3_bucket.go @@ -1,26 +1,14 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsS3BucketResourceType = "aws_s3_bucket" - func initAwsS3BucketMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsS3BucketResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "alias": *res.Attributes().GetString("region"), - } - }) - resourceSchemaRepository.UpdateSchema(AwsS3BucketResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ - "policy": func(attributeSchema *resource.AttributeSchema) { - attributeSchema.JsonString = true - }, - }) - resourceSchemaRepository.SetNormalizeFunc(AwsS3BucketResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsS3BucketResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"force_destroy"}) val.SafeDelete([]string{"bucket_prefix"}) }) - resourceSchemaRepository.SetFlags(AwsS3BucketResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_s3_bucket_analytics_configuration.go b/pkg/resource/aws/aws_s3_bucket_analytics_configuration.go deleted file mode 100644 index f0660dca..00000000 --- a/pkg/resource/aws/aws_s3_bucket_analytics_configuration.go +++ /dev/null @@ -1,14 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsS3BucketAnalyticsConfigurationResourceType = "aws_s3_bucket_analytics_configuration" - -func initAwsS3BucketAnalyticsConfigurationMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsS3BucketAnalyticsConfigurationResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "alias": *res.Attributes().GetString("region"), - } - }) - resourceSchemaRepository.SetFlags(AwsS3BucketAnalyticsConfigurationResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_s3_bucket_inventory.go b/pkg/resource/aws/aws_s3_bucket_inventory.go deleted file mode 100644 index 79e78fd2..00000000 --- a/pkg/resource/aws/aws_s3_bucket_inventory.go +++ /dev/null @@ -1,14 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsS3BucketInventoryResourceType = "aws_s3_bucket_inventory" - -func initAwsS3BucketInventoryMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsS3BucketInventoryResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "alias": *res.Attributes().GetString("region"), - } - }) - resourceSchemaRepository.SetFlags(AwsS3BucketInventoryResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_s3_bucket_metric.go b/pkg/resource/aws/aws_s3_bucket_metric.go deleted file mode 100644 index 2c6cd0af..00000000 --- a/pkg/resource/aws/aws_s3_bucket_metric.go +++ /dev/null @@ -1,14 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsS3BucketMetricResourceType = "aws_s3_bucket_metric" - -func initAwsS3BucketMetricMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsS3BucketMetricResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "alias": *res.Attributes().GetString("region"), - } - }) - resourceSchemaRepository.SetFlags(AwsS3BucketMetricResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_s3_bucket_notification.go b/pkg/resource/aws/aws_s3_bucket_notification.go deleted file mode 100644 index 37e3cb29..00000000 --- a/pkg/resource/aws/aws_s3_bucket_notification.go +++ /dev/null @@ -1,14 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsS3BucketNotificationResourceType = "aws_s3_bucket_notification" - -func initAwsS3BucketNotificationMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsS3BucketNotificationResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "alias": *res.Attributes().GetString("region"), - } - }) - resourceSchemaRepository.SetFlags(AwsS3BucketNotificationResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_s3_bucket_policy.go b/pkg/resource/aws/aws_s3_bucket_policy.go index 7cbf2640..de9fc5f7 100644 --- a/pkg/resource/aws/aws_s3_bucket_policy.go +++ b/pkg/resource/aws/aws_s3_bucket_policy.go @@ -1,24 +1,13 @@ package aws import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/pkg/helpers" - "github.com/snyk/driftctl/pkg/resource" ) -const AwsS3BucketPolicyResourceType = "aws_s3_bucket_policy" - func initAwsS3BucketPolicyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsS3BucketPolicyResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "alias": *res.Attributes().GetString("region"), - } - }) - resourceSchemaRepository.UpdateSchema(AwsS3BucketPolicyResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ - "policy": func(attributeSchema *resource.AttributeSchema) { - attributeSchema.JsonString = true - }, - }) - resourceSchemaRepository.SetNormalizeFunc(AwsS3BucketPolicyResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsS3BucketPolicyResourceType, func(res *resource.Resource) { val := res.Attrs jsonString, err := helpers.NormalizeJsonString((*val)["policy"]) if err != nil { @@ -26,5 +15,4 @@ func initAwsS3BucketPolicyMetaData(resourceSchemaRepository resource.SchemaRepos } _ = val.SafeSet([]string{"policy"}, jsonString) }) - resourceSchemaRepository.SetFlags(AwsS3BucketPolicyResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_security_group.go b/pkg/resource/aws/aws_security_group.go index 93ec512e..7aba0bbb 100644 --- a/pkg/resource/aws/aws_security_group.go +++ b/pkg/resource/aws/aws_security_group.go @@ -1,18 +1,18 @@ package aws -import "github.com/snyk/driftctl/pkg/resource" - -const AwsSecurityGroupResourceType = "aws_security_group" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) func initAwsSecurityGroupMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsSecurityGroupResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsSecurityGroupResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"revoke_rules_on_delete"}) val.SafeDelete([]string{"timeouts"}) - //TODO We need to find a way to warn users that some rules in their states could be unmanaged + // TODO We need to find a way to warn users that some rules in their states could be unmanaged val.SafeDelete([]string{"ingress"}) val.SafeDelete([]string{"egress"}) }) - resourceSchemaRepository.SetFlags(AwsSecurityGroupResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_security_group_rule.go b/pkg/resource/aws/aws_security_group_rule.go index c782abf1..80546f01 100644 --- a/pkg/resource/aws/aws_security_group_rule.go +++ b/pkg/resource/aws/aws_security_group_rule.go @@ -1,96 +1,12 @@ package aws import ( - "bytes" - "fmt" - - "github.com/hashicorp/terraform/flatmap" - "github.com/hashicorp/terraform/helper/hashcode" - "github.com/snyk/driftctl/pkg/helpers" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsSecurityGroupRuleResourceType = "aws_security_group_rule" - -func CreateSecurityGroupRuleIdHash(attrs *resource.Attributes) string { - var buf bytes.Buffer - buf.WriteString(fmt.Sprintf("%s-", *attrs.GetString("security_group_id"))) - if attrs.GetInt("from_port") != nil && *attrs.GetInt("from_port") > 0 { - buf.WriteString(fmt.Sprintf("%d-", *attrs.GetInt("from_port"))) - } - if attrs.GetInt("to_port") != nil && *attrs.GetInt("to_port") > 0 { - buf.WriteString(fmt.Sprintf("%d-", *attrs.GetInt("to_port"))) - } - buf.WriteString(fmt.Sprintf("%s-", *attrs.GetString("protocol"))) - buf.WriteString(fmt.Sprintf("%s-", *attrs.GetString("type"))) - - if attrs.GetSlice("cidr_blocks") != nil { - for _, v := range attrs.GetSlice("cidr_blocks") { - buf.WriteString(fmt.Sprintf("%s-", v)) - } - } - - if attrs.GetSlice("ipv6_cidr_blocks") != nil { - for _, v := range attrs.GetSlice("ipv6_cidr_blocks") { - buf.WriteString(fmt.Sprintf("%s-", v)) - } - } - - if attrs.GetSlice("prefix_list_ids") != nil { - for _, v := range attrs.GetSlice("prefix_list_ids") { - buf.WriteString(fmt.Sprintf("%s-", v)) - } - } - - if (attrs.GetBool("self") != nil && *attrs.GetBool("self")) || - (attrs.GetString("source_security_group_id") != nil && *attrs.GetString("source_security_group_id") != "") { - if attrs.GetBool("self") != nil && *attrs.GetBool("self") { - buf.WriteString(fmt.Sprintf("%s-", *attrs.GetString("security_group_id"))) - } else { - buf.WriteString(fmt.Sprintf("%s-", *attrs.GetString("source_security_group_id"))) - } - buf.WriteString("-") - } - - return fmt.Sprintf("sgrule-%d", hashcode.String(buf.String())) -} - func initAwsSecurityGroupRuleMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsSecurityGroupRuleResourceType, func(res *resource.Resource) map[string]string { - attrs := make(map[string]interface{}) - if v, ok := res.Attributes().Get("type"); ok { - attrs["type"] = v - } - if v, ok := res.Attributes().Get("protocol"); ok { - attrs["protocol"] = v - } - if v := res.Attributes().GetInt("from_port"); v != nil { - attrs["from_port"] = *v - } - if v := res.Attributes().GetInt("to_port"); v != nil { - attrs["to_port"] = *v - } - if v, ok := res.Attributes().Get("security_group_id"); ok { - attrs["security_group_id"] = v - } - if v, ok := res.Attributes().Get("self"); ok { - attrs["self"] = v - } - if v, ok := res.Attributes().Get("cidr_blocks"); ok { - attrs["cidr_blocks"] = v - } - if v, ok := res.Attributes().Get("ipv6_cidr_blocks"); ok { - attrs["ipv6_cidr_blocks"] = v - } - if v, ok := res.Attributes().Get("prefix_list_ids"); ok { - attrs["prefix_list_ids"] = v - } - if v, ok := res.Attributes().Get("source_security_group_id"); ok { - attrs["source_security_group_id"] = v - } - return flatmap.Flatten(attrs) - }) - resourceSchemaRepository.SetNormalizeFunc(AwsSecurityGroupRuleResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsSecurityGroupRuleResourceType, func(res *resource.Resource) { val := res.Attrs val.DeleteIfDefault("security_group_id") val.DeleteIfDefault("source_security_group_id") @@ -108,58 +24,8 @@ func initAwsSecurityGroupRuleMetaData(resourceSchemaRepository resource.SchemaRe val.SafeDelete([]string{"from_port"}) val.SafeDelete([]string{"to_port"}) - id := CreateSecurityGroupRuleIdHash(val) + id := aws.CreateSecurityGroupRuleIdHash(val) _ = val.SafeSet([]string{"id"}, id) res.Id = id }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsSecurityGroupRuleResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if sgID := val.GetString("security_group_id"); sgID != nil && *sgID != "" { - attrs["SecurityGroup"] = *sgID - } - if protocol := val.GetString("protocol"); protocol != nil && *protocol != "" { - if *protocol == "-1" { - *protocol = "All" - } - attrs["Protocol"] = *protocol - } - fromPort := val.GetInt("from_port") - toPort := val.GetInt("to_port") - if fromPort != nil && toPort != nil { - portRange := "All" - if *fromPort != 0 && *fromPort == *toPort { - portRange = fmt.Sprintf("%d", *fromPort) - } - if *fromPort != 0 && *toPort != 0 && *fromPort != *toPort { - portRange = fmt.Sprintf("%d-%d", *fromPort, *toPort) - } - attrs["Ports"] = portRange - } - ty := val.GetString("type") - if ty != nil && *ty != "" { - attrs["Type"] = *ty - var sourceOrDestination string - switch *ty { - case "egress": - sourceOrDestination = "Destination" - case "ingress": - sourceOrDestination = "Source" - } - if ipv4 := val.GetSlice("cidr_blocks"); len(ipv4) > 0 { - attrs[sourceOrDestination] = helpers.Join(ipv4, ", ") - } - if ipv6 := val.GetSlice("ipv6_cidr_blocks"); len(ipv6) > 0 { - attrs[sourceOrDestination] = helpers.Join(ipv6, ", ") - } - if prefixList := val.GetSlice("prefix_list_ids"); len(prefixList) > 0 { - attrs[sourceOrDestination] = helpers.Join(prefixList, ", ") - } - if sourceSgID := val.GetString("source_security_group_id"); sourceSgID != nil && *sourceSgID != "" { - attrs[sourceOrDestination] = *sourceSgID - } - } - return attrs - }) - resourceSchemaRepository.SetFlags(AwsSecurityGroupRuleResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_sns_topic.go b/pkg/resource/aws/aws_sns_topic.go index da62d48b..494069de 100644 --- a/pkg/resource/aws/aws_sns_topic.go +++ b/pkg/resource/aws/aws_sns_topic.go @@ -1,24 +1,12 @@ package aws -import "github.com/snyk/driftctl/pkg/resource" - -const AwsSnsTopicResourceType = "aws_sns_topic" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" +) func initSnsTopicMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsSnsTopicResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "topic_arn": res.ResourceId(), - } - }) - resourceSchemaRepository.UpdateSchema(AwsSnsTopicResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ - "delivery_policy": func(attributeSchema *resource.AttributeSchema) { - attributeSchema.JsonString = true - }, - "policy": func(attributeSchema *resource.AttributeSchema) { - attributeSchema.JsonString = true - }, - }) - resourceSchemaRepository.SetNormalizeFunc(AwsSnsTopicResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsSnsTopicResourceType, func(res *resource.Resource) { val := res.Attrs val.DeleteIfDefault("sqs_success_feedback_sample_rate") val.DeleteIfDefault("lambda_success_feedback_sample_rate") @@ -29,16 +17,4 @@ func initSnsTopicMetaData(resourceSchemaRepository resource.SchemaRepositoryInte val.SafeDelete([]string{"name_prefix"}) val.SafeDelete([]string{"owner"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AwsSnsTopicResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - if displayName := val.GetString("display_name"); displayName != nil && *displayName != "" { - attrs["DisplayName"] = *displayName - } - } - return attrs - }) - resourceSchemaRepository.SetFlags(AwsSnsTopicResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_sns_topic_policy.go b/pkg/resource/aws/aws_sns_topic_policy.go index a3d3f667..a625cbc7 100644 --- a/pkg/resource/aws/aws_sns_topic_policy.go +++ b/pkg/resource/aws/aws_sns_topic_policy.go @@ -1,26 +1,13 @@ package aws import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/pkg/helpers" - "github.com/snyk/driftctl/pkg/resource" ) -const AwsSnsTopicPolicyResourceType = "aws_sns_topic_policy" - func initSnsTopicPolicyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsSnsTopicPolicyResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "topic_arn": res.ResourceId(), - } - }) - - resourceSchemaRepository.UpdateSchema(AwsSnsTopicPolicyResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ - "policy": func(attributeSchema *resource.AttributeSchema) { - attributeSchema.JsonString = true - }, - }) - - resourceSchemaRepository.SetNormalizeFunc(AwsSnsTopicPolicyResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsSnsTopicPolicyResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"owner"}) jsonString, err := helpers.NormalizeJsonString((*val)["policy"]) @@ -29,5 +16,4 @@ func initSnsTopicPolicyMetaData(resourceSchemaRepository resource.SchemaReposito } _ = val.SafeSet([]string{"policy"}, jsonString) }) - resourceSchemaRepository.SetFlags(AwsSnsTopicPolicyResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_sns_topic_policy_test.go b/pkg/resource/aws/aws_sns_topic_policy_test.go index 0f96334d..44110606 100644 --- a/pkg/resource/aws/aws_sns_topic_policy_test.go +++ b/pkg/resource/aws/aws_sns_topic_policy_test.go @@ -1,12 +1,12 @@ package aws_test import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" "testing" "time" "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/cache" "github.com/snyk/driftctl/test" "github.com/snyk/driftctl/test/acceptance" "github.com/snyk/driftctl/test/acceptance/awsutils" diff --git a/pkg/resource/aws/aws_sns_topic_subscription.go b/pkg/resource/aws/aws_sns_topic_subscription.go index d89721ed..ed41dacf 100644 --- a/pkg/resource/aws/aws_sns_topic_subscription.go +++ b/pkg/resource/aws/aws_sns_topic_subscription.go @@ -1,29 +1,13 @@ package aws import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/pkg/helpers" - "github.com/snyk/driftctl/pkg/resource" ) -const AwsSnsTopicSubscriptionResourceType = "aws_sns_topic_subscription" - func initSnsTopicSubscriptionMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(AwsSnsTopicSubscriptionResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "SubscriptionId": res.ResourceId(), - } - }) - - resourceSchemaRepository.UpdateSchema(AwsSnsTopicSubscriptionResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ - "delivery_policy": func(attributeSchema *resource.AttributeSchema) { - attributeSchema.JsonString = true - }, - "filter_policy": func(attributeSchema *resource.AttributeSchema) { - attributeSchema.JsonString = true - }, - }) - - resourceSchemaRepository.SetNormalizeFunc(AwsSnsTopicSubscriptionResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsSnsTopicSubscriptionResourceType, func(res *resource.Resource) { val := res.Attrs jsonString, err := helpers.NormalizeJsonString((*val)["delivery_policy"]) if err == nil { @@ -42,5 +26,4 @@ func initSnsTopicSubscriptionMetaData(resourceSchemaRepository resource.SchemaRe val.SafeDelete([]string{"confirmation_timeout_in_minutes"}) } }) - resourceSchemaRepository.SetFlags(AwsSnsTopicSubscriptionResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_sns_topic_subscription_test.go b/pkg/resource/aws/aws_sns_topic_subscription_test.go index fcd4acb4..f5f6cecb 100644 --- a/pkg/resource/aws/aws_sns_topic_subscription_test.go +++ b/pkg/resource/aws/aws_sns_topic_subscription_test.go @@ -1,12 +1,12 @@ package aws_test import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" "testing" "time" "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/remote/aws/repository" - "github.com/snyk/driftctl/pkg/remote/cache" "github.com/snyk/driftctl/test" "github.com/snyk/driftctl/test/acceptance" "github.com/snyk/driftctl/test/acceptance/awsutils" diff --git a/pkg/resource/aws/aws_sns_topic_test.go b/pkg/resource/aws/aws_sns_topic_test.go index cf45ddd8..e4e4492f 100644 --- a/pkg/resource/aws/aws_sns_topic_test.go +++ b/pkg/resource/aws/aws_sns_topic_test.go @@ -1,22 +1,21 @@ package aws_test import ( + "github.com/snyk/driftctl/enumeration/remote/aws/repository" + "github.com/snyk/driftctl/enumeration/remote/cache" "strings" "testing" "time" "github.com/sirupsen/logrus" - "github.com/snyk/driftctl/pkg/remote/cache" - - "github.com/snyk/driftctl/pkg/remote/aws/repository" "github.com/snyk/driftctl/test" "github.com/aws/aws-sdk-go/service/sns" "github.com/aws/aws-sdk-go/aws" "github.com/r3labs/diff/v2" + awsresources "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/pkg/analyser" - awsresources "github.com/snyk/driftctl/pkg/resource/aws" "github.com/snyk/driftctl/test/acceptance" "github.com/snyk/driftctl/test/acceptance/awsutils" ) diff --git a/pkg/resource/aws/aws_sqs_queue.go b/pkg/resource/aws/aws_sqs_queue.go deleted file mode 100644 index 6415c3e7..00000000 --- a/pkg/resource/aws/aws_sqs_queue.go +++ /dev/null @@ -1,11 +0,0 @@ -package aws - -import ( - "github.com/snyk/driftctl/pkg/resource" -) - -const AwsSqsQueueResourceType = "aws_sqs_queue" - -func initSqsQueueMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetFlags(AwsSqsQueueResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/aws_sqs_queue_policy.go b/pkg/resource/aws/aws_sqs_queue_policy.go index 263b1555..7c71accd 100644 --- a/pkg/resource/aws/aws_sqs_queue_policy.go +++ b/pkg/resource/aws/aws_sqs_queue_policy.go @@ -1,19 +1,13 @@ package aws import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/pkg/helpers" - "github.com/snyk/driftctl/pkg/resource" ) -const AwsSqsQueuePolicyResourceType = "aws_sqs_queue_policy" - func initAwsSQSQueuePolicyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.UpdateSchema(AwsSqsQueuePolicyResourceType, map[string]func(attributeSchema *resource.AttributeSchema){ - "policy": func(attributeSchema *resource.AttributeSchema) { - attributeSchema.JsonString = true - }, - }) - resourceSchemaRepository.SetNormalizeFunc(AwsSqsQueuePolicyResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsSqsQueuePolicyResourceType, func(res *resource.Resource) { val := res.Attrs jsonString, err := helpers.NormalizeJsonString((*val)["policy"]) if err != nil { @@ -21,5 +15,4 @@ func initAwsSQSQueuePolicyMetaData(resourceSchemaRepository resource.SchemaRepos } _ = val.SafeSet([]string{"policy"}, jsonString) }) - resourceSchemaRepository.SetFlags(AwsSqsQueuePolicyResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_sqs_queue_test.go b/pkg/resource/aws/aws_sqs_queue_test.go index 6ace4749..b50a9d47 100644 --- a/pkg/resource/aws/aws_sqs_queue_test.go +++ b/pkg/resource/aws/aws_sqs_queue_test.go @@ -6,8 +6,8 @@ import ( "github.com/sirupsen/logrus" + awsresources "github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/pkg/analyser" - awsresources "github.com/snyk/driftctl/pkg/resource/aws" "github.com/snyk/driftctl/test" "github.com/r3labs/diff/v2" diff --git a/pkg/resource/aws/aws_subnet.go b/pkg/resource/aws/aws_subnet.go index 1de8f2b5..9795a880 100644 --- a/pkg/resource/aws/aws_subnet.go +++ b/pkg/resource/aws/aws_subnet.go @@ -1,15 +1,13 @@ package aws import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/aws" ) -const AwsSubnetResourceType = "aws_subnet" - func initAwsSubnetMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AwsSubnetResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(aws.AwsSubnetResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetFlags(AwsSubnetResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/aws/aws_vpc.go b/pkg/resource/aws/aws_vpc.go deleted file mode 100644 index 848763b9..00000000 --- a/pkg/resource/aws/aws_vpc.go +++ /dev/null @@ -1,9 +0,0 @@ -package aws - -import "github.com/snyk/driftctl/pkg/resource" - -const AwsVpcResourceType = "aws_vpc" - -func initAwsVpcMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetFlags(AwsVpcResourceType, resource.FlagDeepMode) -} diff --git a/pkg/resource/aws/metadata_test.go b/pkg/resource/aws/metadata_test.go deleted file mode 100644 index 0f1d1f29..00000000 --- a/pkg/resource/aws/metadata_test.go +++ /dev/null @@ -1,140 +0,0 @@ -package aws - -import ( - "testing" - - "github.com/snyk/driftctl/pkg/resource" - tf "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - "github.com/stretchr/testify/assert" -) - -func TestAWS_Metadata_Flags(t *testing.T) { - testcases := map[string][]resource.Flags{ - AwsAmiResourceType: {resource.FlagDeepMode}, - AwsApiGatewayAccountResourceType: {}, - AwsApiGatewayApiKeyResourceType: {}, - AwsApiGatewayAuthorizerResourceType: {}, - AwsApiGatewayBasePathMappingResourceType: {}, - AwsApiGatewayDeploymentResourceType: {}, - AwsApiGatewayDomainNameResourceType: {}, - AwsApiGatewayGatewayResponseResourceType: {}, - AwsApiGatewayIntegrationResourceType: {}, - AwsApiGatewayIntegrationResponseResourceType: {}, - AwsApiGatewayMethodResourceType: {}, - AwsApiGatewayMethodResponseResourceType: {}, - AwsApiGatewayMethodSettingsResourceType: {}, - AwsApiGatewayModelResourceType: {}, - AwsApiGatewayRequestValidatorResourceType: {}, - AwsApiGatewayResourceResourceType: {}, - AwsApiGatewayRestApiResourceType: {}, - AwsApiGatewayRestApiPolicyResourceType: {}, - AwsApiGatewayStageResourceType: {}, - AwsApiGatewayVpcLinkResourceType: {}, - AwsApiGatewayV2ApiResourceType: {}, - AwsApiGatewayV2RouteResourceType: {}, - AwsApiGatewayV2DeploymentResourceType: {}, - AwsApiGatewayV2VpcLinkResourceType: {}, - AwsApiGatewayV2AuthorizerResourceType: {}, - AwsApiGatewayV2RouteResponseResourceType: {}, - AwsApiGatewayV2DomainNameResourceType: {}, - AwsApiGatewayV2ModelResourceType: {}, - AwsApiGatewayV2StageResourceType: {}, - AwsApiGatewayV2MappingResourceType: {}, - AwsApiGatewayV2IntegrationResourceType: {}, - AwsApiGatewayV2IntegrationResponseResourceType: {}, - AwsAppAutoscalingPolicyResourceType: {resource.FlagDeepMode}, - AwsAppAutoscalingScheduledActionResourceType: {}, - AwsAppAutoscalingTargetResourceType: {resource.FlagDeepMode}, - AwsCloudformationStackResourceType: {resource.FlagDeepMode}, - AwsCloudfrontDistributionResourceType: {resource.FlagDeepMode}, - AwsDbInstanceResourceType: {resource.FlagDeepMode}, - AwsDbSubnetGroupResourceType: {resource.FlagDeepMode}, - AwsDefaultNetworkACLResourceType: {resource.FlagDeepMode}, - AwsDefaultRouteTableResourceType: {resource.FlagDeepMode}, - AwsDefaultSecurityGroupResourceType: {resource.FlagDeepMode}, - AwsDefaultSubnetResourceType: {resource.FlagDeepMode}, - AwsDefaultVpcResourceType: {resource.FlagDeepMode}, - AwsDynamodbTableResourceType: {resource.FlagDeepMode}, - AwsEbsEncryptionByDefaultResourceType: {resource.FlagDeepMode}, - AwsEbsSnapshotResourceType: {resource.FlagDeepMode}, - AwsEbsVolumeResourceType: {resource.FlagDeepMode}, - AwsEcrRepositoryResourceType: {resource.FlagDeepMode}, - AwsEipResourceType: {resource.FlagDeepMode}, - AwsEipAssociationResourceType: {resource.FlagDeepMode}, - AwsElastiCacheClusterResourceType: {}, - AwsIamAccessKeyResourceType: {resource.FlagDeepMode}, - AwsIamPolicyResourceType: {resource.FlagDeepMode}, - AwsIamPolicyAttachmentResourceType: {resource.FlagDeepMode}, - AwsIamRoleResourceType: {resource.FlagDeepMode}, - AwsIamRolePolicyResourceType: {resource.FlagDeepMode}, - AwsIamRolePolicyAttachmentResourceType: {resource.FlagDeepMode}, - AwsIamUserResourceType: {resource.FlagDeepMode}, - AwsIamUserPolicyResourceType: {resource.FlagDeepMode}, - AwsIamUserPolicyAttachmentResourceType: {resource.FlagDeepMode}, - AwsIamGroupPolicyResourceType: {}, - AwsIamGroupPolicyAttachmentResourceType: {}, - AwsInstanceResourceType: {resource.FlagDeepMode}, - AwsInternetGatewayResourceType: {resource.FlagDeepMode}, - AwsKeyPairResourceType: {resource.FlagDeepMode}, - AwsKmsAliasResourceType: {resource.FlagDeepMode}, - AwsKmsKeyResourceType: {resource.FlagDeepMode}, - AwsLambdaEventSourceMappingResourceType: {resource.FlagDeepMode}, - AwsLambdaFunctionResourceType: {resource.FlagDeepMode}, - AwsNatGatewayResourceType: {resource.FlagDeepMode}, - AwsNetworkACLResourceType: {resource.FlagDeepMode}, - AwsRDSClusterResourceType: {resource.FlagDeepMode}, - AwsRDSClusterInstanceResourceType: {}, - AwsRouteResourceType: {resource.FlagDeepMode}, - AwsRoute53HealthCheckResourceType: {resource.FlagDeepMode}, - AwsRoute53RecordResourceType: {resource.FlagDeepMode}, - AwsRoute53ZoneResourceType: {resource.FlagDeepMode}, - AwsRouteTableResourceType: {resource.FlagDeepMode}, - AwsRouteTableAssociationResourceType: {resource.FlagDeepMode}, - AwsS3BucketResourceType: {resource.FlagDeepMode}, - AwsS3BucketAnalyticsConfigurationResourceType: {resource.FlagDeepMode}, - AwsS3BucketInventoryResourceType: {resource.FlagDeepMode}, - AwsS3BucketMetricResourceType: {resource.FlagDeepMode}, - AwsS3BucketNotificationResourceType: {resource.FlagDeepMode}, - AwsS3BucketPolicyResourceType: {resource.FlagDeepMode}, - AwsS3BucketPublicAccessBlockResourceType: {}, - AwsSecurityGroupResourceType: {resource.FlagDeepMode}, - AwsSnsTopicResourceType: {resource.FlagDeepMode}, - AwsSnsTopicPolicyResourceType: {resource.FlagDeepMode}, - AwsSnsTopicSubscriptionResourceType: {resource.FlagDeepMode}, - AwsSqsQueueResourceType: {resource.FlagDeepMode}, - AwsSqsQueuePolicyResourceType: {resource.FlagDeepMode}, - AwsSubnetResourceType: {resource.FlagDeepMode}, - AwsVpcResourceType: {resource.FlagDeepMode}, - AwsSecurityGroupRuleResourceType: {resource.FlagDeepMode}, - AwsNetworkACLRuleResourceType: {resource.FlagDeepMode}, - AwsLaunchTemplateResourceType: {resource.FlagDeepMode}, - AwsLaunchConfigurationResourceType: {}, - AwsLoadBalancerResourceType: {}, - AwsApplicationLoadBalancerResourceType: {}, - AwsClassicLoadBalancerResourceType: {}, - AwsLoadBalancerListenerResourceType: {}, - AwsApplicationLoadBalancerListenerResourceType: {}, - AwsIamGroupResourceType: {}, - AwsEcrRepositoryPolicyResourceType: {}, - } - - schemaRepository := testresource.InitFakeSchemaRepository(tf.AWS, "3.19.0") - InitResourcesMetadata(schemaRepository) - - for ty, flags := range testcases { - t.Run(ty, func(tt *testing.T) { - sch, exist := schemaRepository.GetSchema(ty) - assert.True(tt, exist) - - if len(flags) == 0 { - assert.Equal(tt, resource.Flags(0x0), sch.Flags, "should not have any flag") - return - } - - for _, flag := range flags { - assert.Truef(tt, sch.Flags.HasFlag(flag), "should have given flag %d", flag) - } - }) - } -} diff --git a/pkg/resource/aws/metadatas.go b/pkg/resource/aws/metadatas.go index 63be9bba..5a0cc021 100644 --- a/pkg/resource/aws/metadatas.go +++ b/pkg/resource/aws/metadatas.go @@ -1,6 +1,6 @@ package aws -import "github.com/snyk/driftctl/pkg/resource" +import "github.com/snyk/driftctl/enumeration/resource" func InitResourcesMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { initAwsAmiMetaData(resourceSchemaRepository) @@ -9,48 +9,30 @@ func InitResourcesMetadata(resourceSchemaRepository resource.SchemaRepositoryInt initAwsDbSubnetGroupMetaData(resourceSchemaRepository) initAwsDefaultSecurityGroupMetaData(resourceSchemaRepository) initAwsDefaultSubnetMetaData(resourceSchemaRepository) - initAwsDefaultVpcMetaData(resourceSchemaRepository) - initAwsDefaultRouteTableMetadata(resourceSchemaRepository) initAwsDynamodbTableMetaData(resourceSchemaRepository) initAwsEbsSnapshotMetaData(resourceSchemaRepository) initAwsInstanceMetaData(resourceSchemaRepository) - initAwsInternetGatewayMetaData(resourceSchemaRepository) initAwsEbsVolumeMetaData(resourceSchemaRepository) initAwsEipMetaData(resourceSchemaRepository) - initAwsEipAssociationMetaData(resourceSchemaRepository) initAwsS3BucketMetaData(resourceSchemaRepository) initAwsS3BucketPolicyMetaData(resourceSchemaRepository) - initAwsS3BucketInventoryMetadata(resourceSchemaRepository) - initAwsS3BucketMetricMetadata(resourceSchemaRepository) - initAwsS3BucketNotificationMetadata(resourceSchemaRepository) - initAwsS3BucketAnalyticsConfigurationMetaData(resourceSchemaRepository) initAwsEcrRepositoryMetaData(resourceSchemaRepository) initAwsRouteMetaData(resourceSchemaRepository) - initAwsRouteTableAssociationMetaData(resourceSchemaRepository) initAwsRoute53RecordMetaData(resourceSchemaRepository) initAwsRoute53ZoneMetaData(resourceSchemaRepository) - initAwsRoute53HealthCheckMetaData(resourceSchemaRepository) - initAwsRouteTableMetaData(resourceSchemaRepository) initSnsTopicSubscriptionMetaData(resourceSchemaRepository) initSnsTopicPolicyMetaData(resourceSchemaRepository) initSnsTopicMetaData(resourceSchemaRepository) - initSqsQueueMetaData(resourceSchemaRepository) initAwsIAMAccessKeyMetaData(resourceSchemaRepository) initAwsIAMPolicyMetaData(resourceSchemaRepository) initAwsIAMPolicyAttachmentMetaData(resourceSchemaRepository) initAwsIAMRoleMetaData(resourceSchemaRepository) - initAwsIAMRolePolicyMetaData(resourceSchemaRepository) - initAwsIamRolePolicyAttachmentMetaData(resourceSchemaRepository) - initAwsIamUserPolicyAttachmentMetaData(resourceSchemaRepository) initAwsIAMUserMetaData(resourceSchemaRepository) - initAwsIAMUserPolicyMetaData(resourceSchemaRepository) initAwsKeyPairMetaData(resourceSchemaRepository) initAwsKmsKeyMetaData(resourceSchemaRepository) initAwsKmsAliasMetaData(resourceSchemaRepository) initAwsLambdaFunctionMetaData(resourceSchemaRepository) initAwsLambdaEventSourceMappingMetaData(resourceSchemaRepository) - initNatGatewayMetaData(resourceSchemaRepository) - initAwsNetworkACLMetaData(resourceSchemaRepository) initAwsNetworkACLRuleMetaData(resourceSchemaRepository) initAwsDefaultNetworkACLMetaData(resourceSchemaRepository) initAwsSubnetMetaData(resourceSchemaRepository) @@ -59,12 +41,5 @@ func InitResourcesMetadata(resourceSchemaRepository resource.SchemaRepositoryInt initAwsSecurityGroupMetaData(resourceSchemaRepository) initAwsRDSClusterMetaData(resourceSchemaRepository) initAwsCloudformationStackMetaData(resourceSchemaRepository) - initAwsVpcMetaData(resourceSchemaRepository) initAwsAppAutoscalingTargetMetaData(resourceSchemaRepository) - initAwsAppAutoscalingPolicyMetaData(resourceSchemaRepository) - initAwsLaunchTemplateMetaData(resourceSchemaRepository) - initAwsApiGatewayV2ModelMetaData(resourceSchemaRepository) - initAwsApiGatewayV2MappingMetaData(resourceSchemaRepository) - initAwsEbsEncryptionByDefaultMetaData(resourceSchemaRepository) - initAwsLoadBalancerMetaData(resourceSchemaRepository) } diff --git a/pkg/resource/aws/testdata/acc/aws_alb/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_alb/.terraform.lock.hcl index 34706398..cb60e412 100644 --- a/pkg/resource/aws/testdata/acc/aws_alb/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_alb/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "4.8.0" hashes = [ "h1:T9Typ5V+dDwecG9USCLbW4oayxN3cxEGsG+OJzzjRgY=", + "h1:W2cPGKmqkPbTc91lu42QeC3RFBqB5TnRnS3IxNME2FM=", "zh:16cbdbc03ad13358d12433e645e2ab5a615e3a3662a74e3c317267c9377713d8", "zh:1d813c5e6c21fe370652495e29f783db4e65037f913ff0d53d28515c36fbb70a", "zh:31ad8282e31d0fac62e96fc2321a68ad4b92ab90f560be5f875d1b01a493e491", diff --git a/pkg/resource/aws/testdata/acc/aws_alb_listener/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_alb_listener/.terraform.lock.hcl index ae023925..c236fad4 100644 --- a/pkg/resource/aws/testdata/acc/aws_alb_listener/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_alb_listener/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "4.9.0" hashes = [ "h1:GtmIOZMkKmr9tMLWouHWiGXmKEL/diOTNar5XfOVLjs=", + "h1:OWIIlbMZl/iQ8qR1U7Co3sGjNHL1HJtgNRnnV1kXNuI=", "zh:084b83aef3335ad4f5e4b8323c6fe43c1ff55e17a7647c6a5cad6af519f72b42", "zh:132e47ce69f14de4523b84b213cedf7173398acda14245b1ffe7747aac50f050", "zh:2068baef7dfce3613f3b4f27314175e971f8db68d9cde9ec30b5659f80c68c6c", diff --git a/pkg/resource/aws/testdata/acc/aws_apigatewayv2_mapping/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_apigatewayv2_mapping/.terraform.lock.hcl index 01b6b450..9e7a3811 100644 --- a/pkg/resource/aws/testdata/acc/aws_apigatewayv2_mapping/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_apigatewayv2_mapping/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.47.0" hashes = [ "h1:gXncRh1KtgLNMeb3/bYq5CvGfy8YTR+n6ds1noc5ggc=", + "h1:oiX6JcoXY6lYIdcYWmEpr7mnS4mkyDV9intCNrcjiBs=", "zh:07bb6bda5b9fdb782dd568a2e85cfe0ab108770e2218f3411e57ed845c58af40", "zh:0926b161a109e75bdc8691e8a32f568b4cd77a55510cf27573261fb5ba382287", "zh:0a91adf25a78ad31d547da513db24f493d27592d3675ed291a7698351c30992d", @@ -23,6 +24,7 @@ provider "registry.terraform.io/hashicorp/aws" { provider "registry.terraform.io/hashicorp/tls" { version = "3.1.0" hashes = [ + "h1:XTU9f6sGMZHOT8r/+LWCz2BZOPH127FBTPjMMEAAu1U=", "h1:fUJX8Zxx38e2kBln+zWr1Tl41X+OuiE++REjrEyiOM4=", "zh:3d46616b41fea215566f4a957b6d3a1aa43f1f75c26776d72a98bdba79439db6", "zh:623a203817a6dafa86f1b4141b645159e07ec418c82fe40acd4d2a27543cbaa2", diff --git a/pkg/resource/aws/testdata/acc/aws_apigatewayv2_model/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_apigatewayv2_model/.terraform.lock.hcl index 642a9881..51efd1d6 100644 --- a/pkg/resource/aws/testdata/acc/aws_apigatewayv2_model/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_apigatewayv2_model/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.47.0" hashes = [ "h1:gXncRh1KtgLNMeb3/bYq5CvGfy8YTR+n6ds1noc5ggc=", + "h1:oiX6JcoXY6lYIdcYWmEpr7mnS4mkyDV9intCNrcjiBs=", "zh:07bb6bda5b9fdb782dd568a2e85cfe0ab108770e2218f3411e57ed845c58af40", "zh:0926b161a109e75bdc8691e8a32f568b4cd77a55510cf27573261fb5ba382287", "zh:0a91adf25a78ad31d547da513db24f493d27592d3675ed291a7698351c30992d", diff --git a/pkg/resource/aws/testdata/acc/aws_apigatewayv2_stage/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_apigatewayv2_stage/.terraform.lock.hcl index 642a9881..51efd1d6 100644 --- a/pkg/resource/aws/testdata/acc/aws_apigatewayv2_stage/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_apigatewayv2_stage/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.47.0" hashes = [ "h1:gXncRh1KtgLNMeb3/bYq5CvGfy8YTR+n6ds1noc5ggc=", + "h1:oiX6JcoXY6lYIdcYWmEpr7mnS4mkyDV9intCNrcjiBs=", "zh:07bb6bda5b9fdb782dd568a2e85cfe0ab108770e2218f3411e57ed845c58af40", "zh:0926b161a109e75bdc8691e8a32f568b4cd77a55510cf27573261fb5ba382287", "zh:0a91adf25a78ad31d547da513db24f493d27592d3675ed291a7698351c30992d", diff --git a/pkg/resource/aws/testdata/acc/aws_appautoscaling_policy/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_appautoscaling_policy/.terraform.lock.hcl index 0b377f59..8076e21f 100644 --- a/pkg/resource/aws/testdata/acc/aws_appautoscaling_policy/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_appautoscaling_policy/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", diff --git a/pkg/resource/aws/testdata/acc/aws_appautoscaling_scheduled_action/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_appautoscaling_scheduled_action/.terraform.lock.hcl index 0b377f59..8076e21f 100644 --- a/pkg/resource/aws/testdata/acc/aws_appautoscaling_scheduled_action/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_appautoscaling_scheduled_action/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", diff --git a/pkg/resource/aws/testdata/acc/aws_appautoscaling_target/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_appautoscaling_target/.terraform.lock.hcl index 0b377f59..8076e21f 100644 --- a/pkg/resource/aws/testdata/acc/aws_appautoscaling_target/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_appautoscaling_target/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", diff --git a/pkg/resource/aws/testdata/acc/aws_dynamodb_table/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_dynamodb_table/.terraform.lock.hcl index b6fbf3c6..8a74c4b9 100644 --- a/pkg/resource/aws/testdata/acc/aws_dynamodb_table/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_dynamodb_table/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.32.0" hashes = [ "h1:l8jJYQ4bPEbNwZUoHYmeR1woajPzJSX5hPCLWuRVFwc=", + "h1:tY5R3hRBB1v8v3MT+ofoCJ/Vz/b/4PXo3xtBKtw7T4A=", "zh:04e4f700c21b1f58e7603638160bd5ad3b85519c35dc75bada3e52b164d06d3e", "zh:09f2338404d4b2d4dcb29781ac59a6955d935745e896d4ee661d83cac8d7c677", "zh:16bdf96d8139268766921d5b891b865f67936190dc302283ba50b94e42510ec5", diff --git a/pkg/resource/aws/testdata/acc/aws_ebs_encryption_by_default/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_ebs_encryption_by_default/.terraform.lock.hcl index 34706398..cb60e412 100644 --- a/pkg/resource/aws/testdata/acc/aws_ebs_encryption_by_default/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_ebs_encryption_by_default/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "4.8.0" hashes = [ "h1:T9Typ5V+dDwecG9USCLbW4oayxN3cxEGsG+OJzzjRgY=", + "h1:W2cPGKmqkPbTc91lu42QeC3RFBqB5TnRnS3IxNME2FM=", "zh:16cbdbc03ad13358d12433e645e2ab5a615e3a3662a74e3c317267c9377713d8", "zh:1d813c5e6c21fe370652495e29f783db4e65037f913ff0d53d28515c36fbb70a", "zh:31ad8282e31d0fac62e96fc2321a68ad4b92ab90f560be5f875d1b01a493e491", diff --git a/pkg/resource/aws/testdata/acc/aws_eip/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_eip/.terraform.lock.hcl index 19dd55fa..6d3f22df 100644 --- a/pkg/resource/aws/testdata/acc/aws_eip/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_eip/.terraform.lock.hcl @@ -3,8 +3,9 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.44.0" - constraints = "~> 3.44.0" + constraints = "3.44.0" hashes = [ + "h1:VOVZWybe1x0E4qyawTwt7jXVBRUplTrzVFHim217DqI=", "h1:hxQ8n9SHHfAIXd/FtfAqxokFYWBedzZf7xqQZWJajUs=", "zh:0680315b29a140e9b7e4f5aeed3f2445abdfab31fc9237f34dcad06de4f410df", "zh:13811322a205fb4a0ee617f0ae51ec94176befdf569235d0c7064db911f0acc7", diff --git a/pkg/resource/aws/testdata/acc/aws_eip_association/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_eip_association/.terraform.lock.hcl index 19dd55fa..6d3f22df 100644 --- a/pkg/resource/aws/testdata/acc/aws_eip_association/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_eip_association/.terraform.lock.hcl @@ -3,8 +3,9 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.44.0" - constraints = "~> 3.44.0" + constraints = "3.44.0" hashes = [ + "h1:VOVZWybe1x0E4qyawTwt7jXVBRUplTrzVFHim217DqI=", "h1:hxQ8n9SHHfAIXd/FtfAqxokFYWBedzZf7xqQZWJajUs=", "zh:0680315b29a140e9b7e4f5aeed3f2445abdfab31fc9237f34dcad06de4f410df", "zh:13811322a205fb4a0ee617f0ae51ec94176befdf569235d0c7064db911f0acc7", diff --git a/pkg/resource/aws/testdata/acc/aws_elb/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_elb/.terraform.lock.hcl index 2fad61d4..6d3f22df 100644 --- a/pkg/resource/aws/testdata/acc/aws_elb/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_elb/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.44.0" constraints = "3.44.0" hashes = [ + "h1:VOVZWybe1x0E4qyawTwt7jXVBRUplTrzVFHim217DqI=", "h1:hxQ8n9SHHfAIXd/FtfAqxokFYWBedzZf7xqQZWJajUs=", "zh:0680315b29a140e9b7e4f5aeed3f2445abdfab31fc9237f34dcad06de4f410df", "zh:13811322a205fb4a0ee617f0ae51ec94176befdf569235d0c7064db911f0acc7", diff --git a/pkg/resource/aws/testdata/acc/aws_iam_access_key/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_iam_access_key/.terraform.lock.hcl index 442356e0..8076e21f 100644 --- a/pkg/resource/aws/testdata/acc/aws_iam_access_key/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_iam_access_key/.terraform.lock.hcl @@ -3,9 +3,10 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.19.0" - constraints = "~> 3.19.0" + constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", diff --git a/pkg/resource/aws/testdata/acc/aws_iam_group/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_iam_group/.terraform.lock.hcl index 0b377f59..8076e21f 100644 --- a/pkg/resource/aws/testdata/acc/aws_iam_group/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_iam_group/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", diff --git a/pkg/resource/aws/testdata/acc/aws_iam_group_policy/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_iam_group_policy/.terraform.lock.hcl index 0b377f59..8076e21f 100644 --- a/pkg/resource/aws/testdata/acc/aws_iam_group_policy/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_iam_group_policy/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", diff --git a/pkg/resource/aws/testdata/acc/aws_iam_group_policy_attachment/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_iam_group_policy_attachment/.terraform.lock.hcl index 0b377f59..8076e21f 100644 --- a/pkg/resource/aws/testdata/acc/aws_iam_group_policy_attachment/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_iam_group_policy_attachment/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", diff --git a/pkg/resource/aws/testdata/acc/aws_iam_role/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_iam_role/.terraform.lock.hcl index c238c3f6..82187a94 100644 --- a/pkg/resource/aws/testdata/acc/aws_iam_role/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_iam_role/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.45.0" constraints = "3.45.0" hashes = [ + "h1:9l/yDPt/OPG6a0ITu7amfq1LjdnWHTsOgn/KOxM26HA=", "h1:LKU/xfna87/p+hl5yTTW3dvOqWJp5JEM+Dt3nnvSDvA=", "zh:0fdbb3af75ff55807466533f97eb314556ec41a908a543d7cafb06546930f7c6", "zh:20656895744fa0f4607096b9681c77b2385f450b1577f9151d3070818378a724", diff --git a/pkg/resource/aws/testdata/acc/aws_iam_role_with_managed_policies/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_iam_role_with_managed_policies/.terraform.lock.hcl index c238c3f6..82187a94 100644 --- a/pkg/resource/aws/testdata/acc/aws_iam_role_with_managed_policies/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_iam_role_with_managed_policies/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.45.0" constraints = "3.45.0" hashes = [ + "h1:9l/yDPt/OPG6a0ITu7amfq1LjdnWHTsOgn/KOxM26HA=", "h1:LKU/xfna87/p+hl5yTTW3dvOqWJp5JEM+Dt3nnvSDvA=", "zh:0fdbb3af75ff55807466533f97eb314556ec41a908a543d7cafb06546930f7c6", "zh:20656895744fa0f4607096b9681c77b2385f450b1577f9151d3070818378a724", diff --git a/pkg/resource/aws/testdata/acc/aws_instance/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_instance/.terraform.lock.hcl old mode 100755 new mode 100644 index 88e37a28..d645626d --- a/pkg/resource/aws/testdata/acc/aws_instance/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_instance/.terraform.lock.hcl @@ -3,9 +3,10 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.19.0" - constraints = "~> 3.19.0" + constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", @@ -23,6 +24,7 @@ provider "registry.terraform.io/hashicorp/random" { version = "3.0.0" hashes = [ "h1:grDzxfnOdFXi90FRIIwP/ZrCzirJ/SfsGBe6cE0Shg4=", + "h1:yhHJpb4IfQQfuio7qjUXuUFTU/s+ensuEpm23A+VWz0=", "zh:0fcb00ff8b87dcac1b0ee10831e47e0203a6c46aafd76cb140ba2bab81f02c6b", "zh:123c984c0e04bad910c421028d18aa2ca4af25a153264aef747521f4e7c36a17", "zh:287443bc6fd7fa9a4341dec235589293cbcc6e467a042ae225fd5d161e4e68dc", diff --git a/pkg/resource/aws/testdata/acc/aws_instance_default/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_instance_default/.terraform.lock.hcl index c238c3f6..82187a94 100644 --- a/pkg/resource/aws/testdata/acc/aws_instance_default/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_instance_default/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.45.0" constraints = "3.45.0" hashes = [ + "h1:9l/yDPt/OPG6a0ITu7amfq1LjdnWHTsOgn/KOxM26HA=", "h1:LKU/xfna87/p+hl5yTTW3dvOqWJp5JEM+Dt3nnvSDvA=", "zh:0fdbb3af75ff55807466533f97eb314556ec41a908a543d7cafb06546930f7c6", "zh:20656895744fa0f4607096b9681c77b2385f450b1577f9151d3070818378a724", diff --git a/pkg/resource/aws/testdata/acc/aws_lambda_event_source_mapping/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_lambda_event_source_mapping/.terraform.lock.hcl index b6fbf3c6..8a74c4b9 100644 --- a/pkg/resource/aws/testdata/acc/aws_lambda_event_source_mapping/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_lambda_event_source_mapping/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.32.0" hashes = [ "h1:l8jJYQ4bPEbNwZUoHYmeR1woajPzJSX5hPCLWuRVFwc=", + "h1:tY5R3hRBB1v8v3MT+ofoCJ/Vz/b/4PXo3xtBKtw7T4A=", "zh:04e4f700c21b1f58e7603638160bd5ad3b85519c35dc75bada3e52b164d06d3e", "zh:09f2338404d4b2d4dcb29781ac59a6955d935745e896d4ee661d83cac8d7c677", "zh:16bdf96d8139268766921d5b891b865f67936190dc302283ba50b94e42510ec5", diff --git a/pkg/resource/aws/testdata/acc/aws_launch_template/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_launch_template/.terraform.lock.hcl index 0b377f59..8076e21f 100644 --- a/pkg/resource/aws/testdata/acc/aws_launch_template/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_launch_template/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", diff --git a/pkg/resource/aws/testdata/acc/aws_lb/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_lb/.terraform.lock.hcl index 34706398..cb60e412 100644 --- a/pkg/resource/aws/testdata/acc/aws_lb/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_lb/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "4.8.0" hashes = [ "h1:T9Typ5V+dDwecG9USCLbW4oayxN3cxEGsG+OJzzjRgY=", + "h1:W2cPGKmqkPbTc91lu42QeC3RFBqB5TnRnS3IxNME2FM=", "zh:16cbdbc03ad13358d12433e645e2ab5a615e3a3662a74e3c317267c9377713d8", "zh:1d813c5e6c21fe370652495e29f783db4e65037f913ff0d53d28515c36fbb70a", "zh:31ad8282e31d0fac62e96fc2321a68ad4b92ab90f560be5f875d1b01a493e491", diff --git a/pkg/resource/aws/testdata/acc/aws_lb_listener/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_lb_listener/.terraform.lock.hcl index ae023925..c236fad4 100644 --- a/pkg/resource/aws/testdata/acc/aws_lb_listener/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_lb_listener/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "4.9.0" hashes = [ "h1:GtmIOZMkKmr9tMLWouHWiGXmKEL/diOTNar5XfOVLjs=", + "h1:OWIIlbMZl/iQ8qR1U7Co3sGjNHL1HJtgNRnnV1kXNuI=", "zh:084b83aef3335ad4f5e4b8323c6fe43c1ff55e17a7647c6a5cad6af519f72b42", "zh:132e47ce69f14de4523b84b213cedf7173398acda14245b1ffe7747aac50f050", "zh:2068baef7dfce3613f3b4f27314175e971f8db68d9cde9ec30b5659f80c68c6c", diff --git a/pkg/resource/aws/testdata/acc/aws_nat_gateway/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_nat_gateway/.terraform.lock.hcl old mode 100755 new mode 100644 index 0b377f59..8076e21f --- a/pkg/resource/aws/testdata/acc/aws_nat_gateway/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_nat_gateway/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", diff --git a/pkg/resource/aws/testdata/acc/aws_network_acl/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_network_acl/.terraform.lock.hcl index 642a9881..51efd1d6 100644 --- a/pkg/resource/aws/testdata/acc/aws_network_acl/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_network_acl/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.47.0" hashes = [ "h1:gXncRh1KtgLNMeb3/bYq5CvGfy8YTR+n6ds1noc5ggc=", + "h1:oiX6JcoXY6lYIdcYWmEpr7mnS4mkyDV9intCNrcjiBs=", "zh:07bb6bda5b9fdb782dd568a2e85cfe0ab108770e2218f3411e57ed845c58af40", "zh:0926b161a109e75bdc8691e8a32f568b4cd77a55510cf27573261fb5ba382287", "zh:0a91adf25a78ad31d547da513db24f493d27592d3675ed291a7698351c30992d", diff --git a/pkg/resource/aws/testdata/acc/aws_rds_cluster_instance/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_rds_cluster_instance/.terraform.lock.hcl index 0b377f59..8076e21f 100644 --- a/pkg/resource/aws/testdata/acc/aws_rds_cluster_instance/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_rds_cluster_instance/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", diff --git a/pkg/resource/aws/testdata/acc/aws_route/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_route/.terraform.lock.hcl index 2fad61d4..6d3f22df 100644 --- a/pkg/resource/aws/testdata/acc/aws_route/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_route/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.44.0" constraints = "3.44.0" hashes = [ + "h1:VOVZWybe1x0E4qyawTwt7jXVBRUplTrzVFHim217DqI=", "h1:hxQ8n9SHHfAIXd/FtfAqxokFYWBedzZf7xqQZWJajUs=", "zh:0680315b29a140e9b7e4f5aeed3f2445abdfab31fc9237f34dcad06de4f410df", "zh:13811322a205fb4a0ee617f0ae51ec94176befdf569235d0c7064db911f0acc7", diff --git a/pkg/resource/aws/testdata/acc/aws_route53_health_check/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_route53_health_check/.terraform.lock.hcl index 9e715343..a81a0e2e 100644 --- a/pkg/resource/aws/testdata/acc/aws_route53_health_check/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_route53_health_check/.terraform.lock.hcl @@ -4,6 +4,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.28.0" hashes = [ + "h1:0cCqlVoOAj4YOi61kVpqoxu1bdAmB67z6uZf+lsHJOw=", "h1:zejsAukFmgZCOdQCk44L3cumXFs8YDSltRIjZN+izsU=", "zh:1fee7fce319be5bea7df2e95f28a78a04e15c18bad5eb56dcc0ecc324c97f4b8", "zh:2383ff31ef7411f7d4bef1ee288f0f79bec41cf220ac94c2b31f6a702b26f984", diff --git a/pkg/resource/aws/testdata/acc/aws_route53_record/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_route53_record/.terraform.lock.hcl old mode 100755 new mode 100644 index 442356e0..8076e21f --- a/pkg/resource/aws/testdata/acc/aws_route53_record/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_route53_record/.terraform.lock.hcl @@ -3,9 +3,10 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.19.0" - constraints = "~> 3.19.0" + constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", diff --git a/pkg/resource/aws/testdata/acc/aws_route_table/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_route_table/.terraform.lock.hcl index ab9ad3ad..ddde27f9 100644 --- a/pkg/resource/aws/testdata/acc/aws_route_table/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_route_table/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.62.0" constraints = "3.62.0" hashes = [ + "h1:3aKS+Lra6yHSs6zMqgVZXBZhYG7nkHS6DED8sG+rAlo=", "h1:UjsV2CRiVU3ye7w9AabX6t/bmDuAF5mt+fr63/pfDHQ=", "zh:08a94019e17304f5927d7c85b8f5dade6b9ffebeb7b06ec0643aaa1130c4c85d", "zh:0e3709f6c1fed8c5119a5653bec7e3069258ddf91f62d851f8deeede10487fb8", diff --git a/pkg/resource/aws/testdata/acc/aws_route_table_association/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_route_table_association/.terraform.lock.hcl old mode 100755 new mode 100644 index 0b377f59..8076e21f --- a/pkg/resource/aws/testdata/acc/aws_route_table_association/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_route_table_association/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", diff --git a/pkg/resource/aws/testdata/acc/aws_route_with_prefix_list_id/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_route_with_prefix_list_id/.terraform.lock.hcl index 1a5ac31c..99756de0 100644 --- a/pkg/resource/aws/testdata/acc/aws_route_with_prefix_list_id/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_route_with_prefix_list_id/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.75.1" hashes = [ "h1:++H0a4igODgreQL3SJuRz71JZkC69rl41R8xLYM894o=", + "h1:zgO9MSF32Rz6lOBumY+FyPZESYwlL5SUXOViTV5cs28=", "zh:11c2ee541ca1da923356c9225575ba294523d7b6af82d6171c912470ef0f90cd", "zh:19fe975993664252b4a2ff1079546f2b186b01d1a025a94a4f15c37e023806c5", "zh:442e7fc145b2debebe9279b283d07f5f736dc1776c2e5b1702728a6eb03789d0", diff --git a/pkg/resource/aws/testdata/acc/aws_s3_bucket/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_s3_bucket/.terraform.lock.hcl old mode 100755 new mode 100644 index 8abce94c..c41723df --- a/pkg/resource/aws/testdata/acc/aws_s3_bucket/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_s3_bucket/.terraform.lock.hcl @@ -3,9 +3,10 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.19.0" - constraints = "~> 3.19.0" + constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", @@ -22,6 +23,7 @@ provider "registry.terraform.io/hashicorp/aws" { provider "registry.terraform.io/hashicorp/random" { version = "3.0.1" hashes = [ + "h1:0QaSbRBgBi8vI/8IRwec1INdOqBxXbgsSFElx1O4k4g=", "h1:SzM8nt2wzLMI28A3CWAtW25g3ZCm1O4xD0h3Ps/rU1U=", "zh:0d4f683868324af056a9eb2b06306feef7c202c88dbbe6a4ad7517146a22fb50", "zh:4824b3c7914b77d41dfe90f6f333c7ac9860afb83e2a344d91fbe46e5dfbec26", diff --git a/pkg/resource/aws/testdata/acc/aws_s3_bucket_public_access_block/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_s3_bucket_public_access_block/.terraform.lock.hcl index 7d2353ee..b7a6b26e 100644 --- a/pkg/resource/aws/testdata/acc/aws_s3_bucket_public_access_block/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_s3_bucket_public_access_block/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", @@ -23,6 +24,7 @@ provider "registry.terraform.io/hashicorp/random" { version = "3.1.2" hashes = [ "h1:5A5VsY5wNmOZlupUcLnIoziMPn8htSZBXbP3lI7lBEM=", + "h1:9A6Ghjgad0KjJRxa6nPo8i8uFvwj3Vv0wnEgy49u+24=", "zh:0daceba867b330d3f8e2c5dc895c4291845a78f31955ce1b91ab2c4d1cd1c10b", "zh:104050099efd30a630741f788f9576b19998e7a09347decbec3da0b21d64ba2d", "zh:173f4ef3fdf0c7e2564a3db0fac560e9f5afdf6afd0b75d6646af6576b122b16", diff --git a/pkg/resource/aws/testdata/acc/aws_sns_topic/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_sns_topic/.terraform.lock.hcl old mode 100755 new mode 100644 index 3c981c6f..0cca25ba --- a/pkg/resource/aws/testdata/acc/aws_sns_topic/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_sns_topic/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.26.0" hashes = [ "h1:0i78FItlPeiomd+4ThZrtm56P5K33k7/6dnEe4ZePI0=", + "h1:b1qNzEzDHZpnHSOW4fRo1PFC0U2Ft25PKKs9NSDGe3U=", "zh:26043eed36d070ca032cf04bc980c654a25821a8abc0c85e1e570e3935bbfcbb", "zh:2fe68f3f78d23830a04d7fac3eda550eef1f627dfc130486f70a65dc5c254300", "zh:3d66484c608c64678e639db25d63872783ce60363a1246e30317f21c9c23b84b", diff --git a/pkg/resource/aws/testdata/acc/aws_sns_topic_policy/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_sns_topic_policy/.terraform.lock.hcl old mode 100755 new mode 100644 index 3c981c6f..0cca25ba --- a/pkg/resource/aws/testdata/acc/aws_sns_topic_policy/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_sns_topic_policy/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.26.0" hashes = [ "h1:0i78FItlPeiomd+4ThZrtm56P5K33k7/6dnEe4ZePI0=", + "h1:b1qNzEzDHZpnHSOW4fRo1PFC0U2Ft25PKKs9NSDGe3U=", "zh:26043eed36d070ca032cf04bc980c654a25821a8abc0c85e1e570e3935bbfcbb", "zh:2fe68f3f78d23830a04d7fac3eda550eef1f627dfc130486f70a65dc5c254300", "zh:3d66484c608c64678e639db25d63872783ce60363a1246e30317f21c9c23b84b", diff --git a/pkg/resource/aws/testdata/acc/aws_sns_topic_subscription/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_sns_topic_subscription/.terraform.lock.hcl old mode 100755 new mode 100644 index a79e6a89..1ed4754d --- a/pkg/resource/aws/testdata/acc/aws_sns_topic_subscription/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_sns_topic_subscription/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/aws" { version = "3.27.0" hashes = [ "h1:ccxtk7jAtmBPvAEXswOEYJcyp5jTD9QlQeg8GEzYmxQ=", + "h1:lJaA23rrNSgLEumcW7Y0KKvno5isVVNNIEV5pT92O8E=", "zh:2986eb5a1ffbb0336c6390aad533b62efc832aa8aa5460d523e1f2daa4f42f79", "zh:825317cdb80860833125a856c0befc877cba22d41c631c5a7ca22400693d4356", "zh:a47aad668cc74058f508c56c5407cd715dbb9b6389aa68d37543e897895db43f", diff --git a/pkg/resource/aws/testdata/acc/aws_subnet/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_subnet/.terraform.lock.hcl old mode 100755 new mode 100644 index 0b377f59..8076e21f --- a/pkg/resource/aws/testdata/acc/aws_subnet/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_subnet/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", diff --git a/pkg/resource/aws/testdata/acc/aws_vpc/.terraform.lock.hcl b/pkg/resource/aws/testdata/acc/aws_vpc/.terraform.lock.hcl index 0b377f59..8076e21f 100644 --- a/pkg/resource/aws/testdata/acc/aws_vpc/.terraform.lock.hcl +++ b/pkg/resource/aws/testdata/acc/aws_vpc/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/aws" { constraints = "3.19.0" hashes = [ "h1:+7Vi7p13+cnrxjXbfJiTimGSFR97xCaQwkkvWcreLns=", + "h1:xur9tF49NgsovNnmwmBR8RdpN8Fcg1TD4CKQPJD6n1A=", "zh:185a5259153eb9ee4699d4be43b3d509386b473683392034319beee97d470c3b", "zh:2d9a0a01f93e8d16539d835c02b8b6e1927b7685f4076e96cb07f7dd6944bc6c", "zh:703f6da36b1b5f3497baa38fccaa7765fb8a2b6440344e4c97172516b49437dd", diff --git a/pkg/resource/azurerm/azurerm_container_registry.go b/pkg/resource/azurerm/azurerm_container_registry.go deleted file mode 100644 index 91593001..00000000 --- a/pkg/resource/azurerm/azurerm_container_registry.go +++ /dev/null @@ -1,16 +0,0 @@ -package azurerm - -import "github.com/snyk/driftctl/pkg/resource" - -const AzureContainerRegistryResourceType = "azurerm_container_registry" - -func initAzureContainerRegistryMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureContainerRegistryResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - return attrs - }) -} diff --git a/pkg/resource/azurerm/azurerm_firewall.go b/pkg/resource/azurerm/azurerm_firewall.go deleted file mode 100644 index b28c2394..00000000 --- a/pkg/resource/azurerm/azurerm_firewall.go +++ /dev/null @@ -1,16 +0,0 @@ -package azurerm - -import "github.com/snyk/driftctl/pkg/resource" - -const AzureFirewallResourceType = "azurerm_firewall" - -func initAzureFirewallMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureFirewallResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - return attrs - }) -} diff --git a/pkg/resource/azurerm/azurerm_image.go b/pkg/resource/azurerm/azurerm_image.go deleted file mode 100644 index 79f320d2..00000000 --- a/pkg/resource/azurerm/azurerm_image.go +++ /dev/null @@ -1,19 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/resource" -) - -const AzureImageResourceType = "azurerm_image" - -func initAzureImageMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureImageResourceType, func(res *resource.Resource) map[string]string { - attrs := make(map[string]string) - - if v := res.Attributes().GetString("name"); v != nil && *v != "" { - attrs["Name"] = *v - } - - return attrs - }) -} diff --git a/pkg/resource/azurerm/azurerm_lb.go b/pkg/resource/azurerm/azurerm_lb.go deleted file mode 100644 index 59b22f26..00000000 --- a/pkg/resource/azurerm/azurerm_lb.go +++ /dev/null @@ -1,16 +0,0 @@ -package azurerm - -import "github.com/snyk/driftctl/pkg/resource" - -const AzureLoadBalancerResourceType = "azurerm_lb" - -func initAzureLoadBalancerMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureLoadBalancerResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - return attrs - }) -} diff --git a/pkg/resource/azurerm/azurerm_lb_rule.go b/pkg/resource/azurerm/azurerm_lb_rule.go index 53e8a8f5..089eee7c 100644 --- a/pkg/resource/azurerm/azurerm_lb_rule.go +++ b/pkg/resource/azurerm/azurerm_lb_rule.go @@ -1,24 +1,12 @@ package azurerm -import "github.com/snyk/driftctl/pkg/resource" - -const AzureLoadBalancerRuleResourceType = "azurerm_lb_rule" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) func initAzureLoadBalancerRuleMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AzureLoadBalancerRuleResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(azurerm.AzureLoadBalancerRuleResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetResolveReadAttributesFunc(AzureLoadBalancerRuleResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "loadbalancer_id": *res.Attributes().GetString("loadbalancer_id"), - } - }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureLoadBalancerRuleResourceType, func(res *resource.Resource) map[string]string { - attrs := make(map[string]string) - if name := res.Attributes().GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - return attrs - }) - resourceSchemaRepository.SetFlags(AzureLoadBalancerRuleResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/azurerm/azurerm_network_security_group.go b/pkg/resource/azurerm/azurerm_network_security_group.go index 594cc5ee..b9ff04df 100644 --- a/pkg/resource/azurerm/azurerm_network_security_group.go +++ b/pkg/resource/azurerm/azurerm_network_security_group.go @@ -1,20 +1,12 @@ package azurerm -import "github.com/snyk/driftctl/pkg/resource" - -const AzureNetworkSecurityGroupResourceType = "azurerm_network_security_group" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) func initAzureNetworkSecurityGroupMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AzureNetworkSecurityGroupResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(azurerm.AzureNetworkSecurityGroupResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureNetworkSecurityGroupResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - return attrs - }) - resourceSchemaRepository.SetFlags(AzureNetworkSecurityGroupResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/azurerm/azurerm_postgresql_database.go b/pkg/resource/azurerm/azurerm_postgresql_database.go deleted file mode 100644 index caa41b16..00000000 --- a/pkg/resource/azurerm/azurerm_postgresql_database.go +++ /dev/null @@ -1,16 +0,0 @@ -package azurerm - -import "github.com/snyk/driftctl/pkg/resource" - -const AzurePostgresqlDatabaseResourceType = "azurerm_postgresql_database" - -func initAzurePostgresqlDatabaseMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePostgresqlDatabaseResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - return attrs - }) -} diff --git a/pkg/resource/azurerm/azurerm_postgresql_server.go b/pkg/resource/azurerm/azurerm_postgresql_server.go deleted file mode 100644 index 380f152b..00000000 --- a/pkg/resource/azurerm/azurerm_postgresql_server.go +++ /dev/null @@ -1,16 +0,0 @@ -package azurerm - -import "github.com/snyk/driftctl/pkg/resource" - -const AzurePostgresqlServerResourceType = "azurerm_postgresql_server" - -func initAzurePostgresqlServerMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePostgresqlServerResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - return attrs - }) -} diff --git a/pkg/resource/azurerm/azurerm_private_dns_a_record.go b/pkg/resource/azurerm/azurerm_private_dns_a_record.go index 42367754..46707086 100644 --- a/pkg/resource/azurerm/azurerm_private_dns_a_record.go +++ b/pkg/resource/azurerm/azurerm_private_dns_a_record.go @@ -1,25 +1,12 @@ package azurerm import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" ) -const AzurePrivateDNSARecordResourceType = "azurerm_private_dns_a_record" - func initAzurePrivateDNSARecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AzurePrivateDNSARecordResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(azurerm.AzurePrivateDNSARecordResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePrivateDNSARecordResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - if zone := val.GetString("zone_name"); zone != nil && *zone != "" { - attrs["Zone"] = *zone - } - return attrs - }) - resourceSchemaRepository.SetFlags(AzurePrivateDNSARecordResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/azurerm/azurerm_private_dns_aaaa_record.go b/pkg/resource/azurerm/azurerm_private_dns_aaaa_record.go index 7e78924b..2b0d1003 100644 --- a/pkg/resource/azurerm/azurerm_private_dns_aaaa_record.go +++ b/pkg/resource/azurerm/azurerm_private_dns_aaaa_record.go @@ -1,25 +1,12 @@ package azurerm import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" ) -const AzurePrivateDNSAAAARecordResourceType = "azurerm_private_dns_aaaa_record" - func initAzurePrivateDNSAAAARecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AzurePrivateDNSAAAARecordResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(azurerm.AzurePrivateDNSAAAARecordResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePrivateDNSAAAARecordResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - if zone := val.GetString("zone_name"); zone != nil && *zone != "" { - attrs["Zone"] = *zone - } - return attrs - }) - resourceSchemaRepository.SetFlags(AzurePrivateDNSAAAARecordResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/azurerm/azurerm_private_dns_cname_record.go b/pkg/resource/azurerm/azurerm_private_dns_cname_record.go index a53b898d..e15b72d1 100644 --- a/pkg/resource/azurerm/azurerm_private_dns_cname_record.go +++ b/pkg/resource/azurerm/azurerm_private_dns_cname_record.go @@ -1,24 +1,12 @@ package azurerm -import "github.com/snyk/driftctl/pkg/resource" - -const AzurePrivateDNSCNameRecordResourceType = "azurerm_private_dns_cname_record" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" +) func initAzurePrivateDNSCNameRecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetFlags(AzurePrivateDNSCNameRecordResourceType, resource.FlagDeepMode) - - resourceSchemaRepository.SetNormalizeFunc(AzurePrivateDNSCNameRecordResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(azurerm.AzurePrivateDNSCNameRecordResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePrivateDNSCNameRecordResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - if zone := val.GetString("zone_name"); zone != nil && *zone != "" { - attrs["Zone"] = *zone - } - return attrs - }) } diff --git a/pkg/resource/azurerm/azurerm_private_dns_mx_record.go b/pkg/resource/azurerm/azurerm_private_dns_mx_record.go index 6f3d22b7..60e92cf1 100644 --- a/pkg/resource/azurerm/azurerm_private_dns_mx_record.go +++ b/pkg/resource/azurerm/azurerm_private_dns_mx_record.go @@ -1,25 +1,12 @@ package azurerm import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" ) -const AzurePrivateDNSMXRecordResourceType = "azurerm_private_dns_mx_record" - func initAzurePrivateDNSMXRecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AzurePrivateDNSMXRecordResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(azurerm.AzurePrivateDNSMXRecordResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePrivateDNSMXRecordResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - if zone := val.GetString("zone_name"); zone != nil && *zone != "" { - attrs["Zone"] = *zone - } - return attrs - }) - resourceSchemaRepository.SetFlags(AzurePrivateDNSMXRecordResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/azurerm/azurerm_private_dns_ptr_record.go b/pkg/resource/azurerm/azurerm_private_dns_ptr_record.go index af0c58c9..c10884b8 100644 --- a/pkg/resource/azurerm/azurerm_private_dns_ptr_record.go +++ b/pkg/resource/azurerm/azurerm_private_dns_ptr_record.go @@ -1,25 +1,12 @@ package azurerm import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" ) -const AzurePrivateDNSPTRRecordResourceType = "azurerm_private_dns_ptr_record" - func initAzurePrivateDNSPTRRecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AzurePrivateDNSPTRRecordResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(azurerm.AzurePrivateDNSPTRRecordResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePrivateDNSPTRRecordResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - if zone := val.GetString("zone_name"); zone != nil && *zone != "" { - attrs["Zone"] = *zone - } - return attrs - }) - resourceSchemaRepository.SetFlags(AzurePrivateDNSPTRRecordResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/azurerm/azurerm_private_dns_srv_record.go b/pkg/resource/azurerm/azurerm_private_dns_srv_record.go index 4d839b95..fbed64f9 100644 --- a/pkg/resource/azurerm/azurerm_private_dns_srv_record.go +++ b/pkg/resource/azurerm/azurerm_private_dns_srv_record.go @@ -1,25 +1,12 @@ package azurerm import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" ) -const AzurePrivateDNSSRVRecordResourceType = "azurerm_private_dns_srv_record" - func initAzurePrivateDNSSRVRecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AzurePrivateDNSSRVRecordResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(azurerm.AzurePrivateDNSSRVRecordResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePrivateDNSSRVRecordResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - if zone := val.GetString("zone_name"); zone != nil && *zone != "" { - attrs["Zone"] = *zone - } - return attrs - }) - resourceSchemaRepository.SetFlags(AzurePrivateDNSSRVRecordResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/azurerm/azurerm_private_dns_txt_record.go b/pkg/resource/azurerm/azurerm_private_dns_txt_record.go index 795e97e2..01bbe9b8 100644 --- a/pkg/resource/azurerm/azurerm_private_dns_txt_record.go +++ b/pkg/resource/azurerm/azurerm_private_dns_txt_record.go @@ -1,25 +1,12 @@ package azurerm import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" ) -const AzurePrivateDNSTXTRecordResourceType = "azurerm_private_dns_txt_record" - func initAzurePrivateDNSTXTRecordMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AzurePrivateDNSTXTRecordResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(azurerm.AzurePrivateDNSTXTRecordResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePrivateDNSTXTRecordResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - if zone := val.GetString("zone_name"); zone != nil && *zone != "" { - attrs["Zone"] = *zone - } - return attrs - }) - resourceSchemaRepository.SetFlags(AzurePrivateDNSTXTRecordResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/azurerm/azurerm_private_dns_zone.go b/pkg/resource/azurerm/azurerm_private_dns_zone.go index 45514f3a..28776192 100644 --- a/pkg/resource/azurerm/azurerm_private_dns_zone.go +++ b/pkg/resource/azurerm/azurerm_private_dns_zone.go @@ -1,15 +1,13 @@ package azurerm import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" ) -const AzurePrivateDNSZoneResourceType = "azurerm_private_dns_zone" - func initAzurePrivateDNSZoneMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AzurePrivateDNSZoneResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(azurerm.AzurePrivateDNSZoneResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"number_of_record_sets"}) res.Attributes().SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetFlags(AzurePrivateDNSZoneResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/azurerm/azurerm_public_ip.go b/pkg/resource/azurerm/azurerm_public_ip.go deleted file mode 100644 index ce2b0374..00000000 --- a/pkg/resource/azurerm/azurerm_public_ip.go +++ /dev/null @@ -1,16 +0,0 @@ -package azurerm - -import "github.com/snyk/driftctl/pkg/resource" - -const AzurePublicIPResourceType = "azurerm_public_ip" - -func initAzurePublicIPMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzurePublicIPResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - return attrs - }) -} diff --git a/pkg/resource/azurerm/azurerm_resource_group.go b/pkg/resource/azurerm/azurerm_resource_group.go deleted file mode 100644 index 88f5daf8..00000000 --- a/pkg/resource/azurerm/azurerm_resource_group.go +++ /dev/null @@ -1,16 +0,0 @@ -package azurerm - -import "github.com/snyk/driftctl/pkg/resource" - -const AzureResourceGroupResourceType = "azurerm_resource_group" - -func initAzureResourceGroupMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureResourceGroupResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - return attrs - }) -} diff --git a/pkg/resource/azurerm/azurerm_route.go b/pkg/resource/azurerm/azurerm_route.go deleted file mode 100644 index 082184f0..00000000 --- a/pkg/resource/azurerm/azurerm_route.go +++ /dev/null @@ -1,23 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/resource" -) - -const AzureRouteResourceType = "azurerm_route" - -func initAzureRouteMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureRouteResourceType, func(res *resource.Resource) map[string]string { - attrs := make(map[string]string) - - if v := res.Attributes().GetString("name"); v != nil && *v != "" { - attrs["Name"] = *v - } - - if v := res.Attributes().GetString("route_table_name"); v != nil && *v != "" { - attrs["Table"] = *v - } - - return attrs - }) -} diff --git a/pkg/resource/azurerm/azurerm_route_table.go b/pkg/resource/azurerm/azurerm_route_table.go deleted file mode 100644 index fb2065d8..00000000 --- a/pkg/resource/azurerm/azurerm_route_table.go +++ /dev/null @@ -1,18 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/resource" -) - -const AzureRouteTableResourceType = "azurerm_route_table" - -func initAzureRouteTableMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureRouteTableResourceType, func(res *resource.Resource) map[string]string { - attrs := make(map[string]string) - - if v := res.Attributes().GetString("name"); v != nil && *v != "" { - attrs["Name"] = *v - } - return attrs - }) -} diff --git a/pkg/resource/azurerm/azurerm_ssh_public_key.go b/pkg/resource/azurerm/azurerm_ssh_public_key.go index 17e87abe..eaf521ed 100644 --- a/pkg/resource/azurerm/azurerm_ssh_public_key.go +++ b/pkg/resource/azurerm/azurerm_ssh_public_key.go @@ -1,23 +1,12 @@ package azurerm import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/azurerm" ) -const AzureSSHPublicKeyResourceType = "azurerm_ssh_public_key" - func initAzureSSHPublicKeyMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(AzureSSHPublicKeyResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(azurerm.AzureSSHPublicKeyResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureSSHPublicKeyResourceType, func(res *resource.Resource) map[string]string { - attrs := make(map[string]string) - - if v := res.Attributes().GetString("name"); v != nil && *v != "" { - attrs["Name"] = *v - } - - return attrs - }) - resourceSchemaRepository.SetFlags(AzureSSHPublicKeyResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/azurerm/azurerm_virtual_network.go b/pkg/resource/azurerm/azurerm_virtual_network.go deleted file mode 100644 index 47049252..00000000 --- a/pkg/resource/azurerm/azurerm_virtual_network.go +++ /dev/null @@ -1,18 +0,0 @@ -package azurerm - -import ( - "github.com/snyk/driftctl/pkg/resource" -) - -const AzureVirtualNetworkResourceType = "azurerm_virtual_network" - -func initAzureVirtualNetworkMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(AzureVirtualNetworkResourceType, func(res *resource.Resource) map[string]string { - attrs := make(map[string]string) - - if v := res.Attributes().GetString("name"); v != nil && *v != "" { - attrs["Name"] = *v - } - return attrs - }) -} diff --git a/pkg/resource/azurerm/metadata.go b/pkg/resource/azurerm/metadata.go index 2a4c4767..0aa92160 100644 --- a/pkg/resource/azurerm/metadata.go +++ b/pkg/resource/azurerm/metadata.go @@ -1,19 +1,9 @@ package azurerm -import "github.com/snyk/driftctl/pkg/resource" +import "github.com/snyk/driftctl/enumeration/resource" func InitResourcesMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - initAzureVirtualNetworkMetaData(resourceSchemaRepository) - initAzureRouteTableMetaData(resourceSchemaRepository) - initAzureRouteMetaData(resourceSchemaRepository) - initAzureResourceGroupMetadata(resourceSchemaRepository) - initAzureContainerRegistryMetadata(resourceSchemaRepository) - initAzureFirewallMetadata(resourceSchemaRepository) - initAzurePostgresqlServerMetadata(resourceSchemaRepository) - initAzurePublicIPMetadata(resourceSchemaRepository) - initAzurePostgresqlDatabaseMetadata(resourceSchemaRepository) initAzureNetworkSecurityGroupMetadata(resourceSchemaRepository) - initAzureLoadBalancerMetadata(resourceSchemaRepository) initAzurePrivateDNSZoneMetaData(resourceSchemaRepository) initAzurePrivateDNSARecordMetaData(resourceSchemaRepository) initAzurePrivateDNSAAAARecordMetaData(resourceSchemaRepository) @@ -21,7 +11,6 @@ func InitResourcesMetadata(resourceSchemaRepository resource.SchemaRepositoryInt initAzurePrivateDNSSRVRecordMetaData(resourceSchemaRepository) initAzurePrivateDNSMXRecordMetaData(resourceSchemaRepository) initAzurePrivateDNSTXTRecordMetaData(resourceSchemaRepository) - initAzureImageMetaData(resourceSchemaRepository) initAzureSSHPublicKeyMetaData(resourceSchemaRepository) initAzurePrivateDNSCNameRecordMetaData(resourceSchemaRepository) initAzureLoadBalancerRuleMetadata(resourceSchemaRepository) diff --git a/pkg/resource/azurerm/metadata_test.go b/pkg/resource/azurerm/metadata_test.go deleted file mode 100644 index 8e1dce50..00000000 --- a/pkg/resource/azurerm/metadata_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package azurerm - -import ( - "testing" - - "github.com/snyk/driftctl/pkg/resource" - tf "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - "github.com/stretchr/testify/assert" -) - -func TestAzureMetadata_Flags(t *testing.T) { - testcases := map[string][]resource.Flags{ - AzureContainerRegistryResourceType: {}, - AzureFirewallResourceType: {}, - AzurePostgresqlServerResourceType: {}, - AzurePostgresqlDatabaseResourceType: {}, - AzurePublicIPResourceType: {}, - AzureResourceGroupResourceType: {}, - AzureRouteResourceType: {}, - AzureRouteTableResourceType: {}, - AzureStorageAccountResourceType: {}, - AzureStorageContainerResourceType: {}, - AzureSubnetResourceType: {}, - AzureVirtualNetworkResourceType: {}, - AzureNetworkSecurityGroupResourceType: {resource.FlagDeepMode}, - AzureLoadBalancerResourceType: {}, - AzurePrivateDNSZoneResourceType: {resource.FlagDeepMode}, - AzurePrivateDNSARecordResourceType: {resource.FlagDeepMode}, - AzurePrivateDNSAAAARecordResourceType: {resource.FlagDeepMode}, - AzurePrivateDNSCNameRecordResourceType: {resource.FlagDeepMode}, - AzurePrivateDNSPTRRecordResourceType: {resource.FlagDeepMode}, - AzurePrivateDNSMXRecordResourceType: {resource.FlagDeepMode}, - AzurePrivateDNSSRVRecordResourceType: {resource.FlagDeepMode}, - AzurePrivateDNSTXTRecordResourceType: {resource.FlagDeepMode}, - AzureImageResourceType: {}, - AzureSSHPublicKeyResourceType: {resource.FlagDeepMode}, - AzureLoadBalancerRuleResourceType: {resource.FlagDeepMode}, - } - - schemaRepository := testresource.InitFakeSchemaRepository(tf.AZURE, "2.71.0") - InitResourcesMetadata(schemaRepository) - - for ty, flags := range testcases { - t.Run(ty, func(tt *testing.T) { - sch, exist := schemaRepository.GetSchema(ty) - assert.True(tt, exist) - - if len(flags) == 0 { - assert.Equal(tt, resource.Flags(0x0), sch.Flags, "should not have any flag") - return - } - - for _, flag := range flags { - assert.Truef(tt, sch.Flags.HasFlag(flag), "should have given flag %d", flag) - } - }) - } -} diff --git a/pkg/resource/azurerm/testdata/acc/azurerm_image/.terraform.lock.hcl b/pkg/resource/azurerm/testdata/acc/azurerm_image/.terraform.lock.hcl index 67ef8eaf..4f1629f0 100644 --- a/pkg/resource/azurerm/testdata/acc/azurerm_image/.terraform.lock.hcl +++ b/pkg/resource/azurerm/testdata/acc/azurerm_image/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/azurerm" { constraints = "~> 2.71.0" hashes = [ "h1:RiFIxNI4Yr9CqleqEdgg1ydLAZ5JiYiz6l5iTD3WcuU=", + "h1:ULax/q7p3Tl0l8DnXV9GNmdDRR1MHpimyLq8OP6E6I0=", "zh:2b9d8a703a0222f72cbceb8d2bdb580066afdcd7f28b6ad65d5ed935319b5433", "zh:332988f4c1747bcc8ebd32734bf8de2bea4c13a6fbd08d7eb97d0c43d335b15e", "zh:3a902470276ba48e23ad4dd6baff16a9ce3b60b29c0b07064dbe96ce4640a31c", diff --git a/pkg/resource/azurerm/testdata/acc/azurerm_lb_rule/.terraform.lock.hcl b/pkg/resource/azurerm/testdata/acc/azurerm_lb_rule/.terraform.lock.hcl index 392ede8b..952bb771 100644 --- a/pkg/resource/azurerm/testdata/acc/azurerm_lb_rule/.terraform.lock.hcl +++ b/pkg/resource/azurerm/testdata/acc/azurerm_lb_rule/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/azurerm" { constraints = "~> 2.71.0" hashes = [ "h1:RiFIxNI4Yr9CqleqEdgg1ydLAZ5JiYiz6l5iTD3WcuU=", + "h1:ULax/q7p3Tl0l8DnXV9GNmdDRR1MHpimyLq8OP6E6I0=", "zh:2b9d8a703a0222f72cbceb8d2bdb580066afdcd7f28b6ad65d5ed935319b5433", "zh:332988f4c1747bcc8ebd32734bf8de2bea4c13a6fbd08d7eb97d0c43d335b15e", "zh:3a902470276ba48e23ad4dd6baff16a9ce3b60b29c0b07064dbe96ce4640a31c", diff --git a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_a_record/.terraform.lock.hcl b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_a_record/.terraform.lock.hcl index 67ef8eaf..4f1629f0 100644 --- a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_a_record/.terraform.lock.hcl +++ b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_a_record/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/azurerm" { constraints = "~> 2.71.0" hashes = [ "h1:RiFIxNI4Yr9CqleqEdgg1ydLAZ5JiYiz6l5iTD3WcuU=", + "h1:ULax/q7p3Tl0l8DnXV9GNmdDRR1MHpimyLq8OP6E6I0=", "zh:2b9d8a703a0222f72cbceb8d2bdb580066afdcd7f28b6ad65d5ed935319b5433", "zh:332988f4c1747bcc8ebd32734bf8de2bea4c13a6fbd08d7eb97d0c43d335b15e", "zh:3a902470276ba48e23ad4dd6baff16a9ce3b60b29c0b07064dbe96ce4640a31c", diff --git a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_aaaa_record/.terraform.lock.hcl b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_aaaa_record/.terraform.lock.hcl index 67ef8eaf..4f1629f0 100644 --- a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_aaaa_record/.terraform.lock.hcl +++ b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_aaaa_record/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/azurerm" { constraints = "~> 2.71.0" hashes = [ "h1:RiFIxNI4Yr9CqleqEdgg1ydLAZ5JiYiz6l5iTD3WcuU=", + "h1:ULax/q7p3Tl0l8DnXV9GNmdDRR1MHpimyLq8OP6E6I0=", "zh:2b9d8a703a0222f72cbceb8d2bdb580066afdcd7f28b6ad65d5ed935319b5433", "zh:332988f4c1747bcc8ebd32734bf8de2bea4c13a6fbd08d7eb97d0c43d335b15e", "zh:3a902470276ba48e23ad4dd6baff16a9ce3b60b29c0b07064dbe96ce4640a31c", diff --git a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_cname_record/.terraform.lock.hcl b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_cname_record/.terraform.lock.hcl index 67ef8eaf..4f1629f0 100644 --- a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_cname_record/.terraform.lock.hcl +++ b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_cname_record/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/azurerm" { constraints = "~> 2.71.0" hashes = [ "h1:RiFIxNI4Yr9CqleqEdgg1ydLAZ5JiYiz6l5iTD3WcuU=", + "h1:ULax/q7p3Tl0l8DnXV9GNmdDRR1MHpimyLq8OP6E6I0=", "zh:2b9d8a703a0222f72cbceb8d2bdb580066afdcd7f28b6ad65d5ed935319b5433", "zh:332988f4c1747bcc8ebd32734bf8de2bea4c13a6fbd08d7eb97d0c43d335b15e", "zh:3a902470276ba48e23ad4dd6baff16a9ce3b60b29c0b07064dbe96ce4640a31c", diff --git a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_mx_record/.terraform.lock.hcl b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_mx_record/.terraform.lock.hcl index 67ef8eaf..4f1629f0 100644 --- a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_mx_record/.terraform.lock.hcl +++ b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_mx_record/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/azurerm" { constraints = "~> 2.71.0" hashes = [ "h1:RiFIxNI4Yr9CqleqEdgg1ydLAZ5JiYiz6l5iTD3WcuU=", + "h1:ULax/q7p3Tl0l8DnXV9GNmdDRR1MHpimyLq8OP6E6I0=", "zh:2b9d8a703a0222f72cbceb8d2bdb580066afdcd7f28b6ad65d5ed935319b5433", "zh:332988f4c1747bcc8ebd32734bf8de2bea4c13a6fbd08d7eb97d0c43d335b15e", "zh:3a902470276ba48e23ad4dd6baff16a9ce3b60b29c0b07064dbe96ce4640a31c", diff --git a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_ptr_record/.terraform.lock.hcl b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_ptr_record/.terraform.lock.hcl index 67ef8eaf..4f1629f0 100644 --- a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_ptr_record/.terraform.lock.hcl +++ b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_ptr_record/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/azurerm" { constraints = "~> 2.71.0" hashes = [ "h1:RiFIxNI4Yr9CqleqEdgg1ydLAZ5JiYiz6l5iTD3WcuU=", + "h1:ULax/q7p3Tl0l8DnXV9GNmdDRR1MHpimyLq8OP6E6I0=", "zh:2b9d8a703a0222f72cbceb8d2bdb580066afdcd7f28b6ad65d5ed935319b5433", "zh:332988f4c1747bcc8ebd32734bf8de2bea4c13a6fbd08d7eb97d0c43d335b15e", "zh:3a902470276ba48e23ad4dd6baff16a9ce3b60b29c0b07064dbe96ce4640a31c", diff --git a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_srv_record/.terraform.lock.hcl b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_srv_record/.terraform.lock.hcl index 67ef8eaf..4f1629f0 100644 --- a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_srv_record/.terraform.lock.hcl +++ b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_srv_record/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/azurerm" { constraints = "~> 2.71.0" hashes = [ "h1:RiFIxNI4Yr9CqleqEdgg1ydLAZ5JiYiz6l5iTD3WcuU=", + "h1:ULax/q7p3Tl0l8DnXV9GNmdDRR1MHpimyLq8OP6E6I0=", "zh:2b9d8a703a0222f72cbceb8d2bdb580066afdcd7f28b6ad65d5ed935319b5433", "zh:332988f4c1747bcc8ebd32734bf8de2bea4c13a6fbd08d7eb97d0c43d335b15e", "zh:3a902470276ba48e23ad4dd6baff16a9ce3b60b29c0b07064dbe96ce4640a31c", diff --git a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_txt_record/.terraform.lock.hcl b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_txt_record/.terraform.lock.hcl index 67ef8eaf..4f1629f0 100644 --- a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_txt_record/.terraform.lock.hcl +++ b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_txt_record/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/azurerm" { constraints = "~> 2.71.0" hashes = [ "h1:RiFIxNI4Yr9CqleqEdgg1ydLAZ5JiYiz6l5iTD3WcuU=", + "h1:ULax/q7p3Tl0l8DnXV9GNmdDRR1MHpimyLq8OP6E6I0=", "zh:2b9d8a703a0222f72cbceb8d2bdb580066afdcd7f28b6ad65d5ed935319b5433", "zh:332988f4c1747bcc8ebd32734bf8de2bea4c13a6fbd08d7eb97d0c43d335b15e", "zh:3a902470276ba48e23ad4dd6baff16a9ce3b60b29c0b07064dbe96ce4640a31c", diff --git a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_zone/.terraform.lock.hcl b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_zone/.terraform.lock.hcl index 67ef8eaf..4f1629f0 100644 --- a/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_zone/.terraform.lock.hcl +++ b/pkg/resource/azurerm/testdata/acc/azurerm_private_dns_zone/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/azurerm" { constraints = "~> 2.71.0" hashes = [ "h1:RiFIxNI4Yr9CqleqEdgg1ydLAZ5JiYiz6l5iTD3WcuU=", + "h1:ULax/q7p3Tl0l8DnXV9GNmdDRR1MHpimyLq8OP6E6I0=", "zh:2b9d8a703a0222f72cbceb8d2bdb580066afdcd7f28b6ad65d5ed935319b5433", "zh:332988f4c1747bcc8ebd32734bf8de2bea4c13a6fbd08d7eb97d0c43d335b15e", "zh:3a902470276ba48e23ad4dd6baff16a9ce3b60b29c0b07064dbe96ce4640a31c", diff --git a/pkg/resource/azurerm/testdata/acc/azurerm_resource_group/.terraform.lock.hcl b/pkg/resource/azurerm/testdata/acc/azurerm_resource_group/.terraform.lock.hcl old mode 100755 new mode 100644 index 67ef8eaf..4f1629f0 --- a/pkg/resource/azurerm/testdata/acc/azurerm_resource_group/.terraform.lock.hcl +++ b/pkg/resource/azurerm/testdata/acc/azurerm_resource_group/.terraform.lock.hcl @@ -6,6 +6,7 @@ provider "registry.terraform.io/hashicorp/azurerm" { constraints = "~> 2.71.0" hashes = [ "h1:RiFIxNI4Yr9CqleqEdgg1ydLAZ5JiYiz6l5iTD3WcuU=", + "h1:ULax/q7p3Tl0l8DnXV9GNmdDRR1MHpimyLq8OP6E6I0=", "zh:2b9d8a703a0222f72cbceb8d2bdb580066afdcd7f28b6ad65d5ed935319b5433", "zh:332988f4c1747bcc8ebd32734bf8de2bea4c13a6fbd08d7eb97d0c43d335b15e", "zh:3a902470276ba48e23ad4dd6baff16a9ce3b60b29c0b07064dbe96ce4640a31c", diff --git a/pkg/resource/deserializer.go b/pkg/resource/deserializer.go index aa71080a..1b16a1a0 100644 --- a/pkg/resource/deserializer.go +++ b/pkg/resource/deserializer.go @@ -3,20 +3,21 @@ package resource import ( "encoding/json" + "github.com/snyk/driftctl/enumeration/resource" "github.com/zclconf/go-cty/cty" ctyjson "github.com/zclconf/go-cty/cty/json" ) type Deserializer struct { - factory ResourceFactory + factory resource.ResourceFactory } -func NewDeserializer(factory ResourceFactory) *Deserializer { +func NewDeserializer(factory resource.ResourceFactory) *Deserializer { return &Deserializer{factory} } -func (s *Deserializer) Deserialize(ty string, rawList []cty.Value) ([]*Resource, error) { - resources := make([]*Resource, 0) +func (s *Deserializer) Deserialize(ty string, rawList []cty.Value) ([]*resource.Resource, error) { + resources := make([]*resource.Resource, 0) for _, rawResource := range rawList { rawResource := rawResource res, err := s.DeserializeOne(ty, rawResource) @@ -28,7 +29,7 @@ func (s *Deserializer) Deserialize(ty string, rawList []cty.Value) ([]*Resource, return resources, nil } -func (s *Deserializer) DeserializeOne(ty string, value cty.Value) (*Resource, error) { +func (s *Deserializer) DeserializeOne(ty string, value cty.Value) (*resource.Resource, error) { if value.IsNull() { return nil, nil } @@ -37,7 +38,7 @@ func (s *Deserializer) DeserializeOne(ty string, value cty.Value) (*Resource, er // For example, this ensures we can deserialize sensitive values too. unmarkedVal, _ := value.UnmarkDeep() - var attrs Attributes + var attrs resource.Attributes bytes, _ := ctyjson.Marshal(unmarkedVal, unmarkedVal.Type()) err := json.Unmarshal(bytes, &attrs) if err != nil { diff --git a/pkg/resource/github/github_branch_protection.go b/pkg/resource/github/github_branch_protection.go index 154cb0e5..39760431 100644 --- a/pkg/resource/github/github_branch_protection.go +++ b/pkg/resource/github/github_branch_protection.go @@ -2,42 +2,13 @@ package github import ( - "encoding/base64" - - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/github" ) -const GithubBranchProtectionResourceType = "github_branch_protection" - func initGithubBranchProtectionMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(GithubBranchProtectionResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(github.GithubBranchProtectionResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"repository_id"}) // Terraform provider is always returning nil }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(GithubBranchProtectionResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - pattern := val.GetString("pattern") - repoID := val.GetString("repository_id") - if pattern != nil && *pattern != "" { - id := "" - if repoID != nil && *repoID != "" { - decodedID, err := base64.StdEncoding.DecodeString(*repoID) - if err == nil { - id = string(decodedID) - } - } - if id == "" { - attrs["Branch"] = *pattern - attrs["Id"] = res.ResourceId() - return attrs - } - attrs["Branch"] = *pattern - attrs["RepoId"] = id - return attrs - } - attrs["Id"] = res.ResourceId() - return attrs - }) - resourceSchemaRepository.SetFlags(GithubBranchProtectionResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/github/github_membership.go b/pkg/resource/github/github_membership.go index 16cf59d0..09b4e46c 100644 --- a/pkg/resource/github/github_membership.go +++ b/pkg/resource/github/github_membership.go @@ -1,14 +1,14 @@ // GENERATED, DO NOT EDIT THIS FILE package github -import "github.com/snyk/driftctl/pkg/resource" - -const GithubMembershipResourceType = "github_membership" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/github" +) func initGithubMembershipMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(GithubMembershipResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(github.GithubMembershipResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"etag"}) }) - resourceSchemaRepository.SetFlags(GithubMembershipResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/github/github_repository.go b/pkg/resource/github/github_repository.go index bd4df0b0..94654bb8 100644 --- a/pkg/resource/github/github_repository.go +++ b/pkg/resource/github/github_repository.go @@ -1,14 +1,14 @@ package github -import "github.com/snyk/driftctl/pkg/resource" - -const GithubRepositoryResourceType = "github_repository" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/github" +) func initGithubRepositoryMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(GithubRepositoryResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(github.GithubRepositoryResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"auto_init"}) val.SafeDelete([]string{"etag"}) }) - resourceSchemaRepository.SetFlags(GithubRepositoryResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/github/github_team.go b/pkg/resource/github/github_team.go index 5288480f..302fcf0c 100644 --- a/pkg/resource/github/github_team.go +++ b/pkg/resource/github/github_team.go @@ -1,26 +1,17 @@ // GENERATED, DO NOT EDIT THIS FILE package github -import "github.com/snyk/driftctl/pkg/resource" - -const GithubTeamResourceType = "github_team" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/github" +) func initGithubTeamMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(GithubTeamResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(github.GithubTeamResourceType, func(res *resource.Resource) { val := res.Attrs if defaultMaintainer, exist := val.Get("create_default_maintainer"); !exist || defaultMaintainer == nil { (*val)["create_default_maintainer"] = false } val.SafeDelete([]string{"etag"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(GithubTeamResourceType, func(res *resource.Resource) map[string]string { - val := res.Attrs - attrs := make(map[string]string) - attrs["Id"] = res.ResourceId() - if name := val.GetString("name"); name != nil && *name != "" { - attrs["Name"] = *name - } - return attrs - }) - resourceSchemaRepository.SetFlags(GithubTeamResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/github/github_team_membership.go b/pkg/resource/github/github_team_membership.go index ccf42ec5..4454ed85 100644 --- a/pkg/resource/github/github_team_membership.go +++ b/pkg/resource/github/github_team_membership.go @@ -1,14 +1,14 @@ // GENERATED, DO NOT EDIT THIS FILE package github -import "github.com/snyk/driftctl/pkg/resource" - -const GithubTeamMembershipResourceType = "github_team_membership" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/github" +) func initGithubTeamMembershipMetaData(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(GithubTeamMembershipResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(github.GithubTeamMembershipResourceType, func(res *resource.Resource) { val := res.Attrs val.SafeDelete([]string{"etag"}) }) - resourceSchemaRepository.SetFlags(GithubTeamMembershipResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/github/metadata_test.go b/pkg/resource/github/metadata_test.go deleted file mode 100644 index c5064fe8..00000000 --- a/pkg/resource/github/metadata_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package github - -import ( - "testing" - - "github.com/snyk/driftctl/pkg/resource" - tf "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - "github.com/stretchr/testify/assert" -) - -func TestGitHub_Metadata_Flags(t *testing.T) { - testcases := map[string][]resource.Flags{ - GithubBranchProtectionResourceType: {resource.FlagDeepMode}, - GithubMembershipResourceType: {resource.FlagDeepMode}, - GithubTeamMembershipResourceType: {resource.FlagDeepMode}, - GithubRepositoryResourceType: {resource.FlagDeepMode}, - GithubTeamResourceType: {resource.FlagDeepMode}, - } - - schemaRepository := testresource.InitFakeSchemaRepository(tf.GITHUB, "4.4.0") - InitResourcesMetadata(schemaRepository) - - for ty, flags := range testcases { - t.Run(ty, func(tt *testing.T) { - sch, exist := schemaRepository.GetSchema(ty) - assert.True(tt, exist) - - if len(flags) == 0 { - assert.Equal(tt, resource.Flags(0x0), sch.Flags, "should not have any flag") - return - } - - for _, flag := range flags { - assert.Truef(tt, sch.Flags.HasFlag(flag), "should have given flag %d", flag) - } - }) - } -} diff --git a/pkg/resource/github/metadatas.go b/pkg/resource/github/metadatas.go index 08ec32df..24622785 100644 --- a/pkg/resource/github/metadatas.go +++ b/pkg/resource/github/metadatas.go @@ -1,6 +1,6 @@ package github -import "github.com/snyk/driftctl/pkg/resource" +import "github.com/snyk/driftctl/enumeration/resource" func InitResourcesMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { initGithubBranchProtectionMetaData(resourceSchemaRepository) diff --git a/pkg/resource/github/testdata/acc/github_branch_protection/.terraform.lock.hcl b/pkg/resource/github/testdata/acc/github_branch_protection/.terraform.lock.hcl old mode 100755 new mode 100644 index 0cb91a2c..bc6235de --- a/pkg/resource/github/testdata/acc/github_branch_protection/.terraform.lock.hcl +++ b/pkg/resource/github/testdata/acc/github_branch_protection/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/github" { version = "4.4.0" constraints = "4.4.0" hashes = [ + "h1:dgn+oL1cC8kz3ODIuT/PyHqgso00SpItPN089ZuUGt4=", "h1:eKArqtLcYoYUFf4dgNzVemqu2GsoEf7K0ZLEXjSoPBo=", "zh:0ebb07c4971ca7d60fce8614270d056328a121fd4ffbda4b29a06d4a1e90e939", "zh:178b333f2f285c1a59b9335320f584bd01304179c2d6a1919366945b55cfb293", diff --git a/pkg/resource/github/testdata/acc/github_repository/.terraform.lock.hcl b/pkg/resource/github/testdata/acc/github_repository/.terraform.lock.hcl old mode 100755 new mode 100644 index 0cb91a2c..bc6235de --- a/pkg/resource/github/testdata/acc/github_repository/.terraform.lock.hcl +++ b/pkg/resource/github/testdata/acc/github_repository/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/github" { version = "4.4.0" constraints = "4.4.0" hashes = [ + "h1:dgn+oL1cC8kz3ODIuT/PyHqgso00SpItPN089ZuUGt4=", "h1:eKArqtLcYoYUFf4dgNzVemqu2GsoEf7K0ZLEXjSoPBo=", "zh:0ebb07c4971ca7d60fce8614270d056328a121fd4ffbda4b29a06d4a1e90e939", "zh:178b333f2f285c1a59b9335320f584bd01304179c2d6a1919366945b55cfb293", diff --git a/pkg/resource/github/testdata/acc/github_team/.terraform.lock.hcl b/pkg/resource/github/testdata/acc/github_team/.terraform.lock.hcl old mode 100755 new mode 100644 index 0cb91a2c..bc6235de --- a/pkg/resource/github/testdata/acc/github_team/.terraform.lock.hcl +++ b/pkg/resource/github/testdata/acc/github_team/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/github" { version = "4.4.0" constraints = "4.4.0" hashes = [ + "h1:dgn+oL1cC8kz3ODIuT/PyHqgso00SpItPN089ZuUGt4=", "h1:eKArqtLcYoYUFf4dgNzVemqu2GsoEf7K0ZLEXjSoPBo=", "zh:0ebb07c4971ca7d60fce8614270d056328a121fd4ffbda4b29a06d4a1e90e939", "zh:178b333f2f285c1a59b9335320f584bd01304179c2d6a1919366945b55cfb293", diff --git a/pkg/resource/google/google_bigquery_dataset.go b/pkg/resource/google/google_bigquery_dataset.go deleted file mode 100644 index 86eda586..00000000 --- a/pkg/resource/google/google_bigquery_dataset.go +++ /dev/null @@ -1,13 +0,0 @@ -package google - -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleBigqueryDatasetResourceType = "google_bigquery_dataset" - -func initGoogleBigqueryDatasetMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleBigqueryDatasetResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "name": *res.Attrs.GetString("friendly_name"), - } - }) -} diff --git a/pkg/resource/google/google_bigquery_table.go b/pkg/resource/google/google_bigquery_table.go deleted file mode 100644 index f6aa8d7a..00000000 --- a/pkg/resource/google/google_bigquery_table.go +++ /dev/null @@ -1,13 +0,0 @@ -package google - -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleBigqueryTableResourceType = "google_bigquery_table" - -func initGoogleBigqueryTableMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleBigqueryTableResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "name": *res.Attrs.GetString("friendly_name"), - } - }) -} diff --git a/pkg/resource/google/google_compute_address.go b/pkg/resource/google/google_compute_address.go deleted file mode 100644 index 61897161..00000000 --- a/pkg/resource/google/google_compute_address.go +++ /dev/null @@ -1,14 +0,0 @@ -package google - -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleComputeAddressResourceType = "google_compute_address" - -func initGoogleComputeAddressMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeAddressResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "Name": *res.Attributes().GetString("name"), - "Address": *res.Attributes().GetString("address"), - } - }) -} diff --git a/pkg/resource/google/google_compute_disk.go b/pkg/resource/google/google_compute_disk.go deleted file mode 100644 index 37e02009..00000000 --- a/pkg/resource/google/google_compute_disk.go +++ /dev/null @@ -1,13 +0,0 @@ -package google - -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleComputeDiskResourceType = "google_compute_disk" - -func initGoogleComputeDiskMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeDiskResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "Name": *res.Attributes().GetString("name"), - } - }) -} diff --git a/pkg/resource/google/google_compute_firewall.go b/pkg/resource/google/google_compute_firewall.go index 7e961795..953b9074 100644 --- a/pkg/resource/google/google_compute_firewall.go +++ b/pkg/resource/google/google_compute_firewall.go @@ -1,18 +1,12 @@ package google -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleComputeFirewallResourceType = "google_compute_firewall" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) func initGoogleComputeFirewallMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleComputeFirewallResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "name": *res.Attrs.GetString("name"), - "project": *res.Attrs.GetString("project"), - } - }) - resourceSchemaRepository.SetNormalizeFunc(GoogleComputeFirewallResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(google.GoogleComputeFirewallResourceType, func(res *resource.Resource) { res.Attrs.SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetFlags(GoogleComputeFirewallResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/google/google_compute_global_address.go b/pkg/resource/google/google_compute_global_address.go deleted file mode 100644 index 00a14140..00000000 --- a/pkg/resource/google/google_compute_global_address.go +++ /dev/null @@ -1,14 +0,0 @@ -package google - -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleComputeGlobalAddressResourceType = "google_compute_global_address" - -func initGoogleComputeGlobalAddressMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeGlobalAddressResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "Name": *res.Attributes().GetString("name"), - "Address": *res.Attributes().GetString("address"), - } - }) -} diff --git a/pkg/resource/google/google_compute_health_check.go b/pkg/resource/google/google_compute_health_check.go deleted file mode 100644 index 033e9163..00000000 --- a/pkg/resource/google/google_compute_health_check.go +++ /dev/null @@ -1,13 +0,0 @@ -package google - -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleComputeHealthCheckResourceType = "google_compute_health_check" - -func initGoogleComputeHealthCheckMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeHealthCheckResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "Name": *res.Attributes().GetString("name"), - } - }) -} diff --git a/pkg/resource/google/google_compute_image.go b/pkg/resource/google/google_compute_image.go deleted file mode 100644 index 7b30ae04..00000000 --- a/pkg/resource/google/google_compute_image.go +++ /dev/null @@ -1,13 +0,0 @@ -package google - -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleComputeImageResourceType = "google_compute_image" - -func initGoogleComputeImageMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeImageResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "Name": *res.Attributes().GetString("name"), - } - }) -} diff --git a/pkg/resource/google/google_compute_instance_group.go b/pkg/resource/google/google_compute_instance_group.go index 16b0601f..3fe2a939 100644 --- a/pkg/resource/google/google_compute_instance_group.go +++ b/pkg/resource/google/google_compute_instance_group.go @@ -1,26 +1,12 @@ package google -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleComputeInstanceGroupResourceType = "google_compute_instance_group" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) func initGoogleComputeInstanceGroupMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(GoogleComputeInstanceGroupResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(google.GoogleComputeInstanceGroupResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"timeouts"}) }) - resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleComputeInstanceGroupResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "name": *res.Attributes().GetString("name"), - "project": *res.Attributes().GetString("project"), - "zone": *res.Attributes().GetString("location"), - } - }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeInstanceGroupResourceType, func(res *resource.Resource) map[string]string { - attrs := make(map[string]string) - if v := res.Attributes().GetString("name"); v != nil && *v != "" { - attrs["Name"] = *v - } - return attrs - }) - resourceSchemaRepository.SetFlags(GoogleComputeInstanceGroupResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/google/google_compute_instance_group_manager.go b/pkg/resource/google/google_compute_instance_group_manager.go deleted file mode 100644 index 9afe1319..00000000 --- a/pkg/resource/google/google_compute_instance_group_manager.go +++ /dev/null @@ -1,15 +0,0 @@ -package google - -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleComputeInstanceGroupManagerResourceType = "google_compute_instance_group_manager" - -func initComputeInstanceGroupManagerMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeInstanceGroupManagerResourceType, func(res *resource.Resource) map[string]string { - attrs := make(map[string]string) - if v := res.Attributes().GetString("name"); v != nil && *v != "" { - attrs["Name"] = *v - } - return attrs - }) -} diff --git a/pkg/resource/google/google_compute_network.go b/pkg/resource/google/google_compute_network.go index b679a9d8..2eaa58ce 100644 --- a/pkg/resource/google/google_compute_network.go +++ b/pkg/resource/google/google_compute_network.go @@ -1,20 +1,15 @@ package google -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleComputeNetworkResourceType = "google_compute_network" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) func initGoogleComputeNetworkMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(GoogleComputeNetworkResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(google.GoogleComputeNetworkResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"timeouts"}) res.Attributes().SafeDelete([]string{"self_link"}) res.Attributes().SafeDelete([]string{"gateway_ipv4"}) res.Attributes().SafeDelete([]string{"delete_default_routes_on_create"}) }) - resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleComputeNetworkResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "name": *res.Attributes().GetString("name"), - } - }) - resourceSchemaRepository.SetFlags(GoogleComputeNetworkResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/google/google_compute_router.go b/pkg/resource/google/google_compute_router.go deleted file mode 100644 index ea428a59..00000000 --- a/pkg/resource/google/google_compute_router.go +++ /dev/null @@ -1,15 +0,0 @@ -package google - -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleComputeRouterResourceType = "google_compute_router" - -func initGoogleComputeRouterMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleComputeRouterResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "name": *res.Attrs.GetString("name"), - "region": *res.Attrs.GetString("region"), - "project": *res.Attrs.GetString("project"), - } - }) -} diff --git a/pkg/resource/google/google_compute_subnetwork.go b/pkg/resource/google/google_compute_subnetwork.go index 8497fad4..3632dbd0 100644 --- a/pkg/resource/google/google_compute_subnetwork.go +++ b/pkg/resource/google/google_compute_subnetwork.go @@ -1,27 +1,13 @@ package google -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleComputeSubnetworkResourceType = "google_compute_subnetwork" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) func initGoogleComputeSubnetworkMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleComputeSubnetworkResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "name": *res.Attributes().GetString("name"), - "region": *res.Attributes().GetString("region"), - } - }) - resourceSchemaRepository.SetNormalizeFunc(GoogleComputeSubnetworkResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(google.GoogleComputeSubnetworkResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"timeouts"}) res.Attributes().SafeDelete([]string{"self_link"}) }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleComputeSubnetworkResourceType, func(res *resource.Resource) map[string]string { - attrs := make(map[string]string) - - if v := res.Attributes().GetString("name"); v != nil && *v != "" { - attrs["Name"] = *v - } - return attrs - }) - resourceSchemaRepository.SetFlags(GoogleComputeSubnetworkResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/google/google_project_iam_member.go b/pkg/resource/google/google_project_iam_member.go index 23a686a6..10d33ff1 100644 --- a/pkg/resource/google/google_project_iam_member.go +++ b/pkg/resource/google/google_project_iam_member.go @@ -1,21 +1,13 @@ package google -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleProjectIamMemberResourceType = "google_project_iam_member" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) func initGoogleProjectIAMMemberMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(GoogleProjectIamMemberResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(google.GoogleProjectIamMemberResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"force_destroy"}) res.Attributes().SafeDelete([]string{"etag"}) }) - resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleProjectIamMemberResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "project": *res.Attrs.GetString("project"), - "role": *res.Attrs.GetString("role"), - "member": *res.Attrs.GetString("member"), - } - }) - resourceSchemaRepository.SetFlags(GoogleProjectIamMemberResourceType, resource.FlagDeepMode) - } diff --git a/pkg/resource/google/google_storage_bucket.go b/pkg/resource/google/google_storage_bucket.go index 265f5267..903aabfe 100644 --- a/pkg/resource/google/google_storage_bucket.go +++ b/pkg/resource/google/google_storage_bucket.go @@ -1,17 +1,12 @@ package google -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleStorageBucketResourceType = "google_storage_bucket" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) func initGoogleStorageBucketMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(GoogleStorageBucketResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(google.GoogleStorageBucketResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"force_destroy"}) }) - resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleStorageBucketResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "name": res.ResourceId(), - } - }) - resourceSchemaRepository.SetFlags(GoogleStorageBucketResourceType, resource.FlagDeepMode) } diff --git a/pkg/resource/google/google_storage_bucket_iam_member.go b/pkg/resource/google/google_storage_bucket_iam_member.go index 5ff5e28d..b1855db1 100644 --- a/pkg/resource/google/google_storage_bucket_iam_member.go +++ b/pkg/resource/google/google_storage_bucket_iam_member.go @@ -1,29 +1,13 @@ package google -import "github.com/snyk/driftctl/pkg/resource" - -const GoogleStorageBucketIamMemberResourceType = "google_storage_bucket_iam_member" +import ( + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/enumeration/resource/google" +) func initGoogleStorageBucketIamBMemberMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { - resourceSchemaRepository.SetNormalizeFunc(GoogleStorageBucketIamMemberResourceType, func(res *resource.Resource) { + resourceSchemaRepository.SetNormalizeFunc(google.GoogleStorageBucketIamMemberResourceType, func(res *resource.Resource) { res.Attributes().SafeDelete([]string{"force_destroy"}) res.Attributes().SafeDelete([]string{"etag"}) }) - resourceSchemaRepository.SetResolveReadAttributesFunc(GoogleStorageBucketIamMemberResourceType, func(res *resource.Resource) map[string]string { - return map[string]string{ - "bucket": *res.Attrs.GetString("bucket"), - "role": *res.Attrs.GetString("role"), - "member": *res.Attrs.GetString("member"), - } - }) - resourceSchemaRepository.SetHumanReadableAttributesFunc(GoogleStorageBucketIamMemberResourceType, func(res *resource.Resource) map[string]string { - attrs := map[string]string{ - "bucket": *res.Attrs.GetString("bucket"), - "role": *res.Attrs.GetString("role"), - "member": *res.Attrs.GetString("member"), - } - return attrs - }) - resourceSchemaRepository.SetFlags(GoogleStorageBucketIamMemberResourceType, resource.FlagDeepMode) - } diff --git a/pkg/resource/google/metadata_test.go b/pkg/resource/google/metadata_test.go deleted file mode 100644 index 2d1ebd6a..00000000 --- a/pkg/resource/google/metadata_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package google - -import ( - "testing" - - "github.com/snyk/driftctl/pkg/resource" - tf "github.com/snyk/driftctl/pkg/terraform" - testresource "github.com/snyk/driftctl/test/resource" - "github.com/stretchr/testify/assert" -) - -func TestGoogle_Metadata_Flags(t *testing.T) { - testcases := map[string][]resource.Flags{ - GoogleBigqueryDatasetResourceType: {}, - GoogleComputeFirewallResourceType: {resource.FlagDeepMode}, - GoogleComputeInstanceResourceType: {}, - GoogleComputeInstanceGroupResourceType: {resource.FlagDeepMode}, - GoogleComputeNetworkResourceType: {resource.FlagDeepMode}, - GoogleComputeRouterResourceType: {}, - GoogleDNSManagedZoneResourceType: {}, - GoogleProjectIamBindingResourceType: {}, - GoogleProjectIamMemberResourceType: {resource.FlagDeepMode}, - GoogleProjectIamPolicyResourceType: {}, - GoogleStorageBucketResourceType: {resource.FlagDeepMode}, - GoogleStorageBucketIamBindingResourceType: {}, - GoogleStorageBucketIamMemberResourceType: {resource.FlagDeepMode}, - GoogleStorageBucketIamPolicyResourceType: {}, - GoogleBigqueryTableResourceType: {}, - GoogleComputeDiskResourceType: {}, - GoogleBigTableInstanceResourceType: {}, - GoogleComputeGlobalAddressResourceType: {}, - GoogleCloudRunServiceResourceType: {}, - GoogleComputeNodeGroupResourceType: {}, - GoogleComputeForwardingRuleResourceType: {}, - GoogleComputeInstanceGroupManagerResourceType: {}, - GoogleComputeGlobalForwardingRuleResourceType: {}, - } - - schemaRepository := testresource.InitFakeSchemaRepository(tf.GOOGLE, "3.78.0") - InitResourcesMetadata(schemaRepository) - - for ty, flags := range testcases { - t.Run(ty, func(tt *testing.T) { - sch, exist := schemaRepository.GetSchema(ty) - assert.True(tt, exist) - - if len(flags) == 0 { - assert.Equal(tt, resource.Flags(0x0), sch.Flags, "should not have any flag") - return - } - - for _, flag := range flags { - assert.Truef(tt, sch.Flags.HasFlag(flag), "should have given flag %d", flag) - } - }) - } -} diff --git a/pkg/resource/google/metadatas.go b/pkg/resource/google/metadatas.go index 788d2f22..9c74aa5d 100644 --- a/pkg/resource/google/metadatas.go +++ b/pkg/resource/google/metadatas.go @@ -1,22 +1,13 @@ package google -import "github.com/snyk/driftctl/pkg/resource" +import "github.com/snyk/driftctl/enumeration/resource" func InitResourcesMetadata(resourceSchemaRepository resource.SchemaRepositoryInterface) { initGoogleStorageBucketMetadata(resourceSchemaRepository) initGoogleComputeFirewallMetadata(resourceSchemaRepository) - initGoogleComputeRouterMetadata(resourceSchemaRepository) initGoogleComputeNetworkMetadata(resourceSchemaRepository) initGoogleStorageBucketIamBMemberMetadata(resourceSchemaRepository) initGoogleComputeInstanceGroupMetadata(resourceSchemaRepository) - initGoogleBigqueryDatasetMetadata(resourceSchemaRepository) - initGoogleBigqueryTableMetadata(resourceSchemaRepository) initGoogleProjectIAMMemberMetadata(resourceSchemaRepository) - initGoogleComputeAddressMetadata(resourceSchemaRepository) - initGoogleComputeGlobalAddressMetadata(resourceSchemaRepository) initGoogleComputeSubnetworkMetadata(resourceSchemaRepository) - initGoogleComputeDiskMetadata(resourceSchemaRepository) - initGoogleComputeImageMetadata(resourceSchemaRepository) - initGoogleComputeHealthCheckMetadata(resourceSchemaRepository) - initComputeInstanceGroupManagerMetadata(resourceSchemaRepository) } diff --git a/pkg/resource/google/testdata/acc/google_bigquery_dataset/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_bigquery_dataset/.terraform.lock.hcl index 9ba9d75b..805a4551 100644 --- a/pkg/resource/google/testdata/acc/google_bigquery_dataset/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_bigquery_dataset/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", diff --git a/pkg/resource/google/testdata/acc/google_bigquery_table/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_bigquery_table/.terraform.lock.hcl index 9ba9d75b..805a4551 100644 --- a/pkg/resource/google/testdata/acc/google_bigquery_table/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_bigquery_table/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", diff --git a/pkg/resource/google/testdata/acc/google_bigtable_instance/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_bigtable_instance/.terraform.lock.hcl index 9ba9d75b..805a4551 100644 --- a/pkg/resource/google/testdata/acc/google_bigtable_instance/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_bigtable_instance/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", diff --git a/pkg/resource/google/testdata/acc/google_bigtable_table/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_bigtable_table/.terraform.lock.hcl index 9ba9d75b..805a4551 100644 --- a/pkg/resource/google/testdata/acc/google_bigtable_table/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_bigtable_table/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", diff --git a/pkg/resource/google/testdata/acc/google_cloudfunctions_function/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_cloudfunctions_function/.terraform.lock.hcl index c14895f0..309981b6 100644 --- a/pkg/resource/google/testdata/acc/google_cloudfunctions_function/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_cloudfunctions_function/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", @@ -24,6 +25,7 @@ provider "registry.terraform.io/hashicorp/random" { version = "3.1.0" hashes = [ "h1:BZMEPucF+pbu9gsPk0G0BHx7YP04+tKdq2MrRDF1EDM=", + "h1:rKYu5ZUbXwrLG1w81k7H3nce/Ys6yAxXhWcbtk36HjY=", "zh:2bbb3339f0643b5daa07480ef4397bd23a79963cc364cdfbb4e86354cb7725bc", "zh:3cd456047805bf639fbf2c761b1848880ea703a054f76db51852008b11008626", "zh:4f251b0eda5bb5e3dc26ea4400dba200018213654b69b4a5f96abee815b4f5ff", diff --git a/pkg/resource/google/testdata/acc/google_cloudrun_service/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_cloudrun_service/.terraform.lock.hcl index 9ba9d75b..805a4551 100644 --- a/pkg/resource/google/testdata/acc/google_cloudrun_service/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_cloudrun_service/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", diff --git a/pkg/resource/google/testdata/acc/google_compute_address/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_compute_address/.terraform.lock.hcl index 9ba9d75b..805a4551 100644 --- a/pkg/resource/google/testdata/acc/google_compute_address/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_compute_address/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", diff --git a/pkg/resource/google/testdata/acc/google_compute_disk/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_compute_disk/.terraform.lock.hcl index 9ba9d75b..805a4551 100644 --- a/pkg/resource/google/testdata/acc/google_compute_disk/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_compute_disk/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", diff --git a/pkg/resource/google/testdata/acc/google_compute_firewall/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_compute_firewall/.terraform.lock.hcl index 9ba9d75b..805a4551 100644 --- a/pkg/resource/google/testdata/acc/google_compute_firewall/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_compute_firewall/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", diff --git a/pkg/resource/google/testdata/acc/google_compute_global_address/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_compute_global_address/.terraform.lock.hcl index 9ba9d75b..805a4551 100644 --- a/pkg/resource/google/testdata/acc/google_compute_global_address/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_compute_global_address/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", diff --git a/pkg/resource/google/testdata/acc/google_compute_health_check/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_compute_health_check/.terraform.lock.hcl index 9ba9d75b..805a4551 100644 --- a/pkg/resource/google/testdata/acc/google_compute_health_check/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_compute_health_check/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", diff --git a/pkg/resource/google/testdata/acc/google_compute_image/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_compute_image/.terraform.lock.hcl index 9ba9d75b..805a4551 100644 --- a/pkg/resource/google/testdata/acc/google_compute_image/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_compute_image/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", diff --git a/pkg/resource/google/testdata/acc/google_compute_instance/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_compute_instance/.terraform.lock.hcl index 9ba9d75b..805a4551 100644 --- a/pkg/resource/google/testdata/acc/google_compute_instance/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_compute_instance/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", diff --git a/pkg/resource/google/testdata/acc/google_compute_instance_group_manager/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_compute_instance_group_manager/.terraform.lock.hcl index c14895f0..309981b6 100644 --- a/pkg/resource/google/testdata/acc/google_compute_instance_group_manager/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_compute_instance_group_manager/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", @@ -24,6 +25,7 @@ provider "registry.terraform.io/hashicorp/random" { version = "3.1.0" hashes = [ "h1:BZMEPucF+pbu9gsPk0G0BHx7YP04+tKdq2MrRDF1EDM=", + "h1:rKYu5ZUbXwrLG1w81k7H3nce/Ys6yAxXhWcbtk36HjY=", "zh:2bbb3339f0643b5daa07480ef4397bd23a79963cc364cdfbb4e86354cb7725bc", "zh:3cd456047805bf639fbf2c761b1848880ea703a054f76db51852008b11008626", "zh:4f251b0eda5bb5e3dc26ea4400dba200018213654b69b4a5f96abee815b4f5ff", diff --git a/pkg/resource/google/testdata/acc/google_compute_network/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_compute_network/.terraform.lock.hcl index c14895f0..309981b6 100644 --- a/pkg/resource/google/testdata/acc/google_compute_network/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_compute_network/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", @@ -24,6 +25,7 @@ provider "registry.terraform.io/hashicorp/random" { version = "3.1.0" hashes = [ "h1:BZMEPucF+pbu9gsPk0G0BHx7YP04+tKdq2MrRDF1EDM=", + "h1:rKYu5ZUbXwrLG1w81k7H3nce/Ys6yAxXhWcbtk36HjY=", "zh:2bbb3339f0643b5daa07480ef4397bd23a79963cc364cdfbb4e86354cb7725bc", "zh:3cd456047805bf639fbf2c761b1848880ea703a054f76db51852008b11008626", "zh:4f251b0eda5bb5e3dc26ea4400dba200018213654b69b4a5f96abee815b4f5ff", diff --git a/pkg/resource/google/testdata/acc/google_compute_router/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_compute_router/.terraform.lock.hcl index 9ba9d75b..805a4551 100644 --- a/pkg/resource/google/testdata/acc/google_compute_router/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_compute_router/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", diff --git a/pkg/resource/google/testdata/acc/google_compute_subnetwork/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_compute_subnetwork/.terraform.lock.hcl index c14895f0..309981b6 100644 --- a/pkg/resource/google/testdata/acc/google_compute_subnetwork/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_compute_subnetwork/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", @@ -24,6 +25,7 @@ provider "registry.terraform.io/hashicorp/random" { version = "3.1.0" hashes = [ "h1:BZMEPucF+pbu9gsPk0G0BHx7YP04+tKdq2MrRDF1EDM=", + "h1:rKYu5ZUbXwrLG1w81k7H3nce/Ys6yAxXhWcbtk36HjY=", "zh:2bbb3339f0643b5daa07480ef4397bd23a79963cc364cdfbb4e86354cb7725bc", "zh:3cd456047805bf639fbf2c761b1848880ea703a054f76db51852008b11008626", "zh:4f251b0eda5bb5e3dc26ea4400dba200018213654b69b4a5f96abee815b4f5ff", diff --git a/pkg/resource/google/testdata/acc/google_dns_managed_zone/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_dns_managed_zone/.terraform.lock.hcl index c14895f0..309981b6 100644 --- a/pkg/resource/google/testdata/acc/google_dns_managed_zone/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_dns_managed_zone/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", @@ -24,6 +25,7 @@ provider "registry.terraform.io/hashicorp/random" { version = "3.1.0" hashes = [ "h1:BZMEPucF+pbu9gsPk0G0BHx7YP04+tKdq2MrRDF1EDM=", + "h1:rKYu5ZUbXwrLG1w81k7H3nce/Ys6yAxXhWcbtk36HjY=", "zh:2bbb3339f0643b5daa07480ef4397bd23a79963cc364cdfbb4e86354cb7725bc", "zh:3cd456047805bf639fbf2c761b1848880ea703a054f76db51852008b11008626", "zh:4f251b0eda5bb5e3dc26ea4400dba200018213654b69b4a5f96abee815b4f5ff", diff --git a/pkg/resource/google/testdata/acc/google_sql_database_instance/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_sql_database_instance/.terraform.lock.hcl index c14895f0..309981b6 100644 --- a/pkg/resource/google/testdata/acc/google_sql_database_instance/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_sql_database_instance/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", @@ -24,6 +25,7 @@ provider "registry.terraform.io/hashicorp/random" { version = "3.1.0" hashes = [ "h1:BZMEPucF+pbu9gsPk0G0BHx7YP04+tKdq2MrRDF1EDM=", + "h1:rKYu5ZUbXwrLG1w81k7H3nce/Ys6yAxXhWcbtk36HjY=", "zh:2bbb3339f0643b5daa07480ef4397bd23a79963cc364cdfbb4e86354cb7725bc", "zh:3cd456047805bf639fbf2c761b1848880ea703a054f76db51852008b11008626", "zh:4f251b0eda5bb5e3dc26ea4400dba200018213654b69b4a5f96abee815b4f5ff", diff --git a/pkg/resource/google/testdata/acc/google_storage_bucket/.terraform.lock.hcl b/pkg/resource/google/testdata/acc/google_storage_bucket/.terraform.lock.hcl index c14895f0..309981b6 100644 --- a/pkg/resource/google/testdata/acc/google_storage_bucket/.terraform.lock.hcl +++ b/pkg/resource/google/testdata/acc/google_storage_bucket/.terraform.lock.hcl @@ -5,6 +5,7 @@ provider "registry.terraform.io/hashicorp/google" { version = "3.78.0" constraints = "3.78.0" hashes = [ + "h1:Seut9gKb/KzzUMxa9Qo59LRWcgURfBWMarqNTRjxnXE=", "h1:iCyTW8BWdr6Bvd5B89wkxlrB8xLxqHvT1CPmGuKembU=", "zh:027971c4689b6130619827fe57ce260aaca060db3446817d3a92869dba7cc07f", "zh:0876dbecc0d441bf2479edd17fe9141d77274b5071ea5f68ac26a2994bff66f3", @@ -24,6 +25,7 @@ provider "registry.terraform.io/hashicorp/random" { version = "3.1.0" hashes = [ "h1:BZMEPucF+pbu9gsPk0G0BHx7YP04+tKdq2MrRDF1EDM=", + "h1:rKYu5ZUbXwrLG1w81k7H3nce/Ys6yAxXhWcbtk36HjY=", "zh:2bbb3339f0643b5daa07480ef4397bd23a79963cc364cdfbb4e86354cb7725bc", "zh:3cd456047805bf639fbf2c761b1848880ea703a054f76db51852008b11008626", "zh:4f251b0eda5bb5e3dc26ea4400dba200018213654b69b4a5f96abee815b4f5ff", diff --git a/pkg/resource/init_metadatas.go b/pkg/resource/init_metadatas.go new file mode 100644 index 00000000..f1973b61 --- /dev/null +++ b/pkg/resource/init_metadatas.go @@ -0,0 +1,29 @@ +package resource + +import ( + "github.com/pkg/errors" + "github.com/snyk/driftctl/enumeration/remote/common" + "github.com/snyk/driftctl/enumeration/resource" + "github.com/snyk/driftctl/pkg/resource/aws" + "github.com/snyk/driftctl/pkg/resource/azurerm" + "github.com/snyk/driftctl/pkg/resource/github" + "github.com/snyk/driftctl/pkg/resource/google" +) + +func InitMetadatas(remote string, + resourceSchemaRepository *resource.SchemaRepository) error { + switch remote { + case common.RemoteAWSTerraform: + aws.InitResourcesMetadata(resourceSchemaRepository) + case common.RemoteGithubTerraform: + github.InitResourcesMetadata(resourceSchemaRepository) + case common.RemoteGoogleTerraform: + google.InitResourcesMetadata(resourceSchemaRepository) + case common.RemoteAzureTerraform: + azurerm.InitResourcesMetadata(resourceSchemaRepository) + + default: + return errors.Errorf("unsupported remote '%s'", remote) + } + return nil +} diff --git a/pkg/resource/mock_IaCSupplier.go b/pkg/resource/mock_IaCSupplier.go index 30907b86..508b0126 100644 --- a/pkg/resource/mock_IaCSupplier.go +++ b/pkg/resource/mock_IaCSupplier.go @@ -2,7 +2,10 @@ package resource -import mock "github.com/stretchr/testify/mock" +import ( + "github.com/snyk/driftctl/enumeration/resource" + mock "github.com/stretchr/testify/mock" +) // MockIaCSupplier is an autogenerated mock type for the IaCSupplier type type MockIaCSupplier struct { @@ -10,15 +13,15 @@ type MockIaCSupplier struct { } // Resources provides a mock function with given fields: -func (_m *MockIaCSupplier) Resources() ([]*Resource, error) { +func (_m *MockIaCSupplier) Resources() ([]*resource.Resource, error) { ret := _m.Called() - var r0 []*Resource - if rf, ok := ret.Get(0).(func() []*Resource); ok { + var r0 []*resource.Resource + if rf, ok := ret.Get(0).(func() []*resource.Resource); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*Resource) + r0 = ret.Get(0).([]*resource.Resource) } } diff --git a/pkg/resource/mock_Supplier.go b/pkg/resource/mock_Supplier.go index 981d4f48..4367ce5e 100644 --- a/pkg/resource/mock_Supplier.go +++ b/pkg/resource/mock_Supplier.go @@ -3,6 +3,7 @@ package resource import mock "github.com/stretchr/testify/mock" +import "github.com/snyk/driftctl/enumeration/resource" // MockSupplier is an autogenerated mock type for the Supplier type type MockSupplier struct { @@ -10,15 +11,15 @@ type MockSupplier struct { } // Resources provides a mock function with given fields: -func (_m *MockSupplier) Resources() ([]*Resource, error) { +func (_m *MockSupplier) Resources() ([]*resource.Resource, error) { ret := _m.Called() - var r0 []*Resource - if rf, ok := ret.Get(0).(func() []*Resource); ok { + var r0 []*resource.Resource + if rf, ok := ret.Get(0).(func() []*resource.Resource); ok { r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*Resource) + r0 = ret.Get(0).([]*resource.Resource) } } diff --git a/pkg/resource/supplier.go b/pkg/resource/supplier.go index 070fc8da..aa7af767 100644 --- a/pkg/resource/supplier.go +++ b/pkg/resource/supplier.go @@ -1,17 +1,9 @@ package resource -// Supplier supply the list of resource.Resource, it's the main interface to retrieve remote resources -type Supplier interface { - Resources() ([]*Resource, error) -} +import "github.com/snyk/driftctl/enumeration/resource" // IaCSupplier supply the list of resource.Resource, it's the main interface to retrieve state resources type IaCSupplier interface { - Supplier + resource.Supplier SourceCount() uint } - -type StoppableSupplier interface { - Supplier - Stop() -} diff --git a/pkg/telemetry/telemetry_test.go b/pkg/telemetry/telemetry_test.go index d7149e61..c2e2fc13 100644 --- a/pkg/telemetry/telemetry_test.go +++ b/pkg/telemetry/telemetry_test.go @@ -8,9 +8,9 @@ import ( "testing" "github.com/jarcoal/httpmock" + "github.com/snyk/driftctl/enumeration/resource" "github.com/snyk/driftctl/pkg/analyser" "github.com/snyk/driftctl/pkg/memstore" - "github.com/snyk/driftctl/pkg/resource" "github.com/snyk/driftctl/pkg/version" "github.com/snyk/driftctl/test/mocks" "github.com/spf13/viper" diff --git a/pkg/terraform/mock_ResourceFactory.go b/pkg/terraform/mock_ResourceFactory.go deleted file mode 100644 index b3e0d0f9..00000000 --- a/pkg/terraform/mock_ResourceFactory.go +++ /dev/null @@ -1,53 +0,0 @@ -// Code generated by mockery v2.3.0. DO NOT EDIT. - -package terraform - -import ( - "github.com/snyk/driftctl/pkg/resource" - mock "github.com/stretchr/testify/mock" - cty "github.com/zclconf/go-cty/cty" -) - -// MockResourceFactory is an autogenerated mock type for the ResourceFactory type -type MockResourceFactory struct { - mock.Mock -} - -// CreateAbstractResource provides a mock function with given fields: ty, id, data -func (_m *MockResourceFactory) CreateAbstractResource(ty string, id string, data map[string]interface{}) *resource.Resource { - ret := _m.Called(ty, id, data) - - var r0 *resource.Resource - if rf, ok := ret.Get(0).(func(string, string, map[string]interface{}) *resource.Resource); ok { - r0 = rf(ty, id, data) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*resource.Resource) - } - } - - return r0 -} - -// CreateResource provides a mock function with given fields: data, ty -func (_m *MockResourceFactory) CreateResource(data interface{}, ty string) (*cty.Value, error) { - ret := _m.Called(data, ty) - - var r0 *cty.Value - if rf, ok := ret.Get(0).(func(interface{}, string) *cty.Value); ok { - r0 = rf(data, ty) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*cty.Value) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(interface{}, string) error); ok { - r1 = rf(data, ty) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} diff --git a/pkg/terraform/parallel_resource_reader.go b/pkg/terraform/parallel_resource_reader.go deleted file mode 100644 index 4ee1cc17..00000000 --- a/pkg/terraform/parallel_resource_reader.go +++ /dev/null @@ -1,43 +0,0 @@ -package terraform - -import ( - "github.com/snyk/driftctl/pkg/parallel" - - "github.com/zclconf/go-cty/cty" -) - -type ParallelResourceReader struct { - runner *parallel.ParallelRunner -} - -func NewParallelResourceReader(runner *parallel.ParallelRunner) *ParallelResourceReader { - return &ParallelResourceReader{ - runner: runner, - } -} - -func (p *ParallelResourceReader) Wait() ([]cty.Value, error) { - results := make([]cty.Value, 0) -Loop: - for { - select { - case res, ok := <-p.runner.Read(): - if !ok { - break Loop - } - ctyVal := res.(cty.Value) - if !ctyVal.IsNull() { - results = append(results, ctyVal) - } - case <-p.runner.DoneChan(): - break Loop - } - } - return results, p.runner.Err() -} - -func (p *ParallelResourceReader) Run(runnable func() (cty.Value, error)) { - p.runner.Run(func() (interface{}, error) { - return runnable() - }) -} diff --git a/pkg/terraform/parallel_resource_reader_test.go b/pkg/terraform/parallel_resource_reader_test.go deleted file mode 100644 index 134b7846..00000000 --- a/pkg/terraform/parallel_resource_reader_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package terraform - -import ( - "context" - "errors" - "strings" - "testing" - - "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/parallel" - - "github.com/stretchr/testify/assert" - - "github.com/zclconf/go-cty/cty" -) - -func TestParallelResourceReader_Wait(t *testing.T) { - assert := assert.New(t) - tests := []struct { - name string - execs []func() (cty.Value, error) - want []cty.Value - wantErr bool - }{ - { - name: "Working // read resource", - execs: []func() (cty.Value, error){ - func() (cty.Value, error) { - return cty.BoolVal(true), nil - }, - func() (cty.Value, error) { - return cty.StringVal("test"), nil - }, - }, - want: []cty.Value{cty.BoolVal(true), cty.StringVal("test")}, - wantErr: false, - }, - - { - name: "failing // read resource", - execs: []func() (cty.Value, error){ - func() (cty.Value, error) { - return cty.BoolVal(true), nil - }, - func() (cty.Value, error) { - return cty.NilVal, errors.New("error") - }, - func() (cty.Value, error) { - return cty.StringVal("test"), nil - }, - }, - want: nil, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p := NewParallelResourceReader(parallel.NewParallelRunner(context.TODO(), 10)) - - for _, fun := range tt.execs { - p.Run(fun) - } - - got, err := p.Wait() - assert.Equal(tt.wantErr, err != nil) - if tt.want != nil { - changelog, err := diff.Diff(got, tt.want) - if err != nil { - panic(err) - } - if len(changelog) > 0 { - for _, change := range changelog { - t.Errorf("%s got = %v, want %v", strings.Join(change.Path, "."), change.From, change.To) - } - } - } - }) - } -} diff --git a/pkg/terraform/provider_downloader.go b/pkg/terraform/provider_downloader.go deleted file mode 100644 index d9463ad4..00000000 --- a/pkg/terraform/provider_downloader.go +++ /dev/null @@ -1,80 +0,0 @@ -package terraform - -import ( - "context" - "io/ioutil" - "net/http" - "os" - - "github.com/hashicorp/go-getter" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - error2 "github.com/snyk/driftctl/pkg/terraform/error" -) - -type ProviderDownloaderInterface interface { - Download(url, path string) error -} - -type ProviderDownloader struct { - httpclient *http.Client - unzip getter.ZipDecompressor - context context.Context -} - -func NewProviderDownloader() *ProviderDownloader { - return &ProviderDownloader{ - httpclient: http.DefaultClient, - unzip: getter.ZipDecompressor{}, - context: context.Background(), - } -} - -func (p *ProviderDownloader) Download(url, path string) error { - logrus.WithFields(logrus.Fields{ - "url": url, - "path": path, - }).Debug("Downloading provider") - - req, err := http.NewRequestWithContext(p.context, "GET", url, nil) - if err != nil { - return err - } - resp, err := p.httpclient.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode == http.StatusForbidden { - return error2.ProviderNotFoundError{} - } - if resp.StatusCode != http.StatusOK { - return errors.Errorf("unsuccessful request to %s: %s", url, resp.Status) - } - f, err := ioutil.TempFile("", "terraform-provider") - if err != nil { - return errors.Errorf("failed to open temporary file to download from %s", url) - } - defer f.Close() - defer os.Remove(f.Name()) - n, err := getter.Copy(p.context, f, resp.Body) - if err == nil && n < resp.ContentLength { - err = errors.Errorf( - "incorrect response size: expected %d bytes, but got %d bytes", - resp.ContentLength, - n, - ) - } - if err != nil { - return err - } - logrus.WithFields(logrus.Fields{ - "src": f.Name(), - "dst": path, - }).Debug("Decompressing archive") - err = p.unzip.Decompress(path, f.Name(), true, 0) - if err != nil { - return err - } - return nil -} diff --git a/pkg/terraform/provider_downloader_test.go b/pkg/terraform/provider_downloader_test.go deleted file mode 100644 index a77734f3..00000000 --- a/pkg/terraform/provider_downloader_test.go +++ /dev/null @@ -1,110 +0,0 @@ -package terraform - -import ( - "fmt" - "io/ioutil" - "net/http" - "path" - "testing" - - "github.com/aws/aws-sdk-go/aws" - terraformError "github.com/snyk/driftctl/pkg/terraform/error" - - "github.com/stretchr/testify/assert" - - "github.com/jarcoal/httpmock" -) - -func TestProviderDownloader_Download(t *testing.T) { - httpmock.Activate() - defer httpmock.DeactivateAndReset() - downloader := NewProviderDownloader() - url := "https://example.com/terraform-provider-aws_3.19.0_linux_amd64.zip" - - cases := []struct { - name string - httpStatus *int - testFile *string - responder httpmock.Responder - assert func(assert *assert.Assertions, tmpDir string, err error) - }{ - { - name: "TestBadResponse(404)", - responder: httpmock.NewBytesResponder(http.StatusNotFound, []byte{}), - assert: func(assert *assert.Assertions, tmpDir string, err error) { - assert.Equal( - fmt.Sprintf("unsuccessful request to %s: 404", url), - err.Error(), - ) - }, - }, - { - name: "TestProviderNotFound(403)", - responder: httpmock.NewBytesResponder(http.StatusForbidden, []byte{}), - assert: func(assert *assert.Assertions, tmpDir string, err error) { - assert.IsType( - terraformError.ProviderNotFoundError{}, - err, - ) - }, - }, - { - name: "TestHttpError", - responder: httpmock.NewErrorResponder(fmt.Errorf("test error")), - assert: func(assert *assert.Assertions, tmpDir string, err error) { - assert.Contains(err.Error(), "test error") - }, - }, - { - name: "TestInvalidZip", - testFile: aws.String("invalid.zip"), - assert: func(assert *assert.Assertions, tmpDir string, err error) { - assert.NotNil(err) - infos, err := ioutil.ReadDir(tmpDir) - assert.Nil(err) - assert.Len(infos, 0) - }, - }, - { - name: "TestValidZip", - testFile: aws.String("terraform-provider-aws_3.5.0_linux_amd64.zip"), - assert: func(assert *assert.Assertions, tmpDir string, err error) { - assert.Nil(err) - file, err := ioutil.ReadFile(path.Join(tmpDir, "terraform-provider-aws_v3.5.0_x5")) - assert.Nil(err) - assert.Equal([]byte{0x74, 0x65, 0x73, 0x74, 0xa}, file) - }, - }, - } - - for _, c := range cases { - - t.Run(c.name, func(tt *testing.T) { - tmpDir := tt.TempDir() - - httpmock.Reset() - assert := assert.New(tt) - - if c.httpStatus == nil { - c.httpStatus = aws.Int(http.StatusOK) - } - - if c.responder != nil { - httpmock.RegisterResponder("GET", url, c.responder) - } else { - if c.testFile != nil { - body, err := ioutil.ReadFile("./testdata/" + *c.testFile) - if err != nil { - tt.Error(err) - } - httpmock.RegisterResponder("GET", url, httpmock.NewBytesResponder(*c.httpStatus, body)) - } - } - - err := downloader.Download(url, tmpDir) - - c.assert(assert, tmpDir, err) - }) - - } -} diff --git a/pkg/terraform/provider_installer.go b/pkg/terraform/provider_installer.go deleted file mode 100644 index bbe6daa5..00000000 --- a/pkg/terraform/provider_installer.go +++ /dev/null @@ -1,97 +0,0 @@ -package terraform - -import ( - "fmt" - "io/fs" - "os" - "path" - "path/filepath" - "runtime" - "strings" - - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - error2 "github.com/snyk/driftctl/pkg/terraform/error" - - "github.com/snyk/driftctl/pkg/output" -) - -type HomeDirInterface interface { - Dir() (string, error) -} - -type ProviderInstaller struct { - downloader ProviderDownloaderInterface - config ProviderConfig - homeDir string -} - -func NewProviderInstaller(config ProviderConfig) (*ProviderInstaller, error) { - return &ProviderInstaller{ - NewProviderDownloader(), - config, - config.ConfigDir, - }, nil -} - -func (p *ProviderInstaller) Install() (string, error) { - providerDir := p.getProviderDirectory() - providerPath := p.getBinaryPath() - - info, err := os.Stat(providerPath) - - if err != nil && os.IsNotExist(err) { - logrus.WithFields(logrus.Fields{ - "path": providerPath, - }).Debug("provider not found, downloading ...") - output.Printf("Downloading terraform provider: %s\n", p.config.Key) - err := p.downloader.Download( - p.config.GetDownloadUrl(), - providerDir, - ) - if err != nil { - if notFoundErr, ok := err.(error2.ProviderNotFoundError); ok { - notFoundErr.Version = p.config.Version - return "", notFoundErr - } - return "", err - } - logrus.Debug("Download successful") - } - - if info != nil && info.IsDir() { - return "", errors.Errorf( - "found directory instead of provider binary in %s", - providerPath, - ) - } - - if info != nil { - logrus.WithFields(logrus.Fields{ - "path": providerPath, - }).Debug("Found existing provider") - } - - return p.getBinaryPath(), nil -} - -func (p ProviderInstaller) getProviderDirectory() string { - return path.Join(p.homeDir, fmt.Sprintf(".driftctl/plugins/%s_%s/", runtime.GOOS, runtime.GOARCH)) -} - -// Handle postfixes in binary names -func (p *ProviderInstaller) getBinaryPath() string { - providerDir := p.getProviderDirectory() - binaryName := p.config.GetBinaryName() - _, err := os.Stat(path.Join(providerDir, binaryName)) - if err != nil && os.IsNotExist(err) { - _ = filepath.WalkDir(providerDir, func(filePath string, d fs.DirEntry, err error) error { - if d != nil && strings.HasPrefix(d.Name(), p.config.GetBinaryName()) { - binaryName = d.Name() - } - return nil - }) - } - - return path.Join(providerDir, binaryName) -} diff --git a/pkg/terraform/provider_installer_test.go b/pkg/terraform/provider_installer_test.go deleted file mode 100644 index d861b32b..00000000 --- a/pkg/terraform/provider_installer_test.go +++ /dev/null @@ -1,206 +0,0 @@ -package terraform - -import ( - "fmt" - "os" - "path" - "runtime" - "testing" - - "github.com/snyk/driftctl/mocks" - terraformError "github.com/snyk/driftctl/pkg/terraform/error" - "github.com/stretchr/testify/mock" - - "github.com/stretchr/testify/assert" -) - -func TestProviderInstallerInstallDoesNotExist(t *testing.T) { - - assert := assert.New(t) - fakeTmpHome := t.TempDir() - - expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH) - - config := ProviderConfig{ - Key: "aws", - Version: "3.19.0", - } - - mockDownloader := mocks.ProviderDownloaderInterface{} - mockDownloader.On("Download", config.GetDownloadUrl(), path.Join(fakeTmpHome, expectedSubFolder)).Return(nil) - - installer := ProviderInstaller{ - downloader: &mockDownloader, - config: config, - homeDir: fakeTmpHome, - } - - providerPath, err := installer.Install() - mockDownloader.AssertExpectations(t) - - assert.Nil(err) - assert.Equal(path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName()), providerPath) - -} - -func TestProviderInstallerInstallAlreadyExist(t *testing.T) { - - assert := assert.New(t) - fakeTmpHome := t.TempDir() - expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH) - err := os.MkdirAll(path.Join(fakeTmpHome, expectedSubFolder), 0755) - if err != nil { - t.Error(err) - } - - config := ProviderConfig{ - Key: "aws", - Version: "3.19.0", - } - - _, err = os.Create(path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName())) - if err != nil { - t.Error(err) - } - - mockDownloader := mocks.ProviderDownloaderInterface{} - - installer := ProviderInstaller{ - downloader: &mockDownloader, - config: config, - homeDir: fakeTmpHome, - } - - providerPath, err := installer.Install() - mockDownloader.AssertExpectations(t) - - assert.Nil(err) - assert.Equal(path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName()), providerPath) - -} - -func TestProviderInstallerInstallAlreadyExistButIsDirectory(t *testing.T) { - - assert := assert.New(t) - fakeTmpHome := t.TempDir() - expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH) - - config := ProviderConfig{ - Key: "aws", - Version: "3.19.0", - } - - invalidDirPath := path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName()) - err := os.MkdirAll(invalidDirPath, 0755) - if err != nil { - t.Error(err) - } - - mockDownloader := mocks.ProviderDownloaderInterface{} - - installer := ProviderInstaller{ - downloader: &mockDownloader, - config: config, - homeDir: fakeTmpHome, - } - - providerPath, err := installer.Install() - mockDownloader.AssertExpectations(t) - - assert.Empty(providerPath) - assert.NotNil(err) - assert.Equal( - fmt.Sprintf( - "found directory instead of provider binary in %s", - invalidDirPath, - ), - err.Error(), - ) - -} - -// Ensure that if a provider exists with a postfix (_x5) we properly detect it -func TestProviderInstallerInstallPostfixIsHandler(t *testing.T) { - - assert := assert.New(t) - fakeTmpHome := t.TempDir() - expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH) - err := os.MkdirAll(path.Join(fakeTmpHome, expectedSubFolder), 0755) - if err != nil { - t.Error(err) - } - - config := ProviderConfig{ - Key: "aws", - Version: "3.19.0", - } - - _, err = os.Create(path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName()+"_x5")) - if err != nil { - t.Fatal(err) - } - - mockDownloader := mocks.ProviderDownloaderInterface{} - - installer := ProviderInstaller{ - downloader: &mockDownloader, - config: config, - homeDir: fakeTmpHome, - } - - providerPath, err := installer.Install() - mockDownloader.AssertExpectations(t) - - assert.Nil(err) - assert.Equal(path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName()+"_x5"), providerPath) - -} - -func TestProviderInstallerVersionDoesNotExist(t *testing.T) { - - assert := assert.New(t) - - config := ProviderConfig{ - Key: "aws", - Version: "666.666.666", - } - - mockDownloader := mocks.ProviderDownloaderInterface{} - mockDownloader.On("Download", mock.Anything, mock.Anything).Return(terraformError.ProviderNotFoundError{}) - - installer := ProviderInstaller{ - downloader: &mockDownloader, - config: config, - } - - _, err := installer.Install() - - assert.Equal("Provider version 666.666.666 does not exist", err.Error()) -} - -func TestProviderInstallerWithConfigDirectory(t *testing.T) { - - assert := assert.New(t) - fakeTmpHome := t.TempDir() - - expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH) - - config := ProviderConfig{ - Key: "aws", - Version: "3.19.0", - ConfigDir: fakeTmpHome, - } - - mockDownloader := mocks.ProviderDownloaderInterface{} - mockDownloader.On("Download", config.GetDownloadUrl(), path.Join(fakeTmpHome, expectedSubFolder)).Return(nil) - - installer, _ := NewProviderInstaller(config) - installer.downloader = &mockDownloader - - providerPath, err := installer.Install() - mockDownloader.AssertExpectations(t) - - assert.Nil(err) - assert.Equal(path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName()), providerPath) - -} diff --git a/pkg/terraform/resource_factory.go b/pkg/terraform/resource_factory.go deleted file mode 100644 index 33670989..00000000 --- a/pkg/terraform/resource_factory.go +++ /dev/null @@ -1,35 +0,0 @@ -package terraform - -import ( - "github.com/snyk/driftctl/pkg/resource" -) - -type TerraformResourceFactory struct { - resourceSchemaRepository resource.SchemaRepositoryInterface -} - -func NewTerraformResourceFactory(resourceSchemaRepository resource.SchemaRepositoryInterface) *TerraformResourceFactory { - return &TerraformResourceFactory{ - resourceSchemaRepository: resourceSchemaRepository, - } -} - -func (r *TerraformResourceFactory) CreateAbstractResource(ty, id string, data map[string]interface{}) *resource.Resource { - attributes := resource.Attributes(data) - attributes.SanitizeDefaults() - - schema, _ := r.resourceSchemaRepository.GetSchema(ty) - res := resource.Resource{ - Id: id, - Type: ty, - Attrs: &attributes, - Sch: schema, - } - - schema, exist := r.resourceSchemaRepository.(*resource.SchemaRepository).GetSchema(ty) - if exist && schema.NormalizeFunc != nil { - schema.NormalizeFunc(&res) - } - - return &res -} diff --git a/pkg/terraform/resource_reader.go b/pkg/terraform/resource_reader.go deleted file mode 100644 index 8600848c..00000000 --- a/pkg/terraform/resource_reader.go +++ /dev/null @@ -1,17 +0,0 @@ -package terraform - -import ( - "github.com/snyk/driftctl/pkg/resource" - - "github.com/zclconf/go-cty/cty" -) - -type ResourceReader interface { - ReadResource(args ReadResourceArgs) (*cty.Value, error) -} - -type ReadResourceArgs struct { - Ty resource.ResourceType - ID string - Attributes map[string]string -} diff --git a/test/cty_test_diff.go b/test/cty_test_diff.go index debb45e5..8d292dde 100644 --- a/test/cty_test_diff.go +++ b/test/cty_test_diff.go @@ -5,13 +5,13 @@ import ( "strings" "testing" + "github.com/snyk/driftctl/enumeration/terraform" + "github.com/snyk/driftctl/test/goldenfile" "github.com/zclconf/go-cty/cty/json" - "github.com/snyk/driftctl/pkg/resource" - - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/resource" "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" diff --git a/test/mocks/MockGoldenTerraformProvider.go b/test/mocks/MockGoldenTerraformProvider.go index 68d6c050..f570a2d3 100644 --- a/test/mocks/MockGoldenTerraformProvider.go +++ b/test/mocks/MockGoldenTerraformProvider.go @@ -4,23 +4,23 @@ import ( gojson "encoding/json" "errors" "fmt" + terraform2 "github.com/snyk/driftctl/enumeration/terraform" "sort" "github.com/snyk/driftctl/test/goldenfile" "github.com/hashicorp/terraform/providers" - "github.com/snyk/driftctl/pkg/terraform" "github.com/zclconf/go-cty/cty" ctyjson "github.com/zclconf/go-cty/cty/json" ) type MockedGoldenTFProvider struct { name string - realProvider terraform.TerraformProvider + realProvider terraform2.TerraformProvider update bool } -func NewMockedGoldenTFProvider(name string, realProvider terraform.TerraformProvider, update bool) *MockedGoldenTFProvider { +func NewMockedGoldenTFProvider(name string, realProvider terraform2.TerraformProvider, update bool) *MockedGoldenTFProvider { return &MockedGoldenTFProvider{name: name, realProvider: realProvider, update: update} } @@ -33,7 +33,7 @@ func (m *MockedGoldenTFProvider) Schema() map[string]providers.Schema { return m.readSchema() } -func (m *MockedGoldenTFProvider) ReadResource(args terraform.ReadResourceArgs) (*cty.Value, error) { +func (m *MockedGoldenTFProvider) ReadResource(args terraform2.ReadResourceArgs) (*cty.Value, error) { if m.update { readResource, err := m.realProvider.ReadResource(args) m.writeReadResource(args, readResource, err) @@ -60,7 +60,7 @@ func (m *MockedGoldenTFProvider) readSchema() map[string]providers.Schema { return schema } -func (m *MockedGoldenTFProvider) writeReadResource(args terraform.ReadResourceArgs, readResource *cty.Value, err error) { +func (m *MockedGoldenTFProvider) writeReadResource(args terraform2.ReadResourceArgs, readResource *cty.Value, err error) { var readRes = ReadResource{ Value: readResource, Err: err, @@ -74,7 +74,7 @@ func (m *MockedGoldenTFProvider) writeReadResource(args terraform.ReadResourceAr goldenfile.WriteFile(m.name, marshalled, fileName) } -func (m *MockedGoldenTFProvider) readReadResource(args terraform.ReadResourceArgs) (*cty.Value, error) { +func (m *MockedGoldenTFProvider) readReadResource(args terraform2.ReadResourceArgs) (*cty.Value, error) { fileName := getFileName(args) // TODO I'm putting this here for compatibility reason... if !goldenfile.FileExists(m.name, fileName) { @@ -146,13 +146,13 @@ func (m *ReadResource) MarshalJSON() ([]byte, error) { return gojson.Marshal(unm) } -func getFileName(args terraform.ReadResourceArgs) string { +func getFileName(args terraform2.ReadResourceArgs) string { suffix := getFileNameSuffix(args) fileName := fmt.Sprintf("%s-%s%s.res.golden.json", args.Ty, args.ID, suffix) return fileName } -func getFileNameSuffix(args terraform.ReadResourceArgs) string { +func getFileNameSuffix(args terraform2.ReadResourceArgs) string { suffix := "" keys := make([]string, 0, len(args.Attributes)) for k := range args.Attributes { diff --git a/test/remote/scanner.go b/test/remote/scanner.go index 65183ad9..63999461 100644 --- a/test/remote/scanner.go +++ b/test/remote/scanner.go @@ -1,7 +1,7 @@ package remote import ( - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" ) type SortableScanner struct { diff --git a/test/resource/resource.go b/test/resource/resource.go index f9a32ee3..f720d392 100644 --- a/test/resource/resource.go +++ b/test/resource/resource.go @@ -2,7 +2,7 @@ package resource import ( "github.com/hashicorp/terraform/providers" - "github.com/snyk/driftctl/pkg/resource" + "github.com/snyk/driftctl/enumeration/resource" "github.com/snyk/driftctl/test/schemas" ) diff --git a/test/terraform/fake_terraform_provider.go b/test/terraform/fake_terraform_provider.go index 19bc4c5f..6264448d 100644 --- a/test/terraform/fake_terraform_provider.go +++ b/test/terraform/fake_terraform_provider.go @@ -9,7 +9,7 @@ import ( "github.com/hashicorp/terraform/providers" "github.com/pkg/errors" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/terraform" "github.com/snyk/driftctl/test/goldenfile" "github.com/snyk/driftctl/test/mocks" "github.com/snyk/driftctl/test/schemas" diff --git a/test/terraform/provider.go b/test/terraform/provider.go index 35892d71..9aa5cf6e 100644 --- a/test/terraform/provider.go +++ b/test/terraform/provider.go @@ -1,14 +1,14 @@ package terraform import ( + "github.com/snyk/driftctl/enumeration/remote/aws" + "github.com/snyk/driftctl/enumeration/remote/azurerm" + "github.com/snyk/driftctl/enumeration/remote/github" + "github.com/snyk/driftctl/enumeration/remote/google" + "github.com/snyk/driftctl/enumeration/terraform" "os" "github.com/snyk/driftctl/pkg/output" - "github.com/snyk/driftctl/pkg/remote/aws" - "github.com/snyk/driftctl/pkg/remote/azurerm" - "github.com/snyk/driftctl/pkg/remote/github" - "github.com/snyk/driftctl/pkg/remote/google" - "github.com/snyk/driftctl/pkg/terraform" ) func InitTestAwsProvider(providerLibrary *terraform.ProviderLibrary, version string) (*aws.AWSTerraformProvider, error) { diff --git a/test/terraform/schemas_test.go b/test/terraform/schemas_test.go index 30297bb2..398f7f2a 100644 --- a/test/terraform/schemas_test.go +++ b/test/terraform/schemas_test.go @@ -4,7 +4,8 @@ import ( "os" "testing" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/terraform" + "github.com/snyk/driftctl/test/schemas" ) diff --git a/test/test_diff.go b/test/test_diff.go index 4cf5b08c..9e424f50 100644 --- a/test/test_diff.go +++ b/test/test_diff.go @@ -4,10 +4,11 @@ import ( "strings" "testing" + "github.com/snyk/driftctl/enumeration/terraform" + "github.com/aws/aws-sdk-go/aws/awsutil" "github.com/r3labs/diff/v2" - "github.com/snyk/driftctl/pkg/resource" - "github.com/snyk/driftctl/pkg/terraform" + "github.com/snyk/driftctl/enumeration/resource" "github.com/snyk/driftctl/test/goldenfile" "github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty/gocty"