chore: extract enumeration to it's own submodule

main
Martin Guibert 2022-06-28 09:23:29 +02:00 committed by Elie CHARRA
parent 516609f12d
commit dc58f94ce9
No known key found for this signature in database
GPG Key ID: 399AF69092C727B6
1624 changed files with 51857 additions and 50813 deletions

View File

@ -0,0 +1,75 @@
package alerter
import (
"fmt"
"github.com/snyk/driftctl/enumeration/resource"
)
type AlerterInterface interface {
SendAlert(key string, alert Alert)
}
type Alerter struct {
alerts Alerts
alertsCh chan Alerts
doneCh chan bool
}
func NewAlerter() *Alerter {
var alerter = &Alerter{
alerts: make(Alerts),
alertsCh: make(chan Alerts),
doneCh: make(chan bool),
}
go alerter.run()
return alerter
}
func (a *Alerter) run() {
defer func() { a.doneCh <- true }()
for alert := range a.alertsCh {
for k, v := range alert {
if val, ok := a.alerts[k]; ok {
a.alerts[k] = append(val, v...)
} else {
a.alerts[k] = v
}
}
}
}
func (a *Alerter) SetAlerts(alerts Alerts) {
a.alerts = alerts
}
func (a *Alerter) Retrieve() Alerts {
close(a.alertsCh)
<-a.doneCh
return a.alerts
}
func (a *Alerter) SendAlert(key string, alert Alert) {
a.alertsCh <- Alerts{
key: []Alert{alert},
}
}
func (a *Alerter) IsResourceIgnored(res *resource.Resource) bool {
alert, alertExists := a.alerts[fmt.Sprintf("%s.%s", res.ResourceType(), res.ResourceId())]
wildcardAlert, wildcardAlertExists := a.alerts[res.ResourceType()]
shouldIgnoreAlert := a.shouldBeIgnored(alert)
shouldIgnoreWildcardAlert := a.shouldBeIgnored(wildcardAlert)
return (alertExists && shouldIgnoreAlert) || (wildcardAlertExists && shouldIgnoreWildcardAlert)
}
func (a *Alerter) shouldBeIgnored(alert []Alert) bool {
for _, a := range alert {
if a.ShouldIgnoreResource() {
return true
}
}
return false
}

View File

@ -0,0 +1,161 @@
package alerter
import (
"reflect"
"testing"
"github.com/snyk/driftctl/enumeration/resource"
)
func TestAlerter_Alert(t *testing.T) {
cases := []struct {
name string
alerts Alerts
expected Alerts
}{
{
name: "TestNoAlerts",
alerts: nil,
expected: Alerts{},
},
{
name: "TestWithSingleAlert",
alerts: Alerts{
"fakeres.foobar": []Alert{
&FakeAlert{"This is an alert", false},
},
},
expected: Alerts{
"fakeres.foobar": []Alert{
&FakeAlert{"This is an alert", false},
},
},
},
{
name: "TestWithMultipleAlerts",
alerts: Alerts{
"fakeres.foobar": []Alert{
&FakeAlert{"This is an alert", false},
&FakeAlert{"This is a second alert", true},
},
"fakeres.barfoo": []Alert{
&FakeAlert{"This is a third alert", true},
},
},
expected: Alerts{
"fakeres.foobar": []Alert{
&FakeAlert{"This is an alert", false},
&FakeAlert{"This is a second alert", true},
},
"fakeres.barfoo": []Alert{
&FakeAlert{"This is a third alert", true},
},
},
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
alerter := NewAlerter()
for k, v := range c.alerts {
for _, a := range v {
alerter.SendAlert(k, a)
}
}
if eq := reflect.DeepEqual(alerter.Retrieve(), c.expected); !eq {
t.Errorf("Got %+v, expected %+v", alerter.Retrieve(), c.expected)
}
})
}
}
func TestAlerter_IgnoreResources(t *testing.T) {
cases := []struct {
name string
alerts Alerts
resource *resource.Resource
expected bool
}{
{
name: "TestNoAlerts",
alerts: Alerts{},
resource: &resource.Resource{
Type: "fakeres",
Id: "foobar",
},
expected: false,
},
{
name: "TestShouldNotBeIgnoredWithAlerts",
alerts: Alerts{
"fakeres": {
&FakeAlert{"Should not be ignored", false},
},
"fakeres.foobar": {
&FakeAlert{"Should not be ignored", false},
},
"fakeres.barfoo": {
&FakeAlert{"Should not be ignored", false},
},
"other.resource": {
&FakeAlert{"Should not be ignored", false},
},
},
resource: &resource.Resource{
Type: "fakeres",
Id: "foobar",
},
expected: false,
},
{
name: "TestShouldBeIgnoredWithAlertsOnWildcard",
alerts: Alerts{
"fakeres": {
&FakeAlert{"Should be ignored", true},
},
"other.foobaz": {
&FakeAlert{"Should be ignored", true},
},
"other.resource": {
&FakeAlert{"Should not be ignored", false},
},
},
resource: &resource.Resource{
Type: "fakeres",
Id: "foobar",
},
expected: true,
},
{
name: "TestShouldBeIgnoredWithAlertsOnResource",
alerts: Alerts{
"fakeres": {
&FakeAlert{"Should be ignored", true},
},
"other.foobaz": {
&FakeAlert{"Should be ignored", true},
},
"other.resource": {
&FakeAlert{"Should not be ignored", false},
},
},
resource: &resource.Resource{
Type: "other",
Id: "foobaz",
},
expected: true,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
alerter := NewAlerter()
alerter.SetAlerts(c.alerts)
if got := alerter.IsResourceIgnored(c.resource); got != c.expected {
t.Errorf("Got %+v, expected %+v", got, c.expected)
}
})
}
}

9
enumeration/filter.go Normal file
View File

@ -0,0 +1,9 @@
package enumeration
import "github.com/snyk/driftctl/enumeration/resource"
type Filter interface {
IsTypeIgnored(ty resource.ResourceType) bool
IsResourceIgnored(res *resource.Resource) bool
IsFieldIgnored(res *resource.Resource, path []string) bool
}

View File

@ -0,0 +1,55 @@
// Code generated by mockery v0.0.0-dev. DO NOT EDIT.
package enumeration
import (
resource "github.com/snyk/driftctl/enumeration/resource"
mock "github.com/stretchr/testify/mock"
)
// MockFilter is an autogenerated mock type for the Filter type
type MockFilter struct {
mock.Mock
}
// IsFieldIgnored provides a mock function with given fields: res, path
func (_m *MockFilter) IsFieldIgnored(res *resource.Resource, path []string) bool {
ret := _m.Called(res, path)
var r0 bool
if rf, ok := ret.Get(0).(func(*resource.Resource, []string) bool); ok {
r0 = rf(res, path)
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// IsResourceIgnored provides a mock function with given fields: res
func (_m *MockFilter) IsResourceIgnored(res *resource.Resource) bool {
ret := _m.Called(res)
var r0 bool
if rf, ok := ret.Get(0).(func(*resource.Resource) bool); ok {
r0 = rf(res)
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// IsTypeIgnored provides a mock function with given fields: ty
func (_m *MockFilter) IsTypeIgnored(ty resource.ResourceType) bool {
ret := _m.Called(ty)
var r0 bool
if rf, ok := ret.Get(0).(func(resource.ResourceType) bool); ok {
r0 = rf(ty)
} else {
r0 = ret.Get(0).(bool)
}
return r0
}

5
enumeration/progress.go Normal file
View File

@ -0,0 +1,5 @@
package enumeration
type ProgressCounter interface {
Inc()
}

View File

@ -0,0 +1,97 @@
package alerts
import (
"fmt"
"github.com/snyk/driftctl/enumeration/alerter"
"github.com/snyk/driftctl/enumeration/remote/common"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/sirupsen/logrus"
)
type ScanningPhase int
const (
EnumerationPhase ScanningPhase = iota
DetailsFetchingPhase
)
type RemoteAccessDeniedAlert struct {
message string
provider string
scanningPhase ScanningPhase
}
func NewRemoteAccessDeniedAlert(provider string, scanErr *remoteerror.ResourceScanningError, scanningPhase ScanningPhase) *RemoteAccessDeniedAlert {
var message string
switch scanningPhase {
case EnumerationPhase:
message = fmt.Sprintf(
"Ignoring %s from drift calculation: Listing %s is forbidden: %s",
scanErr.Resource(),
scanErr.ListedTypeError(),
scanErr.RootCause().Error(),
)
case DetailsFetchingPhase:
message = fmt.Sprintf(
"Ignoring %s from drift calculation: Reading details of %s is forbidden: %s",
scanErr.Resource(),
scanErr.ListedTypeError(),
scanErr.RootCause().Error(),
)
default:
message = fmt.Sprintf(
"Ignoring %s from drift calculation: %s",
scanErr.Resource(),
scanErr.RootCause().Error(),
)
}
return &RemoteAccessDeniedAlert{message, provider, scanningPhase}
}
func (e *RemoteAccessDeniedAlert) Message() string {
return e.message
}
func (e *RemoteAccessDeniedAlert) ShouldIgnoreResource() bool {
return true
}
func (e *RemoteAccessDeniedAlert) GetProviderMessage() string {
var message string
if e.scanningPhase == DetailsFetchingPhase {
message = "It seems that we got access denied exceptions while reading details of resources.\n"
}
if e.scanningPhase == EnumerationPhase {
message = "It seems that we got access denied exceptions while listing resources.\n"
}
switch e.provider {
case common.RemoteGithubTerraform:
message += "Please be sure that your Github token has the right permissions, check the last up-to-date documentation there: https://docs.driftctl.com/github/policy"
case common.RemoteAWSTerraform:
message += "The latest minimal read-only IAM policy for driftctl is always available here, please update yours: https://docs.driftctl.com/aws/policy"
case common.RemoteGoogleTerraform:
message += "Please ensure that you have configured the required roles, please check our documentation at https://docs.driftctl.com/google/policy"
default:
return ""
}
return message
}
func sendRemoteAccessDeniedAlert(provider string, alerter alerter.AlerterInterface, listError *remoteerror.ResourceScanningError, p ScanningPhase) {
logrus.WithFields(logrus.Fields{
"resource": listError.Resource(),
"listed_type": listError.ListedTypeError(),
}).Debugf("Got an access denied error: %+v", listError.Error())
alerter.SendAlert(listError.Resource(), NewRemoteAccessDeniedAlert(provider, listError, p))
}
func SendEnumerationAlert(provider string, alerter alerter.AlerterInterface, listError *remoteerror.ResourceScanningError) {
sendRemoteAccessDeniedAlert(provider, alerter, listError, EnumerationPhase)
}
func SendDetailsFetchingAlert(provider string, alerter alerter.AlerterInterface, listError *remoteerror.ResourceScanningError) {
sendRemoteAccessDeniedAlert(provider, alerter, listError, DetailsFetchingPhase)
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayAccountEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayAccountEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayAccountEnumerator {
return &ApiGatewayAccountEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayAccountEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayAccountResourceType
}
func (e *ApiGatewayAccountEnumerator) Enumerate() ([]*resource.Resource, error) {
account, err := e.repository.GetAccount()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, 1)
if account != nil {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
"api-gateway-account",
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayApiKeyEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayApiKeyEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayApiKeyEnumerator {
return &ApiGatewayApiKeyEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayApiKeyEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayApiKeyResourceType
}
func (e *ApiGatewayApiKeyEnumerator) Enumerate() ([]*resource.Resource, error) {
keys, err := e.repository.ListAllApiKeys()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(keys))
for _, key := range keys {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*key.Id,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,56 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayAuthorizerEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayAuthorizerEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayAuthorizerEnumerator {
return &ApiGatewayAuthorizerEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayAuthorizerEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayAuthorizerResourceType
}
func (e *ApiGatewayAuthorizerEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllRestApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, api := range apis {
a := api
authorizers, err := e.repository.ListAllRestApiAuthorizers(*a.Id)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, authorizer := range authorizers {
au := authorizer
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*au.Id,
map[string]interface{}{},
),
)
}
}
return results, err
}

View File

@ -0,0 +1,64 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"strings"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayBasePathMappingEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayBasePathMappingEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayBasePathMappingEnumerator {
return &ApiGatewayBasePathMappingEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayBasePathMappingEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayBasePathMappingResourceType
}
func (e *ApiGatewayBasePathMappingEnumerator) Enumerate() ([]*resource.Resource, error) {
domainNames, err := e.repository.ListAllDomainNames()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayDomainNameResourceType)
}
results := make([]*resource.Resource, 0)
for _, domainName := range domainNames {
d := domainName
mappings, err := e.repository.ListAllDomainNameBasePathMappings(*d.DomainName)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, mapping := range mappings {
m := mapping
basePath := ""
if m.BasePath != nil && *m.BasePath != "(none)" {
basePath = *m.BasePath
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
strings.Join([]string{*d.DomainName, basePath}, "/"),
map[string]interface{}{},
),
)
}
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayDomainNameEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayDomainNameEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayDomainNameEnumerator {
return &ApiGatewayDomainNameEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayDomainNameEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayDomainNameResourceType
}
func (e *ApiGatewayDomainNameEnumerator) Enumerate() ([]*resource.Resource, error) {
domainNames, err := e.repository.ListAllDomainNames()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(domainNames))
for _, domainName := range domainNames {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*domainName.DomainName,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,57 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"strings"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayGatewayResponseEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayGatewayResponseEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayGatewayResponseEnumerator {
return &ApiGatewayGatewayResponseEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayGatewayResponseEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayGatewayResponseResourceType
}
func (e *ApiGatewayGatewayResponseEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllRestApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, api := range apis {
a := api
gtwResponses, err := e.repository.ListAllRestApiGatewayResponses(*a.Id)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, gtwResponse := range gtwResponses {
g := gtwResponse
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
strings.Join([]string{"aggr", *a.Id, *g.ResponseType}, "-"),
map[string]interface{}{},
),
)
}
}
return results, err
}

View File

@ -0,0 +1,59 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"strings"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayIntegrationEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayIntegrationEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayIntegrationEnumerator {
return &ApiGatewayIntegrationEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayIntegrationEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayIntegrationResourceType
}
func (e *ApiGatewayIntegrationEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllRestApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, api := range apis {
a := api
resources, err := e.repository.ListAllRestApiResources(*a.Id)
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayResourceResourceType)
}
for _, resource := range resources {
r := resource
for httpMethod := range r.ResourceMethods {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
strings.Join([]string{"agi", *a.Id, *r.Id, httpMethod}, "-"),
map[string]interface{}{},
),
)
}
}
}
return results, err
}

View File

@ -0,0 +1,63 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"strings"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayIntegrationResponseEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayIntegrationResponseEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayIntegrationResponseEnumerator {
return &ApiGatewayIntegrationResponseEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayIntegrationResponseEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayIntegrationResponseResourceType
}
func (e *ApiGatewayIntegrationResponseEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllRestApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, api := range apis {
a := api
resources, err := e.repository.ListAllRestApiResources(*a.Id)
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayResourceResourceType)
}
for _, resource := range resources {
r := resource
for httpMethod, method := range r.ResourceMethods {
if method.MethodIntegration != nil {
for statusCode := range method.MethodIntegration.IntegrationResponses {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
strings.Join([]string{"agir", *a.Id, *r.Id, httpMethod, statusCode}, "-"),
map[string]interface{}{},
),
)
}
}
}
}
}
return results, err
}

View File

@ -0,0 +1,59 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"strings"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayMethodEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayMethodEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayMethodEnumerator {
return &ApiGatewayMethodEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayMethodEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayMethodResourceType
}
func (e *ApiGatewayMethodEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllRestApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, api := range apis {
a := api
resources, err := e.repository.ListAllRestApiResources(*a.Id)
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayResourceResourceType)
}
for _, resource := range resources {
r := resource
for method := range r.ResourceMethods {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
strings.Join([]string{"agm", *a.Id, *r.Id, method}, "-"),
map[string]interface{}{},
),
)
}
}
}
return results, err
}

View File

@ -0,0 +1,61 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"strings"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayMethodResponseEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayMethodResponseEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayMethodResponseEnumerator {
return &ApiGatewayMethodResponseEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayMethodResponseEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayMethodResponseResourceType
}
func (e *ApiGatewayMethodResponseEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllRestApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, api := range apis {
a := api
resources, err := e.repository.ListAllRestApiResources(*a.Id)
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayResourceResourceType)
}
for _, resource := range resources {
r := resource
for httpMethod, method := range r.ResourceMethods {
for statusCode := range method.MethodResponses {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
strings.Join([]string{"agmr", *a.Id, *r.Id, httpMethod, statusCode}, "-"),
map[string]interface{}{},
),
)
}
}
}
}
return results, err
}

View File

@ -0,0 +1,59 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"strings"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayMethodSettingsEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayMethodSettingsEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayMethodSettingsEnumerator {
return &ApiGatewayMethodSettingsEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayMethodSettingsEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayMethodSettingsResourceType
}
func (e *ApiGatewayMethodSettingsEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllRestApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, api := range apis {
a := api
stages, err := e.repository.ListAllRestApiStages(*a.Id)
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayStageResourceType)
}
for _, stage := range stages {
s := stage
for methodPath := range s.MethodSettings {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
strings.Join([]string{*a.Id, *s.StageName, methodPath}, "-"),
map[string]interface{}{},
),
)
}
}
}
return results, err
}

View File

@ -0,0 +1,55 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayModelEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayModelEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayModelEnumerator {
return &ApiGatewayModelEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayModelEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayModelResourceType
}
func (e *ApiGatewayModelEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllRestApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, api := range apis {
a := api
models, err := e.repository.ListAllRestApiModels(*a.Id)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, model := range models {
m := model
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*m.Id,
map[string]interface{}{},
),
)
}
}
return results, err
}

View File

@ -0,0 +1,55 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayRequestValidatorEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayRequestValidatorEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayRequestValidatorEnumerator {
return &ApiGatewayRequestValidatorEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayRequestValidatorEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayRequestValidatorResourceType
}
func (e *ApiGatewayRequestValidatorEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllRestApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, api := range apis {
a := api
requestValidators, err := e.repository.ListAllRestApiRequestValidators(*a.Id)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, requestValidator := range requestValidators {
r := requestValidator
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*r.Id,
map[string]interface{}{},
),
)
}
}
return results, err
}

View File

@ -0,0 +1,58 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayResourceEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayResourceEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayResourceEnumerator {
return &ApiGatewayResourceEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayResourceEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayResourceResourceType
}
func (e *ApiGatewayResourceEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllRestApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, api := range apis {
a := api
resources, err := e.repository.ListAllRestApiResources(*a.Id)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, resource := range resources {
r := resource
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*r.Id,
map[string]interface{}{
"rest_api_id": *a.Id,
"path": *r.Path,
},
),
)
}
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayRestApiEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayRestApiEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayRestApiEnumerator {
return &ApiGatewayRestApiEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayRestApiEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayRestApiResourceType
}
func (e *ApiGatewayRestApiEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllRestApis()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(apis))
for _, api := range apis {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*api.Id,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,49 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayRestApiPolicyEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayRestApiPolicyEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayRestApiPolicyEnumerator {
return &ApiGatewayRestApiPolicyEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayRestApiPolicyEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayRestApiPolicyResourceType
}
func (e *ApiGatewayRestApiPolicyEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllRestApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, api := range apis {
a := api
if a.Policy == nil || *a.Policy == "" {
continue
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*a.Id,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,57 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"strings"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayStageEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayStageEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayStageEnumerator {
return &ApiGatewayStageEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayStageEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayStageResourceType
}
func (e *ApiGatewayStageEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllRestApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, api := range apis {
a := api
stages, err := e.repository.ListAllRestApiStages(*a.Id)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, stage := range stages {
s := stage
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
strings.Join([]string{"ags", *a.Id, *s.StageName}, "-"),
map[string]interface{}{},
),
)
}
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayVpcLinkEnumerator struct {
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayVpcLinkEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayVpcLinkEnumerator {
return &ApiGatewayVpcLinkEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayVpcLinkEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayVpcLinkResourceType
}
func (e *ApiGatewayVpcLinkEnumerator) Enumerate() ([]*resource.Resource, error) {
vpcLinks, err := e.repository.ListAllVpcLinks()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(vpcLinks))
for _, vpcLink := range vpcLinks {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*vpcLink.Id,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayV2ApiEnumerator struct {
repository repository.ApiGatewayV2Repository
factory resource.ResourceFactory
}
func NewApiGatewayV2ApiEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2ApiEnumerator {
return &ApiGatewayV2ApiEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayV2ApiEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayV2ApiResourceType
}
func (e *ApiGatewayV2ApiEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllApis()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(apis))
for _, api := range apis {
a := api
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*a.ApiId,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,56 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayV2AuthorizerEnumerator struct {
repository repository.ApiGatewayV2Repository
factory resource.ResourceFactory
}
func NewApiGatewayV2AuthorizerEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2AuthorizerEnumerator {
return &ApiGatewayV2AuthorizerEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayV2AuthorizerEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayV2AuthorizerResourceType
}
func (e *ApiGatewayV2AuthorizerEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, api := range apis {
a := api
authorizers, err := e.repository.ListAllApiAuthorizers(*a.ApiId)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, authorizer := range authorizers {
au := authorizer
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*au.AuthorizerId,
map[string]interface{}{},
),
)
}
}
return results, err
}

View File

@ -0,0 +1,50 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayV2DeploymentEnumerator struct {
repository repository.ApiGatewayV2Repository
factory resource.ResourceFactory
}
func NewApiGatewayV2DeploymentEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2DeploymentEnumerator {
return &ApiGatewayV2DeploymentEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayV2DeploymentEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayV2DeploymentResourceType
}
func (e *ApiGatewayV2DeploymentEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType)
}
var results []*resource.Resource
for _, api := range apis {
deployments, err := e.repository.ListAllApiDeployments(api.ApiId)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, deployment := range deployments {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*deployment.DeploymentId,
map[string]interface{}{},
),
)
}
}
return results, err
}

View File

@ -0,0 +1,49 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayV2DomainNameEnumerator struct {
// AWS SDK list domain names endpoint from API Gateway v2 returns the
// same results as the v1 one, thus let's re-use the method from
// the API Gateway v1
repository repository.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayV2DomainNameEnumerator(repo repository.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayV2DomainNameEnumerator {
return &ApiGatewayV2DomainNameEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayV2DomainNameEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayV2DomainNameResourceType
}
func (e *ApiGatewayV2DomainNameEnumerator) Enumerate() ([]*resource.Resource, error) {
domainNames, err := e.repository.ListAllDomainNames()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(domainNames))
for _, domainName := range domainNames {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*domainName.DomainName,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,63 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayV2IntegrationEnumerator struct {
repository repository.ApiGatewayV2Repository
factory resource.ResourceFactory
}
func NewApiGatewayV2IntegrationEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2IntegrationEnumerator {
return &ApiGatewayV2IntegrationEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayV2IntegrationEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayV2IntegrationResourceType
}
func (e *ApiGatewayV2IntegrationEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, a := range apis {
api := a
integrations, err := e.repository.ListAllApiIntegrations(*api.ApiId)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, integration := range integrations {
data := map[string]interface{}{
"api_id": *api.ApiId,
"integration_type": *integration.IntegrationType,
}
if integration.IntegrationMethod != nil {
// this is needed to discriminate in middleware. But it is nil when the type is mock...
data["integration_method"] = *integration.IntegrationMethod
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*integration.IntegrationId,
data,
),
)
}
}
return results, err
}

View File

@ -0,0 +1,63 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayV2IntegrationResponseEnumerator struct {
repository repository.ApiGatewayV2Repository
factory resource.ResourceFactory
}
func NewApiGatewayV2IntegrationResponseEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2IntegrationResponseEnumerator {
return &ApiGatewayV2IntegrationResponseEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayV2IntegrationResponseEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayV2IntegrationResponseResourceType
}
func (e *ApiGatewayV2IntegrationResponseEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, a := range apis {
apiID := *a.ApiId
integrations, err := e.repository.ListAllApiIntegrations(apiID)
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2IntegrationResourceType)
}
for _, integration := range integrations {
integrationId := *integration.IntegrationId
responses, err := e.repository.ListAllApiIntegrationResponses(apiID, integrationId)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, resp := range responses {
responseId := *resp.IntegrationResponseId
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
responseId,
map[string]interface{}{},
),
)
}
}
}
return results, err
}

View File

@ -0,0 +1,61 @@
package aws
import (
repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayV2MappingEnumerator struct {
repository repository2.ApiGatewayV2Repository
repositoryV1 repository2.ApiGatewayRepository
factory resource.ResourceFactory
}
func NewApiGatewayV2MappingEnumerator(repo repository2.ApiGatewayV2Repository, repov1 repository2.ApiGatewayRepository, factory resource.ResourceFactory) *ApiGatewayV2MappingEnumerator {
return &ApiGatewayV2MappingEnumerator{
repository: repo,
repositoryV1: repov1,
factory: factory,
}
}
func (e *ApiGatewayV2MappingEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayV2MappingResourceType
}
func (e *ApiGatewayV2MappingEnumerator) Enumerate() ([]*resource.Resource, error) {
domainNames, err := e.repositoryV1.ListAllDomainNames()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayDomainNameResourceType)
}
var results []*resource.Resource
for _, domainName := range domainNames {
mappings, err := e.repository.ListAllApiMappings(*domainName.DomainName)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, mapping := range mappings {
attrs := make(map[string]interface{})
if mapping.ApiId != nil {
attrs["api_id"] = *mapping.ApiId
}
if mapping.Stage != nil {
attrs["stage"] = *mapping.Stage
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*mapping.ApiMappingId,
attrs,
),
)
}
}
return results, err
}

View File

@ -0,0 +1,52 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayV2ModelEnumerator struct {
repository repository.ApiGatewayV2Repository
factory resource.ResourceFactory
}
func NewApiGatewayV2ModelEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2ModelEnumerator {
return &ApiGatewayV2ModelEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayV2ModelEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayV2ModelResourceType
}
func (e *ApiGatewayV2ModelEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType)
}
var results []*resource.Resource
for _, api := range apis {
models, err := e.repository.ListAllApiModels(*api.ApiId)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, model := range models {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*model.ModelId,
map[string]interface{}{
"name": *model.Name,
},
),
)
}
}
return results, err
}

View File

@ -0,0 +1,53 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayV2RouteEnumerator struct {
repository repository.ApiGatewayV2Repository
factory resource.ResourceFactory
}
func NewApiGatewayV2RouteEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2RouteEnumerator {
return &ApiGatewayV2RouteEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayV2RouteEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayV2RouteResourceType
}
func (e *ApiGatewayV2RouteEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType)
}
var results []*resource.Resource
for _, api := range apis {
routes, err := e.repository.ListAllApiRoutes(api.ApiId)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, route := range routes {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*route.RouteId,
map[string]interface{}{
"api_id": *api.ApiId,
"route_key": *route.RouteKey,
},
),
)
}
}
return results, err
}

View File

@ -0,0 +1,59 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayV2RouteResponseEnumerator struct {
repository repository.ApiGatewayV2Repository
factory resource.ResourceFactory
}
func NewApiGatewayV2RouteResponseEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2RouteResponseEnumerator {
return &ApiGatewayV2RouteResponseEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayV2RouteResponseEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayV2RouteResponseResourceType
}
func (e *ApiGatewayV2RouteResponseEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType)
}
var results []*resource.Resource
for _, api := range apis {
a := api
routes, err := e.repository.ListAllApiRoutes(a.ApiId)
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2RouteResourceType)
}
for _, route := range routes {
r := route
responses, err := e.repository.ListAllApiRouteResponses(*a.ApiId, *r.RouteId)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, response := range responses {
res := response
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*res.RouteResponseId,
map[string]interface{}{},
),
)
}
}
}
return results, err
}

View File

@ -0,0 +1,54 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayV2StageEnumerator struct {
repository repository.ApiGatewayV2Repository
factory resource.ResourceFactory
}
func NewApiGatewayV2StageEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2StageEnumerator {
return &ApiGatewayV2StageEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayV2StageEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayV2StageResourceType
}
func (e *ApiGatewayV2StageEnumerator) Enumerate() ([]*resource.Resource, error) {
apis, err := e.repository.ListAllApis()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayV2ApiResourceType)
}
results := make([]*resource.Resource, 0)
for _, api := range apis {
stages, err := e.repository.ListAllApiStages(*api.ApiId)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, stage := range stages {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*stage.StageName,
map[string]interface{}{},
),
)
}
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ApiGatewayV2VpcLinkEnumerator struct {
repository repository.ApiGatewayV2Repository
factory resource.ResourceFactory
}
func NewApiGatewayV2VpcLinkEnumerator(repo repository.ApiGatewayV2Repository, factory resource.ResourceFactory) *ApiGatewayV2VpcLinkEnumerator {
return &ApiGatewayV2VpcLinkEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ApiGatewayV2VpcLinkEnumerator) SupportedType() resource.ResourceType {
return aws.AwsApiGatewayV2VpcLinkResourceType
}
func (e *ApiGatewayV2VpcLinkEnumerator) Enumerate() ([]*resource.Resource, error) {
vpcLinks, err := e.repository.ListAllVpcLinks()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(vpcLinks))
for _, vpcLink := range vpcLinks {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*vpcLink.VpcLinkId,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,53 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type AppAutoscalingPolicyEnumerator struct {
repository repository.AppAutoScalingRepository
factory resource.ResourceFactory
}
func NewAppAutoscalingPolicyEnumerator(repository repository.AppAutoScalingRepository, factory resource.ResourceFactory) *AppAutoscalingPolicyEnumerator {
return &AppAutoscalingPolicyEnumerator{
repository,
factory,
}
}
func (e *AppAutoscalingPolicyEnumerator) SupportedType() resource.ResourceType {
return aws.AwsAppAutoscalingPolicyResourceType
}
func (e *AppAutoscalingPolicyEnumerator) Enumerate() ([]*resource.Resource, error) {
results := make([]*resource.Resource, 0)
for _, ns := range e.repository.ServiceNamespaceValues() {
policies, err := e.repository.DescribeScalingPolicies(ns)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, policy := range policies {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*policy.PolicyName,
map[string]interface{}{
"name": *policy.PolicyName,
"resource_id": *policy.ResourceId,
"scalable_dimension": *policy.ScalableDimension,
"service_namespace": *policy.ServiceNamespace,
},
),
)
}
}
return results, nil
}

View File

@ -0,0 +1,50 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"strings"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type AppAutoscalingScheduledActionEnumerator struct {
repository repository.AppAutoScalingRepository
factory resource.ResourceFactory
}
func NewAppAutoscalingScheduledActionEnumerator(repository repository.AppAutoScalingRepository, factory resource.ResourceFactory) *AppAutoscalingScheduledActionEnumerator {
return &AppAutoscalingScheduledActionEnumerator{
repository,
factory,
}
}
func (e *AppAutoscalingScheduledActionEnumerator) SupportedType() resource.ResourceType {
return aws.AwsAppAutoscalingScheduledActionResourceType
}
func (e *AppAutoscalingScheduledActionEnumerator) Enumerate() ([]*resource.Resource, error) {
results := make([]*resource.Resource, 0)
for _, ns := range e.repository.ServiceNamespaceValues() {
actions, err := e.repository.DescribeScheduledActions(ns)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, action := range actions {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
strings.Join([]string{*action.ScheduledActionName, *action.ServiceNamespace, *action.ResourceId}, "-"),
map[string]interface{}{},
),
)
}
}
return results, nil
}

View File

@ -0,0 +1,55 @@
package aws
import (
"github.com/aws/aws-sdk-go/service/applicationautoscaling"
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type AppAutoscalingTargetEnumerator struct {
repository repository.AppAutoScalingRepository
factory resource.ResourceFactory
}
func NewAppAutoscalingTargetEnumerator(repository repository.AppAutoScalingRepository, factory resource.ResourceFactory) *AppAutoscalingTargetEnumerator {
return &AppAutoscalingTargetEnumerator{
repository,
factory,
}
}
func (e *AppAutoscalingTargetEnumerator) SupportedType() resource.ResourceType {
return aws.AwsAppAutoscalingTargetResourceType
}
func (e *AppAutoscalingTargetEnumerator) Enumerate() ([]*resource.Resource, error) {
targets := make([]*applicationautoscaling.ScalableTarget, 0)
for _, ns := range e.repository.ServiceNamespaceValues() {
results, err := e.repository.DescribeScalableTargets(ns)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
targets = append(targets, results...)
}
results := make([]*resource.Resource, 0, len(targets))
for _, target := range targets {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*target.ResourceId,
map[string]interface{}{
"service_namespace": *target.ServiceNamespace,
"scalable_dimension": *target.ScalableDimension,
},
),
)
}
return results, nil
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ClassicLoadBalancerEnumerator struct {
repository repository.ELBRepository
factory resource.ResourceFactory
}
func NewClassicLoadBalancerEnumerator(repo repository.ELBRepository, factory resource.ResourceFactory) *ClassicLoadBalancerEnumerator {
return &ClassicLoadBalancerEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ClassicLoadBalancerEnumerator) SupportedType() resource.ResourceType {
return aws.AwsClassicLoadBalancerResourceType
}
func (e *ClassicLoadBalancerEnumerator) Enumerate() ([]*resource.Resource, error) {
loadBalancers, err := e.repository.ListAllLoadBalancers()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(loadBalancers))
for _, lb := range loadBalancers {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*lb.LoadBalancerName,
map[string]interface{}{},
),
)
}
return results, nil
}

View File

@ -0,0 +1,60 @@
package aws
import (
"github.com/aws/aws-sdk-go/service/cloudformation"
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type CloudformationStackEnumerator struct {
repository repository.CloudformationRepository
factory resource.ResourceFactory
}
func NewCloudformationStackEnumerator(repo repository.CloudformationRepository, factory resource.ResourceFactory) *CloudformationStackEnumerator {
return &CloudformationStackEnumerator{
repository: repo,
factory: factory,
}
}
func (e *CloudformationStackEnumerator) SupportedType() resource.ResourceType {
return aws.AwsCloudformationStackResourceType
}
func (e *CloudformationStackEnumerator) Enumerate() ([]*resource.Resource, error) {
stacks, err := e.repository.ListAllStacks()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(stacks))
for _, stack := range stacks {
attrs := map[string]interface{}{}
if stack.Parameters != nil && len(stack.Parameters) > 0 {
attrs["parameters"] = flattenParameters(stack.Parameters)
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*stack.StackId,
attrs,
),
)
}
return results, err
}
func flattenParameters(parameters []*cloudformation.Parameter) interface{} {
params := make(map[string]interface{}, len(parameters))
for _, p := range parameters {
params[*p.ParameterKey] = *p.ParameterValue
}
return params
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type CloudfrontDistributionEnumerator struct {
repository repository.CloudfrontRepository
factory resource.ResourceFactory
}
func NewCloudfrontDistributionEnumerator(repo repository.CloudfrontRepository, factory resource.ResourceFactory) *CloudfrontDistributionEnumerator {
return &CloudfrontDistributionEnumerator{
repository: repo,
factory: factory,
}
}
func (e *CloudfrontDistributionEnumerator) SupportedType() resource.ResourceType {
return aws.AwsCloudfrontDistributionResourceType
}
func (e *CloudfrontDistributionEnumerator) Enumerate() ([]*resource.Resource, error) {
distributions, err := e.repository.ListAllDistributions()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(distributions))
for _, distribution := range distributions {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*distribution.Id,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,47 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource/aws"
"github.com/snyk/driftctl/enumeration/resource"
)
type DefaultVPCEnumerator struct {
repo repository.EC2Repository
factory resource.ResourceFactory
}
func NewDefaultVPCEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *DefaultVPCEnumerator {
return &DefaultVPCEnumerator{
repo,
factory,
}
}
func (e *DefaultVPCEnumerator) SupportedType() resource.ResourceType {
return aws.AwsDefaultVpcResourceType
}
func (e *DefaultVPCEnumerator) Enumerate() ([]*resource.Resource, error) {
_, defaultVPCs, err := e.repo.ListAllVPCs()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(defaultVPCs))
for _, item := range defaultVPCs {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*item.VpcId,
map[string]interface{}{},
),
)
}
return results, nil
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type DynamoDBTableEnumerator struct {
repository repository.DynamoDBRepository
factory resource.ResourceFactory
}
func NewDynamoDBTableEnumerator(repository repository.DynamoDBRepository, factory resource.ResourceFactory) *DynamoDBTableEnumerator {
return &DynamoDBTableEnumerator{
repository,
factory,
}
}
func (e *DynamoDBTableEnumerator) SupportedType() resource.ResourceType {
return aws.AwsDynamodbTableResourceType
}
func (e *DynamoDBTableEnumerator) Enumerate() ([]*resource.Resource, error) {
tables, err := e.repository.ListAllTables()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(tables))
for _, table := range tables {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*table,
map[string]interface{}{},
),
)
}
return results, nil
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2EbsEncryptionByDefaultEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2EbsEncryptionByDefaultEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2EbsEncryptionByDefaultEnumerator {
return &EC2EbsEncryptionByDefaultEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2EbsEncryptionByDefaultEnumerator) SupportedType() resource.ResourceType {
return aws.AwsEbsEncryptionByDefaultResourceType
}
func (e *EC2EbsEncryptionByDefaultEnumerator) Enumerate() ([]*resource.Resource, error) {
enabled, err := e.repository.IsEbsEncryptionEnabledByDefault()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0)
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
"ebs_encryption_default",
map[string]interface{}{
"enabled": enabled,
},
),
)
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2AmiEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2AmiEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2AmiEnumerator {
return &EC2AmiEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2AmiEnumerator) SupportedType() resource.ResourceType {
return aws.AwsAmiResourceType
}
func (e *EC2AmiEnumerator) Enumerate() ([]*resource.Resource, error) {
images, err := e.repository.ListAllImages()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(images))
for _, image := range images {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*image.ImageId,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,50 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2DefaultNetworkACLEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2DefaultNetworkACLEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2DefaultNetworkACLEnumerator {
return &EC2DefaultNetworkACLEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2DefaultNetworkACLEnumerator) SupportedType() resource.ResourceType {
return aws.AwsDefaultNetworkACLResourceType
}
func (e *EC2DefaultNetworkACLEnumerator) Enumerate() ([]*resource.Resource, error) {
resources, err := e.repository.ListAllNetworkACLs()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(resources))
for _, res := range resources {
// Do not handle non-default network acl since it is a dedicated resource
if !*res.IsDefault {
continue
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*res.NetworkAclId,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,50 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2DefaultRouteTableEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2DefaultRouteTableEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2DefaultRouteTableEnumerator {
return &EC2DefaultRouteTableEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2DefaultRouteTableEnumerator) SupportedType() resource.ResourceType {
return aws.AwsDefaultRouteTableResourceType
}
func (e *EC2DefaultRouteTableEnumerator) Enumerate() ([]*resource.Resource, error) {
routeTables, err := e.repository.ListAllRouteTables()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
var results []*resource.Resource
for _, routeTable := range routeTables {
if isMainRouteTable(routeTable) {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*routeTable.RouteTableId,
map[string]interface{}{
"vpc_id": *routeTable.VpcId,
},
),
)
}
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2DefaultSubnetEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2DefaultSubnetEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2DefaultSubnetEnumerator {
return &EC2DefaultSubnetEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2DefaultSubnetEnumerator) SupportedType() resource.ResourceType {
return aws.AwsDefaultSubnetResourceType
}
func (e *EC2DefaultSubnetEnumerator) Enumerate() ([]*resource.Resource, error) {
_, defaultSubnets, err := e.repository.ListAllSubnets()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(defaultSubnets))
for _, subnet := range defaultSubnets {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*subnet.SubnetId,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2EbsSnapshotEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2EbsSnapshotEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2EbsSnapshotEnumerator {
return &EC2EbsSnapshotEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2EbsSnapshotEnumerator) SupportedType() resource.ResourceType {
return aws.AwsEbsSnapshotResourceType
}
func (e *EC2EbsSnapshotEnumerator) Enumerate() ([]*resource.Resource, error) {
snapshots, err := e.repository.ListAllSnapshots()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(snapshots))
for _, snapshot := range snapshots {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*snapshot.SnapshotId,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2EbsVolumeEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2EbsVolumeEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2EbsVolumeEnumerator {
return &EC2EbsVolumeEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2EbsVolumeEnumerator) SupportedType() resource.ResourceType {
return aws.AwsEbsVolumeResourceType
}
func (e *EC2EbsVolumeEnumerator) Enumerate() ([]*resource.Resource, error) {
volumes, err := e.repository.ListAllVolumes()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(volumes))
for _, volume := range volumes {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*volume.VolumeId,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,48 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2EipAssociationEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2EipAssociationEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2EipAssociationEnumerator {
return &EC2EipAssociationEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2EipAssociationEnumerator) SupportedType() resource.ResourceType {
return aws.AwsEipAssociationResourceType
}
func (e *EC2EipAssociationEnumerator) Enumerate() ([]*resource.Resource, error) {
addresses, err := e.repository.ListAllAddressesAssociation()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(addresses))
for _, address := range addresses {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*address.AssociationId,
map[string]interface{}{
"allocation_id": *address.AllocationId,
},
),
)
}
return results, err
}

View File

@ -0,0 +1,51 @@
package aws
import (
"github.com/sirupsen/logrus"
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2EipEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2EipEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2EipEnumerator {
return &EC2EipEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2EipEnumerator) SupportedType() resource.ResourceType {
return aws.AwsEipResourceType
}
func (e *EC2EipEnumerator) Enumerate() ([]*resource.Resource, error) {
addresses, err := e.repository.ListAllAddresses()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(addresses))
for _, address := range addresses {
if address.AllocationId == nil {
logrus.Warn("Elastic IP does not have an allocation ID, ignoring")
continue
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*address.AllocationId,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2InstanceEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2InstanceEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2InstanceEnumerator {
return &EC2InstanceEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2InstanceEnumerator) SupportedType() resource.ResourceType {
return aws.AwsInstanceResourceType
}
func (e *EC2InstanceEnumerator) Enumerate() ([]*resource.Resource, error) {
instances, err := e.repository.ListAllInstances()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(instances))
for _, instance := range instances {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*instance.InstanceId,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,50 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2InternetGatewayEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2InternetGatewayEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2InternetGatewayEnumerator {
return &EC2InternetGatewayEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2InternetGatewayEnumerator) SupportedType() resource.ResourceType {
return aws.AwsInternetGatewayResourceType
}
func (e *EC2InternetGatewayEnumerator) Enumerate() ([]*resource.Resource, error) {
internetGateways, err := e.repository.ListAllInternetGateways()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(internetGateways))
for _, internetGateway := range internetGateways {
data := map[string]interface{}{}
if len(internetGateway.Attachments) > 0 && internetGateway.Attachments[0].VpcId != nil {
data["vpc_id"] = *internetGateway.Attachments[0].VpcId
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*internetGateway.InternetGatewayId,
data,
),
)
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2KeyPairEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2KeyPairEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2KeyPairEnumerator {
return &EC2KeyPairEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2KeyPairEnumerator) SupportedType() resource.ResourceType {
return aws.AwsKeyPairResourceType
}
func (e *EC2KeyPairEnumerator) Enumerate() ([]*resource.Resource, error) {
keyPairs, err := e.repository.ListAllKeyPairs()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(keyPairs))
for _, keyPair := range keyPairs {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*keyPair.KeyName,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,54 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2NatGatewayEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2NatGatewayEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2NatGatewayEnumerator {
return &EC2NatGatewayEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2NatGatewayEnumerator) SupportedType() resource.ResourceType {
return aws.AwsNatGatewayResourceType
}
func (e *EC2NatGatewayEnumerator) Enumerate() ([]*resource.Resource, error) {
natGateways, err := e.repository.ListAllNatGateways()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(natGateways))
for _, natGateway := range natGateways {
attrs := map[string]interface{}{}
if len(natGateway.NatGatewayAddresses) > 0 {
if allocId := natGateway.NatGatewayAddresses[0].AllocationId; allocId != nil {
attrs["allocation_id"] = *allocId
}
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*natGateway.NatGatewayId,
attrs,
),
)
}
return results, err
}

View File

@ -0,0 +1,50 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2NetworkACLEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2NetworkACLEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2NetworkACLEnumerator {
return &EC2NetworkACLEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2NetworkACLEnumerator) SupportedType() resource.ResourceType {
return aws.AwsNetworkACLResourceType
}
func (e *EC2NetworkACLEnumerator) Enumerate() ([]*resource.Resource, error) {
resources, err := e.repository.ListAllNetworkACLs()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(resources))
for _, res := range resources {
// Do not handle default network acl since it is a dedicated resource
if *res.IsDefault {
continue
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*res.NetworkAclId,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,70 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2NetworkACLRuleEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2NetworkACLRuleEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2NetworkACLRuleEnumerator {
return &EC2NetworkACLRuleEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2NetworkACLRuleEnumerator) SupportedType() resource.ResourceType {
return aws.AwsNetworkACLRuleResourceType
}
func (e *EC2NetworkACLRuleEnumerator) Enumerate() ([]*resource.Resource, error) {
resources, err := e.repository.ListAllNetworkACLs()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsNetworkACLResourceType)
}
results := make([]*resource.Resource, 0, len(resources))
for _, res := range resources {
for _, entry := range res.Entries {
attrs := map[string]interface{}{
"egress": *entry.Egress,
"network_acl_id": *res.NetworkAclId,
"rule_action": *entry.RuleAction, // Used in default middleware
"rule_number": float64(*entry.RuleNumber), // Used in default middleware
"protocol": *entry.Protocol, // Used in default middleware
}
if entry.CidrBlock != nil {
attrs["cidr_block"] = *entry.CidrBlock
}
if entry.Ipv6CidrBlock != nil {
attrs["ipv6_cidr_block"] = *entry.Ipv6CidrBlock
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
aws.CreateNetworkACLRuleID(
*res.NetworkAclId,
int(float64(*entry.RuleNumber)),
*entry.Egress,
*entry.Protocol,
),
attrs,
),
)
}
}
return results, err
}

View File

@ -0,0 +1,66 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2RouteEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2RouteEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2RouteEnumerator {
return &EC2RouteEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2RouteEnumerator) SupportedType() resource.ResourceType {
return aws.AwsRouteResourceType
}
func (e *EC2RouteEnumerator) Enumerate() ([]*resource.Resource, error) {
routeTables, err := e.repository.ListAllRouteTables()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsRouteTableResourceType)
}
var results []*resource.Resource
for _, routeTable := range routeTables {
for _, route := range routeTable.Routes {
routeId := aws.CalculateRouteID(routeTable.RouteTableId, route.DestinationCidrBlock, route.DestinationIpv6CidrBlock, route.DestinationPrefixListId)
data := map[string]interface{}{
"route_table_id": *routeTable.RouteTableId,
"origin": *route.Origin,
}
if route.DestinationCidrBlock != nil && *route.DestinationCidrBlock != "" {
data["destination_cidr_block"] = *route.DestinationCidrBlock
}
if route.DestinationIpv6CidrBlock != nil && *route.DestinationIpv6CidrBlock != "" {
data["destination_ipv6_cidr_block"] = *route.DestinationIpv6CidrBlock
}
if route.DestinationPrefixListId != nil && *route.DestinationPrefixListId != "" {
data["destination_prefix_list_id"] = *route.DestinationPrefixListId
}
if route.GatewayId != nil && *route.GatewayId != "" {
data["gateway_id"] = *route.GatewayId
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
routeId,
data,
),
)
}
}
return results, err
}

View File

@ -0,0 +1,69 @@
package aws
import (
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2RouteTableAssociationEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2RouteTableAssociationEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2RouteTableAssociationEnumerator {
return &EC2RouteTableAssociationEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2RouteTableAssociationEnumerator) SupportedType() resource.ResourceType {
return aws.AwsRouteTableAssociationResourceType
}
func (e *EC2RouteTableAssociationEnumerator) Enumerate() ([]*resource.Resource, error) {
routeTables, err := e.repository.ListAllRouteTables()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsRouteTableResourceType)
}
var results []*resource.Resource
for _, routeTable := range routeTables {
for _, assoc := range routeTable.Associations {
if e.shouldBeIgnored(assoc) {
continue
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*assoc.RouteTableAssociationId,
map[string]interface{}{
"route_table_id": *assoc.RouteTableId,
},
),
)
}
}
return results, err
}
func (e *EC2RouteTableAssociationEnumerator) shouldBeIgnored(assoc *ec2.RouteTableAssociation) bool {
// Ignore when nothing is associated
if assoc.GatewayId == nil && assoc.SubnetId == nil {
return true
}
// Ignore when association is not associated
if assoc.AssociationState != nil && assoc.AssociationState.State != nil &&
*assoc.AssociationState.State != "associated" {
return true
}
return false
}

View File

@ -0,0 +1,58 @@
package aws
import (
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2RouteTableEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2RouteTableEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2RouteTableEnumerator {
return &EC2RouteTableEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2RouteTableEnumerator) SupportedType() resource.ResourceType {
return aws.AwsRouteTableResourceType
}
func (e *EC2RouteTableEnumerator) Enumerate() ([]*resource.Resource, error) {
routeTables, err := e.repository.ListAllRouteTables()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
var results []*resource.Resource
for _, routeTable := range routeTables {
if !isMainRouteTable(routeTable) {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*routeTable.RouteTableId,
map[string]interface{}{},
),
)
}
}
return results, err
}
func isMainRouteTable(routeTable *ec2.RouteTable) bool {
for _, assoc := range routeTable.Associations {
if assoc.Main != nil && *assoc.Main {
return true
}
}
return false
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type EC2SubnetEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewEC2SubnetEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *EC2SubnetEnumerator {
return &EC2SubnetEnumerator{
repository: repo,
factory: factory,
}
}
func (e *EC2SubnetEnumerator) SupportedType() resource.ResourceType {
return aws.AwsSubnetResourceType
}
func (e *EC2SubnetEnumerator) Enumerate() ([]*resource.Resource, error) {
subnets, _, err := e.repository.ListAllSubnets()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(subnets))
for _, subnet := range subnets {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*subnet.SubnetId,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ECRRepositoryEnumerator struct {
repository repository.ECRRepository
factory resource.ResourceFactory
}
func NewECRRepositoryEnumerator(repo repository.ECRRepository, factory resource.ResourceFactory) *ECRRepositoryEnumerator {
return &ECRRepositoryEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ECRRepositoryEnumerator) SupportedType() resource.ResourceType {
return aws.AwsEcrRepositoryResourceType
}
func (e *ECRRepositoryEnumerator) Enumerate() ([]*resource.Resource, error) {
repos, err := e.repository.ListAllRepositories()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(repos))
for _, repo := range repos {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*repo.RepositoryName,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,55 @@
package aws
import (
"github.com/aws/aws-sdk-go/service/ecr"
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ECRRepositoryPolicyEnumerator struct {
repository repository.ECRRepository
factory resource.ResourceFactory
}
func NewECRRepositoryPolicyEnumerator(repo repository.ECRRepository, factory resource.ResourceFactory) *ECRRepositoryPolicyEnumerator {
return &ECRRepositoryPolicyEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ECRRepositoryPolicyEnumerator) SupportedType() resource.ResourceType {
return aws.AwsEcrRepositoryPolicyResourceType
}
func (e *ECRRepositoryPolicyEnumerator) Enumerate() ([]*resource.Resource, error) {
repos, err := e.repository.ListAllRepositories()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsEcrRepositoryResourceType)
}
results := make([]*resource.Resource, 0, len(repos))
for _, repo := range repos {
repoOutput, err := e.repository.GetRepositoryPolicy(repo)
if _, ok := err.(*ecr.RepositoryPolicyNotFoundException); ok {
continue
}
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*repoOutput.RepositoryName,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type ElastiCacheClusterEnumerator struct {
repository repository.ElastiCacheRepository
factory resource.ResourceFactory
}
func NewElastiCacheClusterEnumerator(repo repository.ElastiCacheRepository, factory resource.ResourceFactory) *ElastiCacheClusterEnumerator {
return &ElastiCacheClusterEnumerator{
repository: repo,
factory: factory,
}
}
func (e *ElastiCacheClusterEnumerator) SupportedType() resource.ResourceType {
return aws.AwsElastiCacheClusterResourceType
}
func (e *ElastiCacheClusterEnumerator) Enumerate() ([]*resource.Resource, error) {
clusters, err := e.repository.ListAllCacheClusters()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(clusters))
for _, cluster := range clusters {
c := cluster
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*c.CacheClusterId,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,52 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
resourceaws "github.com/snyk/driftctl/enumeration/resource/aws"
)
type IamAccessKeyEnumerator struct {
repository repository.IAMRepository
factory resource.ResourceFactory
}
func NewIamAccessKeyEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamAccessKeyEnumerator {
return &IamAccessKeyEnumerator{
repository,
factory,
}
}
func (e *IamAccessKeyEnumerator) SupportedType() resource.ResourceType {
return resourceaws.AwsIamAccessKeyResourceType
}
func (e *IamAccessKeyEnumerator) Enumerate() ([]*resource.Resource, error) {
users, err := e.repository.ListAllUsers()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsIamUserResourceType)
}
keys, err := e.repository.ListAllAccessKeys(users)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0)
for _, key := range keys {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*key.AccessKeyId,
map[string]interface{}{
"user": *key.UserName,
},
),
)
}
return results, nil
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type IamGroupEnumerator struct {
repository repository.IAMRepository
factory resource.ResourceFactory
}
func NewIamGroupEnumerator(repo repository.IAMRepository, factory resource.ResourceFactory) *IamGroupEnumerator {
return &IamGroupEnumerator{
repository: repo,
factory: factory,
}
}
func (e *IamGroupEnumerator) SupportedType() resource.ResourceType {
return aws.AwsIamGroupResourceType
}
func (e *IamGroupEnumerator) Enumerate() ([]*resource.Resource, error) {
groups, err := e.repository.ListAllGroups()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsIamGroupResourceType)
}
results := make([]*resource.Resource, 0, len(groups))
for _, group := range groups {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*group.GroupName,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,56 @@
package aws
import (
"fmt"
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
resourceaws "github.com/snyk/driftctl/enumeration/resource/aws"
)
type IamGroupPolicyAttachmentEnumerator struct {
repository repository.IAMRepository
factory resource.ResourceFactory
}
func NewIamGroupPolicyAttachmentEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamGroupPolicyAttachmentEnumerator {
return &IamGroupPolicyAttachmentEnumerator{
repository,
factory,
}
}
func (e *IamGroupPolicyAttachmentEnumerator) SupportedType() resource.ResourceType {
return resourceaws.AwsIamGroupPolicyAttachmentResourceType
}
func (e *IamGroupPolicyAttachmentEnumerator) Enumerate() ([]*resource.Resource, error) {
groups, err := e.repository.ListAllGroups()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsIamGroupResourceType)
}
results := make([]*resource.Resource, 0)
policyAttachments, err := e.repository.ListAllGroupPolicyAttachments(groups)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, attachedPol := range policyAttachments {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
fmt.Sprintf("%s-%s", *attachedPol.PolicyName, attachedPol.GroupName),
map[string]interface{}{
"group": attachedPol.GroupName,
"policy_arn": *attachedPol.PolicyArn,
},
),
)
}
return results, nil
}

View File

@ -0,0 +1,50 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type IamGroupPolicyEnumerator struct {
repository repository.IAMRepository
factory resource.ResourceFactory
}
func NewIamGroupPolicyEnumerator(repo repository.IAMRepository, factory resource.ResourceFactory) *IamGroupPolicyEnumerator {
return &IamGroupPolicyEnumerator{
repository: repo,
factory: factory,
}
}
func (e *IamGroupPolicyEnumerator) SupportedType() resource.ResourceType {
return aws.AwsIamGroupPolicyResourceType
}
func (e *IamGroupPolicyEnumerator) Enumerate() ([]*resource.Resource, error) {
groups, err := e.repository.ListAllGroups()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsIamGroupResourceType)
}
groupPolicies, err := e.repository.ListAllGroupPolicies(groups)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(groupPolicies))
for _, groupPolicy := range groupPolicies {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
groupPolicy,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,47 @@
package aws
import (
awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type IamPolicyEnumerator struct {
repository repository.IAMRepository
factory resource.ResourceFactory
}
func NewIamPolicyEnumerator(repo repository.IAMRepository, factory resource.ResourceFactory) *IamPolicyEnumerator {
return &IamPolicyEnumerator{
repository: repo,
factory: factory,
}
}
func (e *IamPolicyEnumerator) SupportedType() resource.ResourceType {
return aws.AwsIamPolicyResourceType
}
func (e *IamPolicyEnumerator) Enumerate() ([]*resource.Resource, error) {
policies, err := e.repository.ListAllPolicies()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(policies))
for _, policy := range policies {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
awssdk.StringValue(policy.Arn),
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,65 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
resourceaws "github.com/snyk/driftctl/enumeration/resource/aws"
)
var iamRoleExclusionList = map[string]struct{}{
// Enabled by default for aws to enable support, not removable
"AWSServiceRoleForSupport": {},
// Enabled and not removable for every org account
"AWSServiceRoleForOrganizations": {},
// Not manageable by IaC and set by default
"AWSServiceRoleForTrustedAdvisor": {},
}
type IamRoleEnumerator struct {
repository repository.IAMRepository
factory resource.ResourceFactory
}
func NewIamRoleEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamRoleEnumerator {
return &IamRoleEnumerator{
repository,
factory,
}
}
func (e *IamRoleEnumerator) SupportedType() resource.ResourceType {
return resourceaws.AwsIamRoleResourceType
}
func awsIamRoleShouldBeIgnored(roleName string) bool {
_, ok := iamRoleExclusionList[roleName]
return ok
}
func (e *IamRoleEnumerator) Enumerate() ([]*resource.Resource, error) {
roles, err := e.repository.ListAllRoles()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0)
for _, role := range roles {
if role.RoleName != nil && awsIamRoleShouldBeIgnored(*role.RoleName) {
continue
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*role.RoleName,
map[string]interface{}{
"path": *role.Path,
},
),
)
}
return results, nil
}

View File

@ -0,0 +1,69 @@
package aws
import (
"fmt"
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/snyk/driftctl/enumeration/resource"
resourceaws "github.com/snyk/driftctl/enumeration/resource/aws"
)
type IamRolePolicyAttachmentEnumerator struct {
repository repository.IAMRepository
factory resource.ResourceFactory
}
func NewIamRolePolicyAttachmentEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamRolePolicyAttachmentEnumerator {
return &IamRolePolicyAttachmentEnumerator{
repository,
factory,
}
}
func (e *IamRolePolicyAttachmentEnumerator) SupportedType() resource.ResourceType {
return resourceaws.AwsIamRolePolicyAttachmentResourceType
}
func (e *IamRolePolicyAttachmentEnumerator) Enumerate() ([]*resource.Resource, error) {
roles, err := e.repository.ListAllRoles()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsIamRoleResourceType)
}
results := make([]*resource.Resource, 0)
rolesNotIgnored := make([]*iam.Role, 0)
for _, role := range roles {
if role.RoleName != nil && awsIamRoleShouldBeIgnored(*role.RoleName) {
continue
}
rolesNotIgnored = append(rolesNotIgnored, role)
}
if len(rolesNotIgnored) == 0 {
return results, nil
}
policyAttachments, err := e.repository.ListAllRolePolicyAttachments(rolesNotIgnored)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, attachedPol := range policyAttachments {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
fmt.Sprintf("%s-%s", *attachedPol.PolicyName, attachedPol.RoleName),
map[string]interface{}{
"role": attachedPol.RoleName,
"policy_arn": *attachedPol.PolicyArn,
},
),
)
}
return results, nil
}

View File

@ -0,0 +1,54 @@
package aws
import (
"fmt"
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
resourceaws "github.com/snyk/driftctl/enumeration/resource/aws"
)
type IamRolePolicyEnumerator struct {
repository repository.IAMRepository
factory resource.ResourceFactory
}
func NewIamRolePolicyEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamRolePolicyEnumerator {
return &IamRolePolicyEnumerator{
repository,
factory,
}
}
func (e *IamRolePolicyEnumerator) SupportedType() resource.ResourceType {
return resourceaws.AwsIamRolePolicyResourceType
}
func (e *IamRolePolicyEnumerator) Enumerate() ([]*resource.Resource, error) {
roles, err := e.repository.ListAllRoles()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsIamRoleResourceType)
}
policies, err := e.repository.ListAllRolePolicies(roles)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(policies))
for _, policy := range policies {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
fmt.Sprintf("%s:%s", policy.RoleName, policy.Policy),
map[string]interface{}{
"role": policy.RoleName,
},
),
)
}
return results, nil
}

View File

@ -0,0 +1,47 @@
package aws
import (
awssdk "github.com/aws/aws-sdk-go/aws"
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type IamUserEnumerator struct {
repository repository.IAMRepository
factory resource.ResourceFactory
}
func NewIamUserEnumerator(repo repository.IAMRepository, factory resource.ResourceFactory) *IamUserEnumerator {
return &IamUserEnumerator{
repository: repo,
factory: factory,
}
}
func (e *IamUserEnumerator) SupportedType() resource.ResourceType {
return aws.AwsIamUserResourceType
}
func (e *IamUserEnumerator) Enumerate() ([]*resource.Resource, error) {
users, err := e.repository.ListAllUsers()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(users))
for _, user := range users {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
awssdk.StringValue(user.UserName),
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,55 @@
package aws
import (
"fmt"
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
resourceaws "github.com/snyk/driftctl/enumeration/resource/aws"
)
type IamUserPolicyAttachmentEnumerator struct {
repository repository.IAMRepository
factory resource.ResourceFactory
}
func NewIamUserPolicyAttachmentEnumerator(repository repository.IAMRepository, factory resource.ResourceFactory) *IamUserPolicyAttachmentEnumerator {
return &IamUserPolicyAttachmentEnumerator{
repository,
factory,
}
}
func (e *IamUserPolicyAttachmentEnumerator) SupportedType() resource.ResourceType {
return resourceaws.AwsIamUserPolicyAttachmentResourceType
}
func (e *IamUserPolicyAttachmentEnumerator) Enumerate() ([]*resource.Resource, error) {
users, err := e.repository.ListAllUsers()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), resourceaws.AwsIamUserResourceType)
}
results := make([]*resource.Resource, 0)
policyAttachments, err := e.repository.ListAllUserPolicyAttachments(users)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, attachedPol := range policyAttachments {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
fmt.Sprintf("%s-%s", *attachedPol.PolicyName, attachedPol.UserName),
map[string]interface{}{
"user": attachedPol.UserName,
"policy_arn": *attachedPol.PolicyArn,
},
),
)
}
return results, nil
}

View File

@ -0,0 +1,50 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type IamUserPolicyEnumerator struct {
repository repository.IAMRepository
factory resource.ResourceFactory
}
func NewIamUserPolicyEnumerator(repo repository.IAMRepository, factory resource.ResourceFactory) *IamUserPolicyEnumerator {
return &IamUserPolicyEnumerator{
repository: repo,
factory: factory,
}
}
func (e *IamUserPolicyEnumerator) SupportedType() resource.ResourceType {
return aws.AwsIamUserPolicyResourceType
}
func (e *IamUserPolicyEnumerator) Enumerate() ([]*resource.Resource, error) {
users, err := e.repository.ListAllUsers()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsIamUserResourceType)
}
userPolicies, err := e.repository.ListAllUserPolicies(users)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(userPolicies))
for _, userPolicy := range userPolicies {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
userPolicy,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,258 @@
package aws
import (
"github.com/snyk/driftctl/enumeration"
"github.com/snyk/driftctl/enumeration/alerter"
"github.com/snyk/driftctl/enumeration/remote/aws/client"
repository2 "github.com/snyk/driftctl/enumeration/remote/aws/repository"
"github.com/snyk/driftctl/enumeration/remote/cache"
common2 "github.com/snyk/driftctl/enumeration/remote/common"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
"github.com/snyk/driftctl/enumeration/terraform"
)
/**
* Initialize remote (configure credentials, launch tf providers and start gRPC clients)
* Required to use Scanner
*/
func Init(version string, alerter *alerter.Alerter,
providerLibrary *terraform.ProviderLibrary,
remoteLibrary *common2.RemoteLibrary,
progress enumeration.ProgressCounter,
resourceSchemaRepository *resource.SchemaRepository,
factory resource.ResourceFactory,
configDir string) error {
provider, err := NewAWSTerraformProvider(version, progress, configDir)
if err != nil {
return err
}
err = provider.CheckCredentialsExist()
if err != nil {
return err
}
err = provider.Init()
if err != nil {
return err
}
repositoryCache := cache.New(100)
s3Repository := repository2.NewS3Repository(client.NewAWSClientFactory(provider.session), repositoryCache)
ec2repository := repository2.NewEC2Repository(provider.session, repositoryCache)
elbv2Repository := repository2.NewELBV2Repository(provider.session, repositoryCache)
route53repository := repository2.NewRoute53Repository(provider.session, repositoryCache)
lambdaRepository := repository2.NewLambdaRepository(provider.session, repositoryCache)
rdsRepository := repository2.NewRDSRepository(provider.session, repositoryCache)
sqsRepository := repository2.NewSQSRepository(provider.session, repositoryCache)
snsRepository := repository2.NewSNSRepository(provider.session, repositoryCache)
cloudfrontRepository := repository2.NewCloudfrontRepository(provider.session, repositoryCache)
dynamoDBRepository := repository2.NewDynamoDBRepository(provider.session, repositoryCache)
ecrRepository := repository2.NewECRRepository(provider.session, repositoryCache)
kmsRepository := repository2.NewKMSRepository(provider.session, repositoryCache)
iamRepository := repository2.NewIAMRepository(provider.session, repositoryCache)
cloudformationRepository := repository2.NewCloudformationRepository(provider.session, repositoryCache)
apigatewayRepository := repository2.NewApiGatewayRepository(provider.session, repositoryCache)
appAutoScalingRepository := repository2.NewAppAutoScalingRepository(provider.session, repositoryCache)
apigatewayv2Repository := repository2.NewApiGatewayV2Repository(provider.session, repositoryCache)
autoscalingRepository := repository2.NewAutoScalingRepository(provider.session, repositoryCache)
elbRepository := repository2.NewELBRepository(provider.session, repositoryCache)
elasticacheRepository := repository2.NewElastiCacheRepository(provider.session, repositoryCache)
deserializer := resource.NewDeserializer(factory)
providerLibrary.AddProvider(terraform.AWS, provider)
remoteLibrary.AddEnumerator(NewS3BucketEnumerator(s3Repository, factory, provider.Config, alerter))
remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketResourceType, common2.NewGenericDetailsFetcher(aws.AwsS3BucketResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewS3BucketInventoryEnumerator(s3Repository, factory, provider.Config, alerter))
remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketInventoryResourceType, common2.NewGenericDetailsFetcher(aws.AwsS3BucketInventoryResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewS3BucketNotificationEnumerator(s3Repository, factory, provider.Config, alerter))
remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketNotificationResourceType, common2.NewGenericDetailsFetcher(aws.AwsS3BucketNotificationResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewS3BucketMetricsEnumerator(s3Repository, factory, provider.Config, alerter))
remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketMetricResourceType, common2.NewGenericDetailsFetcher(aws.AwsS3BucketMetricResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewS3BucketPolicyEnumerator(s3Repository, factory, provider.Config, alerter))
remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketPolicyResourceType, common2.NewGenericDetailsFetcher(aws.AwsS3BucketPolicyResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewS3BucketAnalyticEnumerator(s3Repository, factory, provider.Config, alerter))
remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, common2.NewGenericDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewS3BucketPublicAccessBlockEnumerator(s3Repository, factory, provider.Config, alerter))
remoteLibrary.AddEnumerator(NewEC2EbsVolumeEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsEbsVolumeResourceType, common2.NewGenericDetailsFetcher(aws.AwsEbsVolumeResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2EbsSnapshotEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsEbsSnapshotResourceType, common2.NewGenericDetailsFetcher(aws.AwsEbsSnapshotResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2EipEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsEipResourceType, common2.NewGenericDetailsFetcher(aws.AwsEipResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2AmiEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsAmiResourceType, common2.NewGenericDetailsFetcher(aws.AwsAmiResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2KeyPairEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsKeyPairResourceType, common2.NewGenericDetailsFetcher(aws.AwsKeyPairResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2EipAssociationEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsEipAssociationResourceType, common2.NewGenericDetailsFetcher(aws.AwsEipAssociationResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2InstanceEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsInstanceResourceType, common2.NewGenericDetailsFetcher(aws.AwsInstanceResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2InternetGatewayEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsInternetGatewayResourceType, common2.NewGenericDetailsFetcher(aws.AwsInternetGatewayResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewVPCEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsVpcResourceType, common2.NewGenericDetailsFetcher(aws.AwsVpcResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewDefaultVPCEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsDefaultVpcResourceType, common2.NewGenericDetailsFetcher(aws.AwsDefaultVpcResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2RouteTableEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsRouteTableResourceType, common2.NewGenericDetailsFetcher(aws.AwsRouteTableResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2DefaultRouteTableEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsDefaultRouteTableResourceType, common2.NewGenericDetailsFetcher(aws.AwsDefaultRouteTableResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2RouteTableAssociationEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsRouteTableAssociationResourceType, common2.NewGenericDetailsFetcher(aws.AwsRouteTableAssociationResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2SubnetEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsSubnetResourceType, common2.NewGenericDetailsFetcher(aws.AwsSubnetResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2DefaultSubnetEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsDefaultSubnetResourceType, common2.NewGenericDetailsFetcher(aws.AwsDefaultSubnetResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewVPCSecurityGroupEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsSecurityGroupResourceType, common2.NewGenericDetailsFetcher(aws.AwsSecurityGroupResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewVPCDefaultSecurityGroupEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsDefaultSecurityGroupResourceType, common2.NewGenericDetailsFetcher(aws.AwsDefaultSecurityGroupResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2NatGatewayEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsNatGatewayResourceType, common2.NewGenericDetailsFetcher(aws.AwsNatGatewayResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2NetworkACLEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsNetworkACLResourceType, common2.NewGenericDetailsFetcher(aws.AwsNetworkACLResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2NetworkACLRuleEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsNetworkACLRuleResourceType, common2.NewGenericDetailsFetcher(aws.AwsNetworkACLRuleResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2DefaultNetworkACLEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsDefaultNetworkACLResourceType, common2.NewGenericDetailsFetcher(aws.AwsDefaultNetworkACLResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2RouteEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsRouteResourceType, common2.NewGenericDetailsFetcher(aws.AwsRouteResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewVPCSecurityGroupRuleEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsSecurityGroupRuleResourceType, common2.NewGenericDetailsFetcher(aws.AwsSecurityGroupRuleResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewLaunchTemplateEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsLaunchTemplateResourceType, common2.NewGenericDetailsFetcher(aws.AwsLaunchTemplateResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewEC2EbsEncryptionByDefaultEnumerator(ec2repository, factory))
remoteLibrary.AddEnumerator(NewKMSKeyEnumerator(kmsRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsKmsKeyResourceType, common2.NewGenericDetailsFetcher(aws.AwsKmsKeyResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewKMSAliasEnumerator(kmsRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsKmsAliasResourceType, common2.NewGenericDetailsFetcher(aws.AwsKmsAliasResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewRoute53HealthCheckEnumerator(route53repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsRoute53HealthCheckResourceType, common2.NewGenericDetailsFetcher(aws.AwsRoute53HealthCheckResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewRoute53ZoneEnumerator(route53repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsRoute53ZoneResourceType, common2.NewGenericDetailsFetcher(aws.AwsRoute53ZoneResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewRoute53RecordEnumerator(route53repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsRoute53RecordResourceType, common2.NewGenericDetailsFetcher(aws.AwsRoute53RecordResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewCloudfrontDistributionEnumerator(cloudfrontRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsCloudfrontDistributionResourceType, common2.NewGenericDetailsFetcher(aws.AwsCloudfrontDistributionResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewRDSDBInstanceEnumerator(rdsRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsDbInstanceResourceType, common2.NewGenericDetailsFetcher(aws.AwsDbInstanceResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewRDSDBSubnetGroupEnumerator(rdsRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsDbSubnetGroupResourceType, common2.NewGenericDetailsFetcher(aws.AwsDbSubnetGroupResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewSQSQueueEnumerator(sqsRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsSqsQueueResourceType, NewSQSQueueDetailsFetcher(provider, deserializer))
remoteLibrary.AddEnumerator(NewSQSQueuePolicyEnumerator(sqsRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsSqsQueuePolicyResourceType, common2.NewGenericDetailsFetcher(aws.AwsSqsQueuePolicyResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewSNSTopicEnumerator(snsRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsSnsTopicResourceType, common2.NewGenericDetailsFetcher(aws.AwsSnsTopicResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewSNSTopicPolicyEnumerator(snsRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsSnsTopicPolicyResourceType, common2.NewGenericDetailsFetcher(aws.AwsSnsTopicPolicyResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewSNSTopicSubscriptionEnumerator(snsRepository, factory, alerter))
remoteLibrary.AddDetailsFetcher(aws.AwsSnsTopicSubscriptionResourceType, common2.NewGenericDetailsFetcher(aws.AwsSnsTopicSubscriptionResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewDynamoDBTableEnumerator(dynamoDBRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsDynamodbTableResourceType, common2.NewGenericDetailsFetcher(aws.AwsDynamodbTableResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewIamPolicyEnumerator(iamRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsIamPolicyResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamPolicyResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewLambdaFunctionEnumerator(lambdaRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsLambdaFunctionResourceType, common2.NewGenericDetailsFetcher(aws.AwsLambdaFunctionResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewLambdaEventSourceMappingEnumerator(lambdaRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsLambdaEventSourceMappingResourceType, common2.NewGenericDetailsFetcher(aws.AwsLambdaEventSourceMappingResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewIamUserEnumerator(iamRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsIamUserResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamUserResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewIamUserPolicyEnumerator(iamRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsIamUserPolicyResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamUserPolicyResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewIamRoleEnumerator(iamRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsIamRoleResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamRoleResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewIamAccessKeyEnumerator(iamRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsIamAccessKeyResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamAccessKeyResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewIamRolePolicyAttachmentEnumerator(iamRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsIamRolePolicyAttachmentResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamRolePolicyAttachmentResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewIamRolePolicyEnumerator(iamRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsIamRolePolicyResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamRolePolicyResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewIamUserPolicyAttachmentEnumerator(iamRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsIamUserPolicyAttachmentResourceType, common2.NewGenericDetailsFetcher(aws.AwsIamUserPolicyAttachmentResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewIamGroupPolicyEnumerator(iamRepository, factory))
remoteLibrary.AddEnumerator(NewIamGroupEnumerator(iamRepository, factory))
remoteLibrary.AddEnumerator(NewIamGroupPolicyAttachmentEnumerator(iamRepository, factory))
remoteLibrary.AddEnumerator(NewECRRepositoryEnumerator(ecrRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsEcrRepositoryResourceType, common2.NewGenericDetailsFetcher(aws.AwsEcrRepositoryResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewECRRepositoryPolicyEnumerator(ecrRepository, factory))
remoteLibrary.AddEnumerator(NewRDSClusterEnumerator(rdsRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsRDSClusterResourceType, common2.NewGenericDetailsFetcher(aws.AwsRDSClusterResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewCloudformationStackEnumerator(cloudformationRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsCloudformationStackResourceType, common2.NewGenericDetailsFetcher(aws.AwsCloudformationStackResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewApiGatewayRestApiEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayAccountEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayApiKeyEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayAuthorizerEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayStageEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayResourceEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayDomainNameEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayVpcLinkEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayRequestValidatorEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayRestApiPolicyEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayBasePathMappingEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayMethodEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayModelEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayMethodResponseEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayGatewayResponseEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayMethodSettingsEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayIntegrationEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayIntegrationResponseEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayV2ApiEnumerator(apigatewayv2Repository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayV2RouteEnumerator(apigatewayv2Repository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayV2DeploymentEnumerator(apigatewayv2Repository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayV2VpcLinkEnumerator(apigatewayv2Repository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayV2AuthorizerEnumerator(apigatewayv2Repository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayV2IntegrationEnumerator(apigatewayv2Repository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayV2ModelEnumerator(apigatewayv2Repository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayV2StageEnumerator(apigatewayv2Repository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayV2RouteResponseEnumerator(apigatewayv2Repository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayV2MappingEnumerator(apigatewayv2Repository, apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayV2DomainNameEnumerator(apigatewayRepository, factory))
remoteLibrary.AddEnumerator(NewApiGatewayV2IntegrationResponseEnumerator(apigatewayv2Repository, factory))
remoteLibrary.AddEnumerator(NewAppAutoscalingTargetEnumerator(appAutoScalingRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsAppAutoscalingTargetResourceType, common2.NewGenericDetailsFetcher(aws.AwsAppAutoscalingTargetResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewAppAutoscalingPolicyEnumerator(appAutoScalingRepository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsAppAutoscalingPolicyResourceType, common2.NewGenericDetailsFetcher(aws.AwsAppAutoscalingPolicyResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewAppAutoscalingScheduledActionEnumerator(appAutoScalingRepository, factory))
remoteLibrary.AddEnumerator(NewLaunchConfigurationEnumerator(autoscalingRepository, factory))
remoteLibrary.AddEnumerator(NewLoadBalancerEnumerator(elbv2Repository, factory))
remoteLibrary.AddEnumerator(NewLoadBalancerListenerEnumerator(elbv2Repository, factory))
remoteLibrary.AddEnumerator(NewClassicLoadBalancerEnumerator(elbRepository, factory))
remoteLibrary.AddEnumerator(NewElastiCacheClusterEnumerator(elasticacheRepository, factory))
err = resourceSchemaRepository.Init(terraform.AWS, provider.Version(), provider.Schema())
if err != nil {
return err
}
aws.InitResourcesMetadata(resourceSchemaRepository)
return nil
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type KMSAliasEnumerator struct {
repository repository.KMSRepository
factory resource.ResourceFactory
}
func NewKMSAliasEnumerator(repo repository.KMSRepository, factory resource.ResourceFactory) *KMSAliasEnumerator {
return &KMSAliasEnumerator{
repository: repo,
factory: factory,
}
}
func (e *KMSAliasEnumerator) SupportedType() resource.ResourceType {
return aws.AwsKmsAliasResourceType
}
func (e *KMSAliasEnumerator) Enumerate() ([]*resource.Resource, error) {
aliases, err := e.repository.ListAllAliases()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(aliases))
for _, alias := range aliases {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*alias.AliasName,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type KMSKeyEnumerator struct {
repository repository.KMSRepository
factory resource.ResourceFactory
}
func NewKMSKeyEnumerator(repo repository.KMSRepository, factory resource.ResourceFactory) *KMSKeyEnumerator {
return &KMSKeyEnumerator{
repository: repo,
factory: factory,
}
}
func (e *KMSKeyEnumerator) SupportedType() resource.ResourceType {
return aws.AwsKmsKeyResourceType
}
func (e *KMSKeyEnumerator) Enumerate() ([]*resource.Resource, error) {
keys, err := e.repository.ListAllKeys()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(keys))
for _, key := range keys {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*key.KeyId,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
resourceaws "github.com/snyk/driftctl/enumeration/resource/aws"
)
type LambdaEventSourceMappingEnumerator struct {
repository repository.LambdaRepository
factory resource.ResourceFactory
}
func NewLambdaEventSourceMappingEnumerator(repo repository.LambdaRepository, factory resource.ResourceFactory) *LambdaEventSourceMappingEnumerator {
return &LambdaEventSourceMappingEnumerator{
repository: repo,
factory: factory,
}
}
func (e *LambdaEventSourceMappingEnumerator) SupportedType() resource.ResourceType {
return resourceaws.AwsLambdaEventSourceMappingResourceType
}
func (e *LambdaEventSourceMappingEnumerator) Enumerate() ([]*resource.Resource, error) {
eventSourceMappings, err := e.repository.ListAllLambdaEventSourceMappings()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(eventSourceMappings))
for _, eventSourceMapping := range eventSourceMappings {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*eventSourceMapping.UUID,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
resourceaws "github.com/snyk/driftctl/enumeration/resource/aws"
)
type LambdaFunctionEnumerator struct {
repository repository.LambdaRepository
factory resource.ResourceFactory
}
func NewLambdaFunctionEnumerator(repo repository.LambdaRepository, factory resource.ResourceFactory) *LambdaFunctionEnumerator {
return &LambdaFunctionEnumerator{
repository: repo,
factory: factory,
}
}
func (e *LambdaFunctionEnumerator) SupportedType() resource.ResourceType {
return resourceaws.AwsLambdaFunctionResourceType
}
func (e *LambdaFunctionEnumerator) Enumerate() ([]*resource.Resource, error) {
functions, err := e.repository.ListAllLambdaFunctions()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(functions))
for _, function := range functions {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*function.FunctionName,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type LaunchConfigurationEnumerator struct {
repository repository.AutoScalingRepository
factory resource.ResourceFactory
}
func NewLaunchConfigurationEnumerator(repo repository.AutoScalingRepository, factory resource.ResourceFactory) *LaunchConfigurationEnumerator {
return &LaunchConfigurationEnumerator{
repository: repo,
factory: factory,
}
}
func (e *LaunchConfigurationEnumerator) SupportedType() resource.ResourceType {
return aws.AwsLaunchConfigurationResourceType
}
func (e *LaunchConfigurationEnumerator) Enumerate() ([]*resource.Resource, error) {
configs, err := e.repository.DescribeLaunchConfigurations()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(configs))
for _, config := range configs {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*config.LaunchConfigurationName,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type LaunchTemplateEnumerator struct {
repository repository.EC2Repository
factory resource.ResourceFactory
}
func NewLaunchTemplateEnumerator(repo repository.EC2Repository, factory resource.ResourceFactory) *LaunchTemplateEnumerator {
return &LaunchTemplateEnumerator{
repository: repo,
factory: factory,
}
}
func (e *LaunchTemplateEnumerator) SupportedType() resource.ResourceType {
return aws.AwsLaunchTemplateResourceType
}
func (e *LaunchTemplateEnumerator) Enumerate() ([]*resource.Resource, error) {
templates, err := e.repository.DescribeLaunchTemplates()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(templates))
for _, tmpl := range templates {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*tmpl.LaunchTemplateId,
map[string]interface{}{},
),
)
}
return results, nil
}

View File

@ -0,0 +1,48 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type LoadBalancerEnumerator struct {
repository repository.ELBV2Repository
factory resource.ResourceFactory
}
func NewLoadBalancerEnumerator(repo repository.ELBV2Repository, factory resource.ResourceFactory) *LoadBalancerEnumerator {
return &LoadBalancerEnumerator{
repository: repo,
factory: factory,
}
}
func (e *LoadBalancerEnumerator) SupportedType() resource.ResourceType {
return aws.AwsLoadBalancerResourceType
}
func (e *LoadBalancerEnumerator) Enumerate() ([]*resource.Resource, error) {
loadBalancers, err := e.repository.ListAllLoadBalancers()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(loadBalancers))
for _, lb := range loadBalancers {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*lb.LoadBalancerArn,
map[string]interface{}{
"name": *lb.LoadBalancerName,
},
),
)
}
return results, err
}

View File

@ -0,0 +1,53 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type LoadBalancerListenerEnumerator struct {
repository repository.ELBV2Repository
factory resource.ResourceFactory
}
func NewLoadBalancerListenerEnumerator(repo repository.ELBV2Repository, factory resource.ResourceFactory) *LoadBalancerListenerEnumerator {
return &LoadBalancerListenerEnumerator{
repository: repo,
factory: factory,
}
}
func (e *LoadBalancerListenerEnumerator) SupportedType() resource.ResourceType {
return aws.AwsLoadBalancerListenerResourceType
}
func (e *LoadBalancerListenerEnumerator) Enumerate() ([]*resource.Resource, error) {
loadBalancers, err := e.repository.ListAllLoadBalancers()
if err != nil {
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsLoadBalancerResourceType)
}
results := make([]*resource.Resource, 0)
for _, lb := range loadBalancers {
listeners, err := e.repository.ListAllLoadBalancerListeners(*lb.LoadBalancerArn)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, listener := range listeners {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*listener.ListenerArn,
map[string]interface{}{},
),
)
}
}
return results, nil
}

View File

@ -0,0 +1,119 @@
package aws
import (
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/snyk/driftctl/enumeration"
"github.com/snyk/driftctl/enumeration/remote/terraform"
tf "github.com/snyk/driftctl/enumeration/terraform"
)
type awsConfig struct {
AccessKey string
SecretKey string
CredsFilename string
Profile string
Token string
Region string `cty:"region"`
MaxRetries int
AssumeRoleARN string
AssumeRoleExternalID string
AssumeRoleSessionName string
AssumeRolePolicy string
AllowedAccountIds []string
ForbiddenAccountIds []string
Endpoints map[string]string
IgnoreTagsConfig map[string]string
Insecure bool
SkipCredsValidation bool `cty:"skip_credentials_validation"`
SkipGetEC2Platforms bool
SkipRegionValidation bool
SkipRequestingAccountId bool `cty:"skip_requesting_account_id"`
SkipMetadataApiCheck bool
S3ForcePathStyle bool
}
type AWSTerraformProvider struct {
*terraform.TerraformProvider
session *session.Session
name string
version string
}
func NewAWSTerraformProvider(version string, progress enumeration.ProgressCounter, configDir string) (*AWSTerraformProvider, error) {
if version == "" {
version = "3.19.0"
}
p := &AWSTerraformProvider{
version: version,
name: "aws",
}
installer, err := tf.NewProviderInstaller(tf.ProviderConfig{
Key: p.name,
Version: version,
ConfigDir: configDir,
})
if err != nil {
return nil, err
}
p.session = session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))
tfProvider, err := terraform.NewTerraformProvider(installer, terraform.TerraformProviderConfig{
Name: p.name,
DefaultAlias: *p.session.Config.Region,
GetProviderConfig: func(alias string) interface{} {
return awsConfig{
Region: alias,
// Those two parameters are used to make sure that the credentials are not validated when calling
// Configure(). Credentials validation is now handled directly in driftctl
SkipCredsValidation: true,
SkipRequestingAccountId: true,
MaxRetries: 10, // TODO make this configurable
}
},
}, progress)
if err != nil {
return nil, err
}
p.TerraformProvider = tfProvider
return p, err
}
func (a *AWSTerraformProvider) Name() string {
return a.name
}
func (p *AWSTerraformProvider) Version() string {
return p.version
}
func (p *AWSTerraformProvider) CheckCredentialsExist() error {
_, err := p.session.Config.Credentials.Get()
if err == credentials.ErrNoValidProvidersFoundInChain {
return errors.New("Could not find a way to authenticate on AWS!\n" +
"Please refer to AWS documentation: https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html\n\n" +
"To use a different cloud provider, use --to=\"gcp+tf\" for GCP or --to=\"azure+tf\" for Azure.")
}
if err != nil {
return err
}
// This call is to make sure that the credentials are valid
// A more complex logic exist in terraform provider, but it's probably not worth to implement it
// https://github.com/hashicorp/terraform-provider-aws/blob/e3959651092864925045a6044961a73137095798/aws/auth_helpers.go#L111
_, err = sts.New(p.session).GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
logrus.Debug(err)
return errors.New("Could not authenticate successfully on AWS with the provided credentials.\n" +
"Please refer to the AWS documentation: https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html\n")
}
return nil
}

View File

@ -0,0 +1,55 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type RDSClusterEnumerator struct {
repository repository.RDSRepository
factory resource.ResourceFactory
}
func NewRDSClusterEnumerator(repository repository.RDSRepository, factory resource.ResourceFactory) *RDSClusterEnumerator {
return &RDSClusterEnumerator{
repository,
factory,
}
}
func (e *RDSClusterEnumerator) SupportedType() resource.ResourceType {
return aws.AwsRDSClusterResourceType
}
func (e *RDSClusterEnumerator) Enumerate() ([]*resource.Resource, error) {
clusters, err := e.repository.ListAllDBClusters()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(clusters))
for _, cluster := range clusters {
var databaseName string
if v := cluster.DatabaseName; v != nil {
databaseName = *cluster.DatabaseName
}
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*cluster.DBClusterIdentifier,
map[string]interface{}{
"cluster_identifier": *cluster.DBClusterIdentifier,
"database_name": databaseName,
},
),
)
}
return results, nil
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type RDSDBInstanceEnumerator struct {
repository repository.RDSRepository
factory resource.ResourceFactory
}
func NewRDSDBInstanceEnumerator(repo repository.RDSRepository, factory resource.ResourceFactory) *RDSDBInstanceEnumerator {
return &RDSDBInstanceEnumerator{
repository: repo,
factory: factory,
}
}
func (e *RDSDBInstanceEnumerator) SupportedType() resource.ResourceType {
return aws.AwsDbInstanceResourceType
}
func (e *RDSDBInstanceEnumerator) Enumerate() ([]*resource.Resource, error) {
instances, err := e.repository.ListAllDBInstances()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(instances))
for _, instance := range instances {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*instance.DBInstanceIdentifier,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,46 @@
package aws
import (
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
"github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws"
)
type RDSDBSubnetGroupEnumerator struct {
repository repository.RDSRepository
factory resource.ResourceFactory
}
func NewRDSDBSubnetGroupEnumerator(repo repository.RDSRepository, factory resource.ResourceFactory) *RDSDBSubnetGroupEnumerator {
return &RDSDBSubnetGroupEnumerator{
repository: repo,
factory: factory,
}
}
func (e *RDSDBSubnetGroupEnumerator) SupportedType() resource.ResourceType {
return aws.AwsDbSubnetGroupResourceType
}
func (e *RDSDBSubnetGroupEnumerator) Enumerate() ([]*resource.Resource, error) {
subnetGroups, err := e.repository.ListAllDBSubnetGroups()
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0, len(subnetGroups))
for _, subnetGroup := range subnetGroups {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*subnetGroup.DBSubnetGroupName,
map[string]interface{}{},
),
)
}
return results, err
}

View File

@ -0,0 +1,285 @@
package repository
import (
"fmt"
"github.com/snyk/driftctl/enumeration/remote/cache"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/apigateway"
"github.com/aws/aws-sdk-go/service/apigateway/apigatewayiface"
)
type ApiGatewayRepository interface {
ListAllRestApis() ([]*apigateway.RestApi, error)
GetAccount() (*apigateway.Account, error)
ListAllApiKeys() ([]*apigateway.ApiKey, error)
ListAllRestApiAuthorizers(string) ([]*apigateway.Authorizer, error)
ListAllRestApiStages(string) ([]*apigateway.Stage, error)
ListAllRestApiResources(string) ([]*apigateway.Resource, error)
ListAllDomainNames() ([]*apigateway.DomainName, error)
ListAllVpcLinks() ([]*apigateway.UpdateVpcLinkOutput, error)
ListAllRestApiRequestValidators(string) ([]*apigateway.UpdateRequestValidatorOutput, error)
ListAllDomainNameBasePathMappings(string) ([]*apigateway.BasePathMapping, error)
ListAllRestApiModels(string) ([]*apigateway.Model, error)
ListAllRestApiGatewayResponses(string) ([]*apigateway.UpdateGatewayResponseOutput, error)
}
type apigatewayRepository struct {
client apigatewayiface.APIGatewayAPI
cache cache.Cache
}
func NewApiGatewayRepository(session *session.Session, c cache.Cache) *apigatewayRepository {
return &apigatewayRepository{
apigateway.New(session),
c,
}
}
func (r *apigatewayRepository) ListAllRestApis() ([]*apigateway.RestApi, error) {
cacheKey := "apigatewayListAllRestApis"
v := r.cache.GetAndLock(cacheKey)
defer r.cache.Unlock(cacheKey)
if v != nil {
return v.([]*apigateway.RestApi), nil
}
var restApis []*apigateway.RestApi
input := apigateway.GetRestApisInput{}
err := r.client.GetRestApisPages(&input,
func(resp *apigateway.GetRestApisOutput, lastPage bool) bool {
restApis = append(restApis, resp.Items...)
return !lastPage
},
)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, restApis)
return restApis, nil
}
func (r *apigatewayRepository) GetAccount() (*apigateway.Account, error) {
if v := r.cache.Get("apigatewayGetAccount"); v != nil {
return v.(*apigateway.Account), nil
}
account, err := r.client.GetAccount(&apigateway.GetAccountInput{})
if err != nil {
return nil, err
}
r.cache.Put("apigatewayGetAccount", account)
return account, nil
}
func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) {
if v := r.cache.Get("apigatewayListAllApiKeys"); v != nil {
return v.([]*apigateway.ApiKey), nil
}
var apiKeys []*apigateway.ApiKey
input := apigateway.GetApiKeysInput{}
err := r.client.GetApiKeysPages(&input,
func(resp *apigateway.GetApiKeysOutput, lastPage bool) bool {
apiKeys = append(apiKeys, resp.Items...)
return !lastPage
},
)
if err != nil {
return nil, err
}
r.cache.Put("apigatewayListAllApiKeys", apiKeys)
return apiKeys, nil
}
func (r *apigatewayRepository) ListAllRestApiAuthorizers(apiId string) ([]*apigateway.Authorizer, error) {
cacheKey := fmt.Sprintf("apigatewayListAllRestApiAuthorizers_api_%s", apiId)
if v := r.cache.Get(cacheKey); v != nil {
return v.([]*apigateway.Authorizer), nil
}
input := &apigateway.GetAuthorizersInput{
RestApiId: &apiId,
}
resources, err := r.client.GetAuthorizers(input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Items)
return resources.Items, nil
}
func (r *apigatewayRepository) ListAllRestApiStages(apiId string) ([]*apigateway.Stage, error) {
cacheKey := fmt.Sprintf("apigatewayListAllRestApiStages_api_%s", apiId)
v := r.cache.GetAndLock(cacheKey)
defer r.cache.Unlock(cacheKey)
if v != nil {
return v.([]*apigateway.Stage), nil
}
input := &apigateway.GetStagesInput{
RestApiId: &apiId,
}
resources, err := r.client.GetStages(input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Item)
return resources.Item, nil
}
func (r *apigatewayRepository) ListAllRestApiResources(apiId string) ([]*apigateway.Resource, error) {
cacheKey := fmt.Sprintf("apigatewayListAllRestApiResources_api_%s", apiId)
v := r.cache.GetAndLock(cacheKey)
defer r.cache.Unlock(cacheKey)
if v != nil {
return v.([]*apigateway.Resource), nil
}
var resources []*apigateway.Resource
input := &apigateway.GetResourcesInput{
RestApiId: &apiId,
Embed: []*string{aws.String("methods")},
}
err := r.client.GetResourcesPages(input, func(res *apigateway.GetResourcesOutput, lastPage bool) bool {
resources = append(resources, res.Items...)
return !lastPage
})
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources)
return resources, nil
}
func (r *apigatewayRepository) ListAllDomainNames() ([]*apigateway.DomainName, error) {
cacheKey := "apigatewayListAllDomainNames"
v := r.cache.GetAndLock(cacheKey)
defer r.cache.Unlock(cacheKey)
if v != nil {
return v.([]*apigateway.DomainName), nil
}
var domainNames []*apigateway.DomainName
input := apigateway.GetDomainNamesInput{}
err := r.client.GetDomainNamesPages(&input,
func(resp *apigateway.GetDomainNamesOutput, lastPage bool) bool {
domainNames = append(domainNames, resp.Items...)
return !lastPage
},
)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, domainNames)
return domainNames, nil
}
func (r *apigatewayRepository) ListAllVpcLinks() ([]*apigateway.UpdateVpcLinkOutput, error) {
if v := r.cache.Get("apigatewayListAllVpcLinks"); v != nil {
return v.([]*apigateway.UpdateVpcLinkOutput), nil
}
var vpcLinks []*apigateway.UpdateVpcLinkOutput
input := apigateway.GetVpcLinksInput{}
err := r.client.GetVpcLinksPages(&input,
func(resp *apigateway.GetVpcLinksOutput, lastPage bool) bool {
vpcLinks = append(vpcLinks, resp.Items...)
return !lastPage
},
)
if err != nil {
return nil, err
}
r.cache.Put("apigatewayListAllVpcLinks", vpcLinks)
return vpcLinks, nil
}
func (r *apigatewayRepository) ListAllRestApiRequestValidators(apiId string) ([]*apigateway.UpdateRequestValidatorOutput, error) {
cacheKey := fmt.Sprintf("apigatewayListAllRestApiRequestValidators_api_%s", apiId)
if v := r.cache.Get(cacheKey); v != nil {
return v.([]*apigateway.UpdateRequestValidatorOutput), nil
}
input := &apigateway.GetRequestValidatorsInput{
RestApiId: &apiId,
}
resources, err := r.client.GetRequestValidators(input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Items)
return resources.Items, nil
}
func (r *apigatewayRepository) ListAllDomainNameBasePathMappings(domainName string) ([]*apigateway.BasePathMapping, error) {
cacheKey := fmt.Sprintf("apigatewayListAllDomainNameBasePathMappings_domainName_%s", domainName)
if v := r.cache.Get(cacheKey); v != nil {
return v.([]*apigateway.BasePathMapping), nil
}
var mappings []*apigateway.BasePathMapping
input := &apigateway.GetBasePathMappingsInput{
DomainName: &domainName,
}
err := r.client.GetBasePathMappingsPages(input, func(res *apigateway.GetBasePathMappingsOutput, lastPage bool) bool {
mappings = append(mappings, res.Items...)
return !lastPage
})
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, mappings)
return mappings, nil
}
func (r *apigatewayRepository) ListAllRestApiModels(apiId string) ([]*apigateway.Model, error) {
cacheKey := fmt.Sprintf("apigatewayListAllRestApiModels_api_%s", apiId)
if v := r.cache.Get(cacheKey); v != nil {
return v.([]*apigateway.Model), nil
}
var resources []*apigateway.Model
input := &apigateway.GetModelsInput{
RestApiId: &apiId,
}
err := r.client.GetModelsPages(input, func(res *apigateway.GetModelsOutput, lastPage bool) bool {
resources = append(resources, res.Items...)
return !lastPage
})
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources)
return resources, nil
}
func (r *apigatewayRepository) ListAllRestApiGatewayResponses(apiId string) ([]*apigateway.UpdateGatewayResponseOutput, error) {
cacheKey := fmt.Sprintf("apigatewayListAllRestApiGatewayResponses_api_%s", apiId)
if v := r.cache.Get(cacheKey); v != nil {
return v.([]*apigateway.UpdateGatewayResponseOutput), nil
}
input := &apigateway.GetGatewayResponsesInput{
RestApiId: &apiId,
}
resources, err := r.client.GetGatewayResponses(input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Items)
return resources.Items, nil
}

View File

@ -0,0 +1,890 @@
package repository
import (
"github.com/snyk/driftctl/enumeration/remote/cache"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/apigateway"
"github.com/pkg/errors"
awstest "github.com/snyk/driftctl/test/aws"
"github.com/stretchr/testify/mock"
"github.com/r3labs/diff/v2"
"github.com/stretchr/testify/assert"
)
func Test_apigatewayRepository_ListAllRestApis(t *testing.T) {
apis := []*apigateway.RestApi{
{Id: aws.String("restapi1")},
{Id: aws.String("restapi2")},
{Id: aws.String("restapi3")},
{Id: aws.String("restapi4")},
{Id: aws.String("restapi5")},
{Id: aws.String("restapi6")},
}
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache)
want []*apigateway.RestApi
wantErr error
}{
{
name: "list multiple rest apis",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetRestApisPages",
&apigateway.GetRestApisInput{},
mock.MatchedBy(func(callback func(res *apigateway.GetRestApisOutput, lastPage bool) bool) bool {
callback(&apigateway.GetRestApisOutput{
Items: apis[:3],
}, false)
callback(&apigateway.GetRestApisOutput{
Items: apis[3:],
}, true)
return true
})).Return(nil).Once()
store.On("GetAndLock", "apigatewayListAllRestApis").Return(nil).Times(1)
store.On("Unlock", "apigatewayListAllRestApis").Times(1)
store.On("Put", "apigatewayListAllRestApis", apis).Return(false).Times(1)
},
want: apis,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
store.On("GetAndLock", "apigatewayListAllRestApis").Return(apis).Times(1)
store.On("Unlock", "apigatewayListAllRestApis").Times(1)
},
want: apis,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGateway{}
tt.mocks(client, store)
r := &apigatewayRepository{
client: client,
cache: store,
}
got, err := r.ListAllRestApis()
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayRepository_GetAccount(t *testing.T) {
account := &apigateway.Account{
CloudwatchRoleArn: aws.String("arn:aws:iam::017011014111:role/api_gateway_cloudwatch_global"),
}
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache)
want *apigateway.Account
wantErr error
}{
{
name: "get a single account",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetAccount", &apigateway.GetAccountInput{}).Return(account, nil).Once()
store.On("Get", "apigatewayGetAccount").Return(nil).Times(1)
store.On("Put", "apigatewayGetAccount", account).Return(false).Times(1)
},
want: account,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
store.On("Get", "apigatewayGetAccount").Return(account).Times(1)
},
want: account,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGateway{}
tt.mocks(client, store)
r := &apigatewayRepository{
client: client,
cache: store,
}
got, err := r.GetAccount()
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayRepository_ListAllApiKeys(t *testing.T) {
keys := []*apigateway.ApiKey{
{Id: aws.String("apikey1")},
{Id: aws.String("apikey2")},
{Id: aws.String("apikey3")},
{Id: aws.String("apikey4")},
{Id: aws.String("apikey5")},
{Id: aws.String("apikey6")},
}
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache)
want []*apigateway.ApiKey
wantErr error
}{
{
name: "list multiple api keys",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetApiKeysPages",
&apigateway.GetApiKeysInput{},
mock.MatchedBy(func(callback func(res *apigateway.GetApiKeysOutput, lastPage bool) bool) bool {
callback(&apigateway.GetApiKeysOutput{
Items: keys[:3],
}, false)
callback(&apigateway.GetApiKeysOutput{
Items: keys[3:],
}, true)
return true
})).Return(nil).Once()
store.On("Get", "apigatewayListAllApiKeys").Return(nil).Times(1)
store.On("Put", "apigatewayListAllApiKeys", keys).Return(false).Times(1)
},
want: keys,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
store.On("Get", "apigatewayListAllApiKeys").Return(keys).Times(1)
},
want: keys,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGateway{}
tt.mocks(client, store)
r := &apigatewayRepository{
client: client,
cache: store,
}
got, err := r.ListAllApiKeys()
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayRepository_ListAllRestApiAuthorizers(t *testing.T) {
api := &apigateway.RestApi{
Id: aws.String("restapi1"),
}
apiAuthorizers := []*apigateway.Authorizer{
{Id: aws.String("resource1")},
{Id: aws.String("resource2")},
{Id: aws.String("resource3")},
{Id: aws.String("resource4")},
}
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache)
want []*apigateway.Authorizer
wantErr error
}{
{
name: "list multiple rest api authorizers",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetAuthorizers",
&apigateway.GetAuthorizersInput{
RestApiId: aws.String("restapi1"),
}).Return(&apigateway.GetAuthorizersOutput{Items: apiAuthorizers}, nil).Once()
store.On("Get", "apigatewayListAllRestApiAuthorizers_api_restapi1").Return(nil).Times(1)
store.On("Put", "apigatewayListAllRestApiAuthorizers_api_restapi1", apiAuthorizers).Return(false).Times(1)
},
want: apiAuthorizers,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
store.On("Get", "apigatewayListAllRestApiAuthorizers_api_restapi1").Return(apiAuthorizers).Times(1)
},
want: apiAuthorizers,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGateway{}
tt.mocks(client, store)
r := &apigatewayRepository{
client: client,
cache: store,
}
got, err := r.ListAllRestApiAuthorizers(*api.Id)
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayRepository_ListAllRestApiStages(t *testing.T) {
api := &apigateway.RestApi{
Id: aws.String("restapi1"),
}
apiStages := []*apigateway.Stage{
{StageName: aws.String("stage1")},
{StageName: aws.String("stage2")},
{StageName: aws.String("stage3")},
{StageName: aws.String("stage4")},
}
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache)
want []*apigateway.Stage
wantErr error
}{
{
name: "list multiple rest api stages",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetStages",
&apigateway.GetStagesInput{
RestApiId: aws.String("restapi1"),
}).Return(&apigateway.GetStagesOutput{Item: apiStages}, nil).Once()
store.On("GetAndLock", "apigatewayListAllRestApiStages_api_restapi1").Return(nil).Times(1)
store.On("Unlock", "apigatewayListAllRestApiStages_api_restapi1").Times(1)
store.On("Put", "apigatewayListAllRestApiStages_api_restapi1", apiStages).Return(false).Times(1)
},
want: apiStages,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
store.On("GetAndLock", "apigatewayListAllRestApiStages_api_restapi1").Return(apiStages).Times(1)
store.On("Unlock", "apigatewayListAllRestApiStages_api_restapi1").Times(1)
},
want: apiStages,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGateway{}
tt.mocks(client, store)
r := &apigatewayRepository{
client: client,
cache: store,
}
got, err := r.ListAllRestApiStages(*api.Id)
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayRepository_ListAllRestApiResources(t *testing.T) {
api := &apigateway.RestApi{
Id: aws.String("restapi1"),
}
apiResources := []*apigateway.Resource{
{Id: aws.String("resource1")},
{Id: aws.String("resource2")},
{Id: aws.String("resource3")},
{Id: aws.String("resource4")},
}
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache)
want []*apigateway.Resource
wantErr error
}{
{
name: "list multiple rest api resources",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetResourcesPages",
&apigateway.GetResourcesInput{
RestApiId: aws.String("restapi1"),
Embed: []*string{aws.String("methods")},
},
mock.MatchedBy(func(callback func(res *apigateway.GetResourcesOutput, lastPage bool) bool) bool {
callback(&apigateway.GetResourcesOutput{
Items: apiResources,
}, true)
return true
})).Return(nil).Once()
store.On("GetAndLock", "apigatewayListAllRestApiResources_api_restapi1").Return(nil).Times(1)
store.On("Unlock", "apigatewayListAllRestApiResources_api_restapi1").Times(1)
store.On("Put", "apigatewayListAllRestApiResources_api_restapi1", apiResources).Return(false).Times(1)
},
want: apiResources,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
store.On("GetAndLock", "apigatewayListAllRestApiResources_api_restapi1").Return(apiResources).Times(1)
store.On("Unlock", "apigatewayListAllRestApiResources_api_restapi1").Times(1)
},
want: apiResources,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGateway{}
tt.mocks(client, store)
r := &apigatewayRepository{
client: client,
cache: store,
}
got, err := r.ListAllRestApiResources(*api.Id)
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayRepository_ListAllDomainNames(t *testing.T) {
domainNames := []*apigateway.DomainName{
{DomainName: aws.String("domainName1")},
{DomainName: aws.String("domainName2")},
{DomainName: aws.String("domainName3")},
{DomainName: aws.String("domainName4")},
{DomainName: aws.String("domainName5")},
{DomainName: aws.String("domainName6")},
}
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache)
want []*apigateway.DomainName
wantErr error
}{
{
name: "list multiple domain names",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetDomainNamesPages",
&apigateway.GetDomainNamesInput{},
mock.MatchedBy(func(callback func(res *apigateway.GetDomainNamesOutput, lastPage bool) bool) bool {
callback(&apigateway.GetDomainNamesOutput{
Items: domainNames[:3],
}, false)
callback(&apigateway.GetDomainNamesOutput{
Items: domainNames[3:],
}, true)
return true
})).Return(nil).Once()
store.On("GetAndLock", "apigatewayListAllDomainNames").Return(nil).Times(1)
store.On("Unlock", "apigatewayListAllDomainNames").Times(1)
store.On("Put", "apigatewayListAllDomainNames", domainNames).Return(false).Times(1)
},
want: domainNames,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
store.On("GetAndLock", "apigatewayListAllDomainNames").Return(domainNames).Times(1)
store.On("Unlock", "apigatewayListAllDomainNames").Times(1)
},
want: domainNames,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGateway{}
tt.mocks(client, store)
r := &apigatewayRepository{
client: client,
cache: store,
}
got, err := r.ListAllDomainNames()
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayRepository_ListAllVpcLinks(t *testing.T) {
vpcLinks := []*apigateway.UpdateVpcLinkOutput{
{Id: aws.String("vpcLink1")},
{Id: aws.String("vpcLink2")},
{Id: aws.String("vpcLink3")},
{Id: aws.String("vpcLink4")},
{Id: aws.String("vpcLink5")},
{Id: aws.String("vpcLink6")},
}
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache)
want []*apigateway.UpdateVpcLinkOutput
wantErr error
}{
{
name: "list multiple vpc links",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetVpcLinksPages",
&apigateway.GetVpcLinksInput{},
mock.MatchedBy(func(callback func(res *apigateway.GetVpcLinksOutput, lastPage bool) bool) bool {
callback(&apigateway.GetVpcLinksOutput{
Items: vpcLinks[:3],
}, false)
callback(&apigateway.GetVpcLinksOutput{
Items: vpcLinks[3:],
}, true)
return true
})).Return(nil).Once()
store.On("Get", "apigatewayListAllVpcLinks").Return(nil).Times(1)
store.On("Put", "apigatewayListAllVpcLinks", vpcLinks).Return(false).Times(1)
},
want: vpcLinks,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
store.On("Get", "apigatewayListAllVpcLinks").Return(vpcLinks).Times(1)
},
want: vpcLinks,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGateway{}
tt.mocks(client, store)
r := &apigatewayRepository{
client: client,
cache: store,
}
got, err := r.ListAllVpcLinks()
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayRepository_ListAllRestApiRequestValidators(t *testing.T) {
api := &apigateway.RestApi{
Id: aws.String("restapi1"),
}
requestValidators := []*apigateway.UpdateRequestValidatorOutput{
{Id: aws.String("reqVal1")},
{Id: aws.String("reqVal2")},
{Id: aws.String("reqVal3")},
{Id: aws.String("reqVal4")},
}
remoteError := errors.New("remote error")
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache)
want []*apigateway.UpdateRequestValidatorOutput
wantErr error
}{
{
name: "list multiple rest api request validators",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetRequestValidators",
&apigateway.GetRequestValidatorsInput{
RestApiId: aws.String("restapi1"),
}).Return(&apigateway.GetRequestValidatorsOutput{Items: requestValidators}, nil).Once()
store.On("Get", "apigatewayListAllRestApiRequestValidators_api_restapi1").Return(nil).Times(1)
store.On("Put", "apigatewayListAllRestApiRequestValidators_api_restapi1", requestValidators).Return(false).Times(1)
},
want: requestValidators,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
store.On("Get", "apigatewayListAllRestApiRequestValidators_api_restapi1").Return(requestValidators).Times(1)
},
want: requestValidators,
},
{
name: "should return remote error",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetRequestValidators",
&apigateway.GetRequestValidatorsInput{
RestApiId: aws.String("restapi1"),
}).Return(nil, remoteError).Once()
store.On("Get", "apigatewayListAllRestApiRequestValidators_api_restapi1").Return(nil).Times(1)
},
wantErr: remoteError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGateway{}
tt.mocks(client, store)
r := &apigatewayRepository{
client: client,
cache: store,
}
got, err := r.ListAllRestApiRequestValidators(*api.Id)
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayRepository_ListAllDomainNameBasePathMappings(t *testing.T) {
domainName := &apigateway.DomainName{
DomainName: aws.String("domainName1"),
}
mappings := []*apigateway.BasePathMapping{
{BasePath: aws.String("path1")},
{BasePath: aws.String("path2")},
{BasePath: aws.String("path3")},
{BasePath: aws.String("path4")},
}
remoteError := errors.New("remote error")
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache)
want []*apigateway.BasePathMapping
wantErr error
}{
{
name: "list multiple domain name base path mappings",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetBasePathMappingsPages",
&apigateway.GetBasePathMappingsInput{
DomainName: aws.String("domainName1"),
},
mock.MatchedBy(func(callback func(res *apigateway.GetBasePathMappingsOutput, lastPage bool) bool) bool {
callback(&apigateway.GetBasePathMappingsOutput{
Items: mappings,
}, true)
return true
})).Return(nil).Once()
store.On("Get", "apigatewayListAllDomainNameBasePathMappings_domainName_domainName1").Return(nil).Times(1)
store.On("Put", "apigatewayListAllDomainNameBasePathMappings_domainName_domainName1", mappings).Return(false).Times(1)
},
want: mappings,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
store.On("Get", "apigatewayListAllDomainNameBasePathMappings_domainName_domainName1").Return(mappings).Times(1)
},
want: mappings,
},
{
name: "should return remote error",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetBasePathMappingsPages",
&apigateway.GetBasePathMappingsInput{
DomainName: aws.String("domainName1"),
}, mock.AnythingOfType("func(*apigateway.GetBasePathMappingsOutput, bool) bool")).Return(remoteError).Once()
store.On("Get", "apigatewayListAllDomainNameBasePathMappings_domainName_domainName1").Return(nil).Times(1)
},
wantErr: remoteError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGateway{}
tt.mocks(client, store)
r := &apigatewayRepository{
client: client,
cache: store,
}
got, err := r.ListAllDomainNameBasePathMappings(*domainName.DomainName)
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayRepository_ListAllRestApiModels(t *testing.T) {
api := &apigateway.RestApi{
Id: aws.String("restapi1"),
}
apiModels := []*apigateway.Model{
{Id: aws.String("model1")},
{Id: aws.String("model2")},
{Id: aws.String("model3")},
{Id: aws.String("model4")},
}
remoteError := errors.New("remote error")
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache)
want []*apigateway.Model
wantErr error
}{
{
name: "list multiple rest api models",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetModelsPages",
&apigateway.GetModelsInput{
RestApiId: aws.String("restapi1"),
},
mock.MatchedBy(func(callback func(res *apigateway.GetModelsOutput, lastPage bool) bool) bool {
callback(&apigateway.GetModelsOutput{
Items: apiModels,
}, true)
return true
})).Return(nil).Once()
store.On("Get", "apigatewayListAllRestApiModels_api_restapi1").Return(nil).Times(1)
store.On("Put", "apigatewayListAllRestApiModels_api_restapi1", apiModels).Return(false).Times(1)
},
want: apiModels,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
store.On("Get", "apigatewayListAllRestApiModels_api_restapi1").Return(apiModels).Times(1)
},
want: apiModels,
},
{
name: "should return remote error",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetModelsPages",
&apigateway.GetModelsInput{
RestApiId: aws.String("restapi1"),
}, mock.AnythingOfType("func(*apigateway.GetModelsOutput, bool) bool")).Return(remoteError).Once()
store.On("Get", "apigatewayListAllRestApiModels_api_restapi1").Return(nil).Times(1)
},
wantErr: remoteError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGateway{}
tt.mocks(client, store)
r := &apigatewayRepository{
client: client,
cache: store,
}
got, err := r.ListAllRestApiModels(*api.Id)
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayRepository_ListAllRestApiGatewayResponses(t *testing.T) {
api := &apigateway.RestApi{
Id: aws.String("restapi1"),
}
gtwResponses := []*apigateway.UpdateGatewayResponseOutput{
{ResponseType: aws.String("ACCESS_DENIED")},
{ResponseType: aws.String("DEFAULT_4XX")},
{ResponseType: aws.String("DEFAULT_5XX")},
{ResponseType: aws.String("UNAUTHORIZED")},
}
remoteError := errors.New("remote error")
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGateway, store *cache.MockCache)
want []*apigateway.UpdateGatewayResponseOutput
wantErr error
}{
{
name: "list multiple rest api gateway responses",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetGatewayResponses",
&apigateway.GetGatewayResponsesInput{
RestApiId: aws.String("restapi1"),
}).Return(&apigateway.GetGatewayResponsesOutput{Items: gtwResponses}, nil).Once()
store.On("Get", "apigatewayListAllRestApiGatewayResponses_api_restapi1").Return(nil).Times(1)
store.On("Put", "apigatewayListAllRestApiGatewayResponses_api_restapi1", gtwResponses).Return(false).Times(1)
},
want: gtwResponses,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
store.On("Get", "apigatewayListAllRestApiGatewayResponses_api_restapi1").Return(gtwResponses).Times(1)
},
want: gtwResponses,
},
{
name: "should return remote error",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
client.On("GetGatewayResponses",
&apigateway.GetGatewayResponsesInput{
RestApiId: aws.String("restapi1"),
}).Return(nil, remoteError).Once()
store.On("Get", "apigatewayListAllRestApiGatewayResponses_api_restapi1").Return(nil).Times(1)
},
wantErr: remoteError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGateway{}
tt.mocks(client, store)
r := &apigatewayRepository{
client: client,
cache: store,
}
got, err := r.ListAllRestApiGatewayResponses(*api.Id)
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}

View File

@ -0,0 +1,228 @@
package repository
import (
"fmt"
"github.com/snyk/driftctl/enumeration/remote/cache"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/apigatewayv2"
"github.com/aws/aws-sdk-go/service/apigatewayv2/apigatewayv2iface"
)
type ApiGatewayV2Repository interface {
ListAllApis() ([]*apigatewayv2.Api, error)
ListAllApiRoutes(apiId *string) ([]*apigatewayv2.Route, error)
ListAllApiDeployments(apiId *string) ([]*apigatewayv2.Deployment, error)
ListAllVpcLinks() ([]*apigatewayv2.VpcLink, error)
ListAllApiAuthorizers(string) ([]*apigatewayv2.Authorizer, error)
ListAllApiIntegrations(string) ([]*apigatewayv2.Integration, error)
ListAllApiModels(string) ([]*apigatewayv2.Model, error)
ListAllApiStages(string) ([]*apigatewayv2.Stage, error)
ListAllApiRouteResponses(string, string) ([]*apigatewayv2.RouteResponse, error)
ListAllApiMappings(string) ([]*apigatewayv2.ApiMapping, error)
ListAllApiIntegrationResponses(string, string) ([]*apigatewayv2.IntegrationResponse, error)
}
type apigatewayv2Repository struct {
client apigatewayv2iface.ApiGatewayV2API
cache cache.Cache
}
func NewApiGatewayV2Repository(session *session.Session, c cache.Cache) *apigatewayv2Repository {
return &apigatewayv2Repository{
apigatewayv2.New(session),
c,
}
}
func (r *apigatewayv2Repository) ListAllApis() ([]*apigatewayv2.Api, error) {
cacheKey := "apigatewayv2ListAllApis"
v := r.cache.GetAndLock(cacheKey)
defer r.cache.Unlock(cacheKey)
if v != nil {
return v.([]*apigatewayv2.Api), nil
}
input := apigatewayv2.GetApisInput{}
resources, err := r.client.GetApis(&input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Items)
return resources.Items, nil
}
func (r *apigatewayv2Repository) ListAllApiRoutes(apiID *string) ([]*apigatewayv2.Route, error) {
cacheKey := fmt.Sprintf("apigatewayv2ListAllApiRoutes_api_%s", *apiID)
v := r.cache.GetAndLock(cacheKey)
defer r.cache.Unlock(cacheKey)
if v != nil {
return v.([]*apigatewayv2.Route), nil
}
resources, err := r.client.GetRoutes(&apigatewayv2.GetRoutesInput{ApiId: apiID})
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Items)
return resources.Items, nil
}
func (r *apigatewayv2Repository) ListAllApiDeployments(apiID *string) ([]*apigatewayv2.Deployment, error) {
cacheKey := fmt.Sprintf("apigatewayv2ListAllApiDeployments_api_%s", *apiID)
v := r.cache.Get(cacheKey)
if v != nil {
return v.([]*apigatewayv2.Deployment), nil
}
resources, err := r.client.GetDeployments(&apigatewayv2.GetDeploymentsInput{ApiId: apiID})
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Items)
return resources.Items, nil
}
func (r *apigatewayv2Repository) ListAllVpcLinks() ([]*apigatewayv2.VpcLink, error) {
if v := r.cache.Get("apigatewayv2ListAllVpcLinks"); v != nil {
return v.([]*apigatewayv2.VpcLink), nil
}
input := apigatewayv2.GetVpcLinksInput{}
resources, err := r.client.GetVpcLinks(&input)
if err != nil {
return nil, err
}
r.cache.Put("apigatewayv2ListAllVpcLinks", resources.Items)
return resources.Items, nil
}
func (r *apigatewayv2Repository) ListAllApiAuthorizers(apiId string) ([]*apigatewayv2.Authorizer, error) {
cacheKey := fmt.Sprintf("apigatewayv2ListAllApiAuthorizers_api_%s", apiId)
if v := r.cache.Get(cacheKey); v != nil {
return v.([]*apigatewayv2.Authorizer), nil
}
input := apigatewayv2.GetAuthorizersInput{
ApiId: &apiId,
}
resources, err := r.client.GetAuthorizers(&input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Items)
return resources.Items, nil
}
func (r *apigatewayv2Repository) ListAllApiIntegrations(apiId string) ([]*apigatewayv2.Integration, error) {
cacheKey := fmt.Sprintf("apigatewayv2ListAllApiIntegrations_api_%s", apiId)
if v := r.cache.Get(cacheKey); v != nil {
return v.([]*apigatewayv2.Integration), nil
}
input := apigatewayv2.GetIntegrationsInput{
ApiId: &apiId,
}
resources, err := r.client.GetIntegrations(&input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Items)
return resources.Items, nil
}
func (r *apigatewayv2Repository) ListAllApiModels(apiId string) ([]*apigatewayv2.Model, error) {
cacheKey := fmt.Sprintf("apigatewayv2ListAllApiModels_api_%s", apiId)
if v := r.cache.Get(cacheKey); v != nil {
return v.([]*apigatewayv2.Model), nil
}
input := apigatewayv2.GetModelsInput{
ApiId: &apiId,
}
resources, err := r.client.GetModels(&input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Items)
return resources.Items, nil
}
func (r *apigatewayv2Repository) ListAllApiStages(apiId string) ([]*apigatewayv2.Stage, error) {
cacheKey := fmt.Sprintf("apigatewayv2ListAllApiStages_api_%s", apiId)
if v := r.cache.Get(cacheKey); v != nil {
return v.([]*apigatewayv2.Stage), nil
}
input := apigatewayv2.GetStagesInput{
ApiId: &apiId,
}
resources, err := r.client.GetStages(&input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Items)
return resources.Items, nil
}
func (r *apigatewayv2Repository) ListAllApiIntegrationResponses(apiId, integrationId string) ([]*apigatewayv2.IntegrationResponse, error) {
cacheKey := fmt.Sprintf("apigatewayv2ListAllApiIntegrationResponses_api_%s_integration_%s", apiId, integrationId)
v := r.cache.Get(cacheKey)
if v != nil {
return v.([]*apigatewayv2.IntegrationResponse), nil
}
input := apigatewayv2.GetIntegrationResponsesInput{
ApiId: &apiId,
IntegrationId: &integrationId,
}
resources, err := r.client.GetIntegrationResponses(&input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Items)
return resources.Items, nil
}
func (r *apigatewayv2Repository) ListAllApiRouteResponses(apiId, routeId string) ([]*apigatewayv2.RouteResponse, error) {
cacheKey := fmt.Sprintf("apigatewayv2ListAllApiRouteResponses_api_%s_route_%s", apiId, routeId)
v := r.cache.Get(cacheKey)
if v != nil {
return v.([]*apigatewayv2.RouteResponse), nil
}
input := apigatewayv2.GetRouteResponsesInput{
ApiId: &apiId,
RouteId: &routeId,
}
resources, err := r.client.GetRouteResponses(&input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Items)
return resources.Items, nil
}
func (r *apigatewayv2Repository) ListAllApiMappings(domainName string) ([]*apigatewayv2.ApiMapping, error) {
cacheKey := fmt.Sprintf("apigatewayv2ListAllApiMappings_api_%s", domainName)
if v := r.cache.Get(cacheKey); v != nil {
return v.([]*apigatewayv2.ApiMapping), nil
}
input := apigatewayv2.GetApiMappingsInput{
DomainName: &domainName,
}
resources, err := r.client.GetApiMappings(&input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Items)
return resources.Items, nil
}

View File

@ -0,0 +1,637 @@
package repository
import (
"github.com/snyk/driftctl/enumeration/remote/cache"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/apigatewayv2"
"github.com/pkg/errors"
awstest "github.com/snyk/driftctl/test/aws"
"github.com/r3labs/diff/v2"
"github.com/stretchr/testify/assert"
)
func Test_apigatewayv2Repository_ListAllApis(t *testing.T) {
apis := []*apigatewayv2.Api{
{ApiId: aws.String("api1")},
{ApiId: aws.String("api2")},
{ApiId: aws.String("api3")},
{ApiId: aws.String("api4")},
{ApiId: aws.String("api5")},
{ApiId: aws.String("api6")},
}
remoteError := errors.New("remote error")
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache)
want []*apigatewayv2.Api
wantErr error
}{
{
name: "list multiple apis",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetApis",
&apigatewayv2.GetApisInput{}).Return(&apigatewayv2.GetApisOutput{Items: apis}, nil).Once()
store.On("GetAndLock", "apigatewayv2ListAllApis").Return(nil).Times(1)
store.On("Unlock", "apigatewayv2ListAllApis").Times(1)
store.On("Put", "apigatewayv2ListAllApis", apis).Return(false).Times(1)
},
want: apis,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
store.On("GetAndLock", "apigatewayv2ListAllApis").Return(apis).Times(1)
store.On("Unlock", "apigatewayv2ListAllApis").Times(1)
},
want: apis,
},
{
name: "should return remote error",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetApis",
&apigatewayv2.GetApisInput{}).Return(nil, remoteError).Once()
store.On("GetAndLock", "apigatewayv2ListAllApis").Return(nil).Times(1)
store.On("Unlock", "apigatewayv2ListAllApis").Times(1)
},
wantErr: remoteError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGatewayV2{}
tt.mocks(client, store)
r := &apigatewayv2Repository{
client: client,
cache: store,
}
got, err := r.ListAllApis()
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayv2Repository_ListAllApiRoutes(t *testing.T) {
routes := []*apigatewayv2.Route{
{RouteId: aws.String("route1")},
{RouteId: aws.String("route2")},
{RouteId: aws.String("route3")},
}
remoteError := errors.New("remote error")
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache)
want []*apigatewayv2.Route
wantErr error
}{
{
name: "list multiple routes",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetRoutes",
&apigatewayv2.GetRoutesInput{ApiId: aws.String("an-id")}).
Return(&apigatewayv2.GetRoutesOutput{Items: routes}, nil).Once()
store.On("GetAndLock", "apigatewayv2ListAllApiRoutes_api_an-id").Return(nil).Times(1)
store.On("Unlock", "apigatewayv2ListAllApiRoutes_api_an-id").Times(1)
store.On("Put", "apigatewayv2ListAllApiRoutes_api_an-id", routes).Return(false).Times(1)
},
want: routes,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
store.On("GetAndLock", "apigatewayv2ListAllApiRoutes_api_an-id").Return(routes).Times(1)
store.On("Unlock", "apigatewayv2ListAllApiRoutes_api_an-id").Times(1)
},
want: routes,
},
{
name: "should return remote error",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetRoutes",
&apigatewayv2.GetRoutesInput{ApiId: aws.String("an-id")}).Return(nil, remoteError).Once()
store.On("GetAndLock", "apigatewayv2ListAllApiRoutes_api_an-id").Return(nil).Times(1)
store.On("Unlock", "apigatewayv2ListAllApiRoutes_api_an-id").Times(1)
},
wantErr: remoteError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGatewayV2{}
tt.mocks(client, store)
r := &apigatewayv2Repository{
client: client,
cache: store,
}
got, err := r.ListAllApiRoutes(aws.String("an-id"))
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayv2Repository_ListAllApiDeployments(t *testing.T) {
deployments := []*apigatewayv2.Deployment{
{DeploymentId: aws.String("id1")},
{DeploymentId: aws.String("id2")},
{DeploymentId: aws.String("id3")},
}
remoteError := errors.New("remote error")
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache)
want []*apigatewayv2.Deployment
wantErr error
}{
{
name: "list multiple deployments",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetDeployments",
&apigatewayv2.GetDeploymentsInput{ApiId: aws.String("an-id")}).
Return(&apigatewayv2.GetDeploymentsOutput{Items: deployments}, nil).Once()
store.On("Get", "apigatewayv2ListAllApiDeployments_api_an-id").Return(nil).Times(1)
store.On("Put", "apigatewayv2ListAllApiDeployments_api_an-id", deployments).Return(false).Times(1)
},
want: deployments,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
store.On("Get", "apigatewayv2ListAllApiDeployments_api_an-id").Return(deployments).Times(1)
},
want: deployments,
},
{
name: "should return remote error",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetDeployments",
&apigatewayv2.GetDeploymentsInput{ApiId: aws.String("an-id")}).Return(nil, remoteError).Once()
store.On("Get", "apigatewayv2ListAllApiDeployments_api_an-id").Return(nil).Times(1)
},
wantErr: remoteError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGatewayV2{}
tt.mocks(client, store)
r := &apigatewayv2Repository{
client: client,
cache: store,
}
got, err := r.ListAllApiDeployments(aws.String("an-id"))
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayv2Repository_ListAllVpcLinks(t *testing.T) {
vpcLinks := []*apigatewayv2.VpcLink{
{VpcLinkId: aws.String("vpcLink1")},
{VpcLinkId: aws.String("vpcLink2")},
{VpcLinkId: aws.String("vpcLink3")},
{VpcLinkId: aws.String("vpcLink4")},
{VpcLinkId: aws.String("vpcLink5")},
{VpcLinkId: aws.String("vpcLink6")},
}
remoteError := errors.New("remote error")
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache)
want []*apigatewayv2.VpcLink
wantErr error
}{
{
name: "list multiple vpc links",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetVpcLinks",
&apigatewayv2.GetVpcLinksInput{}).Return(&apigatewayv2.GetVpcLinksOutput{Items: vpcLinks}, nil).Once()
store.On("Get", "apigatewayv2ListAllVpcLinks").Return(nil).Times(1)
store.On("Put", "apigatewayv2ListAllVpcLinks", vpcLinks).Return(false).Times(1)
},
want: vpcLinks,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
store.On("Get", "apigatewayv2ListAllVpcLinks").Return(vpcLinks).Times(1)
},
want: vpcLinks,
},
{
name: "should return remote error",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetVpcLinks",
&apigatewayv2.GetVpcLinksInput{}).Return(nil, remoteError).Once()
store.On("Get", "apigatewayv2ListAllVpcLinks").Return(nil).Times(1)
},
wantErr: remoteError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGatewayV2{}
tt.mocks(client, store)
r := &apigatewayv2Repository{
client: client,
cache: store,
}
got, err := r.ListAllVpcLinks()
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayv2Repository_ListAllApiAuthorizers(t *testing.T) {
api := &apigatewayv2.Api{
ApiId: aws.String("api1"),
}
apiAuthorizers := []*apigatewayv2.Authorizer{
{AuthorizerId: aws.String("authorizer1")},
{AuthorizerId: aws.String("authorizer2")},
{AuthorizerId: aws.String("authorizer3")},
{AuthorizerId: aws.String("authorizer4")},
}
remoteError := errors.New("remote error")
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache)
want []*apigatewayv2.Authorizer
wantErr error
}{
{
name: "list multiple api authorizers",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetAuthorizers",
&apigatewayv2.GetAuthorizersInput{
ApiId: aws.String("api1"),
}).Return(&apigatewayv2.GetAuthorizersOutput{Items: apiAuthorizers}, nil).Once()
store.On("Get", "apigatewayv2ListAllApiAuthorizers_api_api1").Return(nil).Times(1)
store.On("Put", "apigatewayv2ListAllApiAuthorizers_api_api1", apiAuthorizers).Return(false).Times(1)
},
want: apiAuthorizers,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
store.On("Get", "apigatewayv2ListAllApiAuthorizers_api_api1").Return(apiAuthorizers).Times(1)
},
want: apiAuthorizers,
},
{
name: "should return remote error",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetAuthorizers",
&apigatewayv2.GetAuthorizersInput{
ApiId: aws.String("api1"),
}).Return(nil, remoteError).Once()
store.On("Get", "apigatewayv2ListAllApiAuthorizers_api_api1").Return(nil).Times(1)
},
wantErr: remoteError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGatewayV2{}
tt.mocks(client, store)
r := &apigatewayv2Repository{
client: client,
cache: store,
}
got, err := r.ListAllApiAuthorizers(*api.ApiId)
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayv2Repository_ListAllApiIntegrations(t *testing.T) {
api := &apigatewayv2.Api{
ApiId: aws.String("api1"),
}
apiIntegrations := []*apigatewayv2.Integration{
{IntegrationId: aws.String("integration1")},
{IntegrationId: aws.String("integration2")},
{IntegrationId: aws.String("integration3")},
{IntegrationId: aws.String("integration4")},
}
remoteError := errors.New("remote error")
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache)
want []*apigatewayv2.Integration
wantErr error
}{
{
name: "list multiple api integrations",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetIntegrations",
&apigatewayv2.GetIntegrationsInput{
ApiId: aws.String("api1"),
}).Return(&apigatewayv2.GetIntegrationsOutput{Items: apiIntegrations}, nil).Once()
store.On("Get", "apigatewayv2ListAllApiIntegrations_api_api1").Return(nil).Times(1)
store.On("Put", "apigatewayv2ListAllApiIntegrations_api_api1", apiIntegrations).Return(false).Times(1)
},
want: apiIntegrations,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
store.On("Get", "apigatewayv2ListAllApiIntegrations_api_api1").Return(apiIntegrations).Times(1)
},
want: apiIntegrations,
},
{
name: "should return remote error",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetIntegrations",
&apigatewayv2.GetIntegrationsInput{
ApiId: aws.String("api1"),
}).Return(nil, remoteError).Once()
store.On("Get", "apigatewayv2ListAllApiIntegrations_api_api1").Return(nil).Times(1)
},
wantErr: remoteError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGatewayV2{}
tt.mocks(client, store)
r := &apigatewayv2Repository{
client: client,
cache: store,
}
got, err := r.ListAllApiIntegrations(*api.ApiId)
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayv2Repository_ListAllApiRouteResponses(t *testing.T) {
api := &apigatewayv2.Api{
ApiId: aws.String("api1"),
}
route := &apigatewayv2.Route{
RouteId: aws.String("route1"),
}
responses := []*apigatewayv2.RouteResponse{
{RouteResponseId: aws.String("response1")},
{RouteResponseId: aws.String("response2")},
{RouteResponseId: aws.String("response3")},
{RouteResponseId: aws.String("response4")},
}
remoteError := errors.New("remote error")
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache)
want []*apigatewayv2.RouteResponse
wantErr error
}{
{
name: "list multiple api route responses",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetRouteResponses",
&apigatewayv2.GetRouteResponsesInput{
ApiId: aws.String("api1"),
RouteId: aws.String("route1"),
}).Return(&apigatewayv2.GetRouteResponsesOutput{Items: responses}, nil).Once()
store.On("Get", "apigatewayv2ListAllApiRouteResponses_api_api1_route_route1").Return(nil).Times(1)
store.On("Put", "apigatewayv2ListAllApiRouteResponses_api_api1_route_route1", responses).Return(false).Times(1)
},
want: responses,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
store.On("Get", "apigatewayv2ListAllApiRouteResponses_api_api1_route_route1").Return(responses).Times(1)
},
want: responses,
},
{
name: "should return remote error",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetRouteResponses",
&apigatewayv2.GetRouteResponsesInput{
ApiId: aws.String("api1"),
RouteId: aws.String("route1"),
}).Return(nil, remoteError).Once()
store.On("Get", "apigatewayv2ListAllApiRouteResponses_api_api1_route_route1").Return(nil).Times(1)
},
wantErr: remoteError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGatewayV2{}
tt.mocks(client, store)
r := &apigatewayv2Repository{
client: client,
cache: store,
}
got, err := r.ListAllApiRouteResponses(*api.ApiId, *route.RouteId)
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}
func Test_apigatewayv2Repository_ListAllApiIntegrationResponses(t *testing.T) {
api := &apigatewayv2.Api{
ApiId: aws.String("api1"),
}
integration := &apigatewayv2.Integration{
IntegrationId: aws.String("integration1"),
}
responses := []*apigatewayv2.IntegrationResponse{
{IntegrationResponseId: aws.String("response1")},
{IntegrationResponseId: aws.String("response2")},
{IntegrationResponseId: aws.String("response3")},
{IntegrationResponseId: aws.String("response4")},
}
remoteError := errors.New("remote error")
tests := []struct {
name string
mocks func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache)
want []*apigatewayv2.IntegrationResponse
wantErr error
}{
{
name: "list multiple api integration responses",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetIntegrationResponses",
&apigatewayv2.GetIntegrationResponsesInput{
ApiId: aws.String("api1"),
IntegrationId: aws.String("integration1"),
}).Return(&apigatewayv2.GetIntegrationResponsesOutput{Items: responses}, nil).Once()
store.On("Get", "apigatewayv2ListAllApiIntegrationResponses_api_api1_integration_integration1").Return(nil).Times(1)
store.On("Put", "apigatewayv2ListAllApiIntegrationResponses_api_api1_integration_integration1", responses).Return(false).Times(1)
},
want: responses,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
store.On("Get", "apigatewayv2ListAllApiIntegrationResponses_api_api1_integration_integration1").Return(responses).Times(1)
},
want: responses,
},
{
name: "should return remote error",
mocks: func(client *awstest.MockFakeApiGatewayV2, store *cache.MockCache) {
client.On("GetIntegrationResponses",
&apigatewayv2.GetIntegrationResponsesInput{
ApiId: aws.String("api1"),
IntegrationId: aws.String("integration1"),
}).Return(nil, remoteError).Once()
store.On("Get", "apigatewayv2ListAllApiIntegrationResponses_api_api1_integration_integration1").Return(nil).Times(1)
},
wantErr: remoteError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := &cache.MockCache{}
client := &awstest.MockFakeApiGatewayV2{}
tt.mocks(client, store)
r := &apigatewayv2Repository{
client: client,
cache: store,
}
got, err := r.ListAllApiIntegrationResponses(*api.ApiId, *integration.IntegrationId)
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)
assert.Nil(t, err)
if len(changelog) > 0 {
for _, change := range changelog {
t.Errorf("%s: %s -> %s", strings.Join(change.Path, "."), change.From, change.To)
}
t.Fail()
}
store.AssertExpectations(t)
client.AssertExpectations(t)
})
}
}

View File

@ -0,0 +1,87 @@
package repository
import (
"fmt"
"github.com/snyk/driftctl/enumeration/remote/cache"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/applicationautoscaling"
"github.com/aws/aws-sdk-go/service/applicationautoscaling/applicationautoscalingiface"
)
type AppAutoScalingRepository interface {
ServiceNamespaceValues() []string
DescribeScalableTargets(string) ([]*applicationautoscaling.ScalableTarget, error)
DescribeScalingPolicies(string) ([]*applicationautoscaling.ScalingPolicy, error)
DescribeScheduledActions(string) ([]*applicationautoscaling.ScheduledAction, error)
}
type appAutoScalingRepository struct {
client applicationautoscalingiface.ApplicationAutoScalingAPI
cache cache.Cache
}
func NewAppAutoScalingRepository(session *session.Session, c cache.Cache) *appAutoScalingRepository {
return &appAutoScalingRepository{
applicationautoscaling.New(session),
c,
}
}
func (r *appAutoScalingRepository) ServiceNamespaceValues() []string {
return applicationautoscaling.ServiceNamespace_Values()
}
func (r *appAutoScalingRepository) DescribeScalableTargets(namespace string) ([]*applicationautoscaling.ScalableTarget, error) {
cacheKey := fmt.Sprintf("appAutoScalingDescribeScalableTargets_%s", namespace)
if v := r.cache.Get(cacheKey); v != nil {
return v.([]*applicationautoscaling.ScalableTarget), nil
}
input := &applicationautoscaling.DescribeScalableTargetsInput{
ServiceNamespace: &namespace,
}
result, err := r.client.DescribeScalableTargets(input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, result.ScalableTargets)
return result.ScalableTargets, nil
}
func (r *appAutoScalingRepository) DescribeScalingPolicies(namespace string) ([]*applicationautoscaling.ScalingPolicy, error) {
cacheKey := fmt.Sprintf("appAutoScalingDescribeScalingPolicies_%s", namespace)
if v := r.cache.Get(cacheKey); v != nil {
return v.([]*applicationautoscaling.ScalingPolicy), nil
}
input := &applicationautoscaling.DescribeScalingPoliciesInput{
ServiceNamespace: &namespace,
}
result, err := r.client.DescribeScalingPolicies(input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, result.ScalingPolicies)
return result.ScalingPolicies, nil
}
func (r *appAutoScalingRepository) DescribeScheduledActions(namespace string) ([]*applicationautoscaling.ScheduledAction, error) {
cacheKey := fmt.Sprintf("appAutoScalingDescribeScheduledActions_%s", namespace)
if v := r.cache.Get(cacheKey); v != nil {
return v.([]*applicationautoscaling.ScheduledAction), nil
}
input := &applicationautoscaling.DescribeScheduledActionsInput{
ServiceNamespace: &namespace,
}
result, err := r.client.DescribeScheduledActions(input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, result.ScheduledActions)
return result.ScheduledActions, nil
}

Some files were not shown because too many files have changed in this diff Show More