fix: inject account id to enumerator instead of repo

main
Martin Guibert 2022-10-05 16:03:06 +02:00
parent e0104c848b
commit c94dad7f16
No known key found for this signature in database
GPG Key ID: 990E40316943BAA6
7 changed files with 35 additions and 54 deletions

View File

@ -35,7 +35,7 @@ func Init(version string, alerter alerter.AlerterInterface, providerLibrary *ter
repositoryCache := cache.New(100) repositoryCache := cache.New(100)
s3Repository := repository.NewS3Repository(client.NewAWSClientFactory(provider.session), repositoryCache) s3Repository := repository.NewS3Repository(client.NewAWSClientFactory(provider.session), repositoryCache)
s3ControlRepository := repository.NewS3ControlRepository(client.NewAWSClientFactory(provider.session), provider.accountId, repositoryCache) s3ControlRepository := repository.NewS3ControlRepository(client.NewAWSClientFactory(provider.session), repositoryCache)
ec2repository := repository.NewEC2Repository(provider.session, repositoryCache) ec2repository := repository.NewEC2Repository(provider.session, repositoryCache)
elbv2Repository := repository.NewELBV2Repository(provider.session, repositoryCache) elbv2Repository := repository.NewELBV2Repository(provider.session, repositoryCache)
route53repository := repository.NewRoute53Repository(provider.session, repositoryCache) route53repository := repository.NewRoute53Repository(provider.session, repositoryCache)
@ -72,7 +72,7 @@ func Init(version string, alerter alerter.AlerterInterface, providerLibrary *ter
remoteLibrary.AddEnumerator(NewS3BucketAnalyticEnumerator(s3Repository, factory, provider.Config, alerter)) remoteLibrary.AddEnumerator(NewS3BucketAnalyticEnumerator(s3Repository, factory, provider.Config, alerter))
remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, common.NewGenericDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, provider, deserializer)) remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, common.NewGenericDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, provider, deserializer))
remoteLibrary.AddEnumerator(NewS3BucketPublicAccessBlockEnumerator(s3Repository, factory, provider.Config, alerter)) remoteLibrary.AddEnumerator(NewS3BucketPublicAccessBlockEnumerator(s3Repository, factory, provider.Config, alerter))
remoteLibrary.AddEnumerator(NewS3AccountPublicAccessBlockEnumerator(s3ControlRepository, factory, provider.Config, alerter)) remoteLibrary.AddEnumerator(NewS3AccountPublicAccessBlockEnumerator(s3ControlRepository, factory, provider.accountId, alerter))
remoteLibrary.AddEnumerator(NewEC2EbsVolumeEnumerator(ec2repository, factory)) remoteLibrary.AddEnumerator(NewEC2EbsVolumeEnumerator(ec2repository, factory))
remoteLibrary.AddDetailsFetcher(aws.AwsEbsVolumeResourceType, common.NewGenericDetailsFetcher(aws.AwsEbsVolumeResourceType, provider, deserializer)) remoteLibrary.AddDetailsFetcher(aws.AwsEbsVolumeResourceType, common.NewGenericDetailsFetcher(aws.AwsEbsVolumeResourceType, provider, deserializer))

View File

@ -12,13 +12,13 @@ type MockS3ControlRepository struct {
mock.Mock mock.Mock
} }
// DescribeAccountPublicAccessBlock provides a mock function with given fields: // DescribeAccountPublicAccessBlock provides a mock function with given fields: accountID
func (_m *MockS3ControlRepository) DescribeAccountPublicAccessBlock() (*s3control.PublicAccessBlockConfiguration, error) { func (_m *MockS3ControlRepository) DescribeAccountPublicAccessBlock(accountID string) (*s3control.PublicAccessBlockConfiguration, error) {
ret := _m.Called() ret := _m.Called(accountID)
var r0 *s3control.PublicAccessBlockConfiguration var r0 *s3control.PublicAccessBlockConfiguration
if rf, ok := ret.Get(0).(func() *s3control.PublicAccessBlockConfiguration); ok { if rf, ok := ret.Get(0).(func(string) *s3control.PublicAccessBlockConfiguration); ok {
r0 = rf() r0 = rf(accountID)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(*s3control.PublicAccessBlockConfiguration) r0 = ret.Get(0).(*s3control.PublicAccessBlockConfiguration)
@ -26,8 +26,8 @@ func (_m *MockS3ControlRepository) DescribeAccountPublicAccessBlock() (*s3contro
} }
var r1 error var r1 error
if rf, ok := ret.Get(1).(func() error); ok { if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf() r1 = rf(accountID)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
@ -35,20 +35,6 @@ func (_m *MockS3ControlRepository) DescribeAccountPublicAccessBlock() (*s3contro
return r0, r1 return r0, r1
} }
// GetAccountID provides a mock function with given fields:
func (_m *MockS3ControlRepository) GetAccountID() string {
ret := _m.Called()
var r0 string
if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(string)
}
return r0
}
type mockConstructorTestingTNewMockS3ControlRepository interface { type mockConstructorTestingTNewMockS3ControlRepository interface {
mock.TestingT mock.TestingT
Cleanup(func()) Cleanup(func())

View File

@ -8,34 +8,28 @@ import (
) )
type S3ControlRepository interface { type S3ControlRepository interface {
DescribeAccountPublicAccessBlock() (*s3control.PublicAccessBlockConfiguration, error) DescribeAccountPublicAccessBlock(accountID string) (*s3control.PublicAccessBlockConfiguration, error)
GetAccountID() string
} }
type s3ControlRepository struct { type s3ControlRepository struct {
clientFactory client.AwsClientFactoryInterface clientFactory client.AwsClientFactoryInterface
accountId string
cache cache.Cache cache cache.Cache
} }
func NewS3ControlRepository(factory client.AwsClientFactoryInterface, accountId string, c cache.Cache) *s3ControlRepository { func NewS3ControlRepository(factory client.AwsClientFactoryInterface, c cache.Cache) *s3ControlRepository {
return &s3ControlRepository{ return &s3ControlRepository{
clientFactory: factory, clientFactory: factory,
accountId: accountId,
cache: c, cache: c,
} }
} }
func (s *s3ControlRepository) GetAccountID() string {
return s.accountId
}
func (s *s3ControlRepository) DescribeAccountPublicAccessBlock() (*s3control.PublicAccessBlockConfiguration, error) { func (s *s3ControlRepository) DescribeAccountPublicAccessBlock(accountID string) (*s3control.PublicAccessBlockConfiguration, error) {
cacheKey := "S3DescribeAccountPublicAccessBlock" cacheKey := "S3DescribeAccountPublicAccessBlock"
if v := s.cache.Get(cacheKey); v != nil { if v := s.cache.Get(cacheKey); v != nil {
return v.(*s3control.PublicAccessBlockConfiguration), nil return v.(*s3control.PublicAccessBlockConfiguration), nil
} }
out, err := s.clientFactory.GetS3ControlClient(nil).GetPublicAccessBlock(&s3control.GetPublicAccessBlockInput{ out, err := s.clientFactory.GetS3ControlClient(nil).GetPublicAccessBlock(&s3control.GetPublicAccessBlockInput{
AccountId: aws.String(s.accountId), AccountId: aws.String(accountID),
}) })
if err != nil { if err != nil {

View File

@ -17,6 +17,7 @@ import (
) )
func Test_s3ControlRepository_DescribeAccountPublicAccessBlock(t *testing.T) { func Test_s3ControlRepository_DescribeAccountPublicAccessBlock(t *testing.T) {
accountID := "123456"
tests := []struct { tests := []struct {
name string name string
@ -65,14 +66,14 @@ func Test_s3ControlRepository_DescribeAccountPublicAccessBlock(t *testing.T) {
tt.mocks(mockedClient) tt.mocks(mockedClient)
factory := client.MockAwsClientFactoryInterface{} factory := client.MockAwsClientFactoryInterface{}
factory.On("GetS3ControlClient", (*aws.Config)(nil)).Return(mockedClient).Once() factory.On("GetS3ControlClient", (*aws.Config)(nil)).Return(mockedClient).Once()
r := NewS3ControlRepository(&factory, "", store) r := NewS3ControlRepository(&factory, store)
got, err := r.DescribeAccountPublicAccessBlock() got, err := r.DescribeAccountPublicAccessBlock(accountID)
factory.AssertExpectations(t) factory.AssertExpectations(t)
assert.Equal(t, tt.wantErr, err) assert.Equal(t, tt.wantErr, err)
if err == nil { if err == nil {
// Check that results were cached // Check that results were cached
cachedData, err := r.DescribeAccountPublicAccessBlock() cachedData, err := r.DescribeAccountPublicAccessBlock(accountID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, got, cachedData) assert.Equal(t, got, cachedData)
assert.IsType(t, &s3control.PublicAccessBlockConfiguration{}, store.Get("S3DescribeAccountPublicAccessBlock")) assert.IsType(t, &s3control.PublicAccessBlockConfiguration{}, store.Get("S3DescribeAccountPublicAccessBlock"))

View File

@ -5,7 +5,6 @@ import (
"github.com/snyk/driftctl/enumeration/alerter" "github.com/snyk/driftctl/enumeration/alerter"
"github.com/snyk/driftctl/enumeration/remote/aws/repository" "github.com/snyk/driftctl/enumeration/remote/aws/repository"
remoteerror "github.com/snyk/driftctl/enumeration/remote/error" remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
tf "github.com/snyk/driftctl/enumeration/remote/terraform"
"github.com/snyk/driftctl/enumeration/resource" "github.com/snyk/driftctl/enumeration/resource"
"github.com/snyk/driftctl/enumeration/resource/aws" "github.com/snyk/driftctl/enumeration/resource/aws"
) )
@ -13,15 +12,15 @@ import (
type S3AccountPublicAccessBlockEnumerator struct { type S3AccountPublicAccessBlockEnumerator struct {
repository repository.S3ControlRepository repository repository.S3ControlRepository
factory resource.ResourceFactory factory resource.ResourceFactory
providerConfig tf.TerraformProviderConfig accountID string
alerter alerter.AlerterInterface alerter alerter.AlerterInterface
} }
func NewS3AccountPublicAccessBlockEnumerator(repo repository.S3ControlRepository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3AccountPublicAccessBlockEnumerator { func NewS3AccountPublicAccessBlockEnumerator(repo repository.S3ControlRepository, factory resource.ResourceFactory, accountId string, alerter alerter.AlerterInterface) *S3AccountPublicAccessBlockEnumerator {
return &S3AccountPublicAccessBlockEnumerator{ return &S3AccountPublicAccessBlockEnumerator{
repository: repo, repository: repo,
factory: factory, factory: factory,
providerConfig: providerConfig, accountID: accountId,
alerter: alerter, alerter: alerter,
} }
} }
@ -31,7 +30,7 @@ func (e *S3AccountPublicAccessBlockEnumerator) SupportedType() resource.Resource
} }
func (e *S3AccountPublicAccessBlockEnumerator) Enumerate() ([]*resource.Resource, error) { func (e *S3AccountPublicAccessBlockEnumerator) Enumerate() ([]*resource.Resource, error) {
accountPublicAccessBlock, err := e.repository.DescribeAccountPublicAccessBlock() accountPublicAccessBlock, err := e.repository.DescribeAccountPublicAccessBlock(e.accountID)
if err != nil { if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType())) return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
} }
@ -42,7 +41,7 @@ func (e *S3AccountPublicAccessBlockEnumerator) Enumerate() ([]*resource.Resource
results, results,
e.factory.CreateAbstractResource( e.factory.CreateAbstractResource(
string(e.SupportedType()), string(e.SupportedType()),
e.repository.GetAccountID(), e.accountID,
map[string]interface{}{ map[string]interface{}{
"block_public_acls": awssdk.BoolValue(accountPublicAccessBlock.BlockPublicAcls), "block_public_acls": awssdk.BoolValue(accountPublicAccessBlock.BlockPublicAcls),
"block_public_policy": awssdk.BoolValue(accountPublicAccessBlock.BlockPublicPolicy), "block_public_policy": awssdk.BoolValue(accountPublicAccessBlock.BlockPublicPolicy),

View File

@ -1071,6 +1071,7 @@ func TestS3BucketAnalytic(t *testing.T) {
func TestS3AccountPublicAccessBlock(t *testing.T) { func TestS3AccountPublicAccessBlock(t *testing.T) {
dummyError := errors.New("this is an error") dummyError := errors.New("this is an error")
accountID := "123456"
tests := []struct { tests := []struct {
test string test string
mocks func(*repository.MockS3ControlRepository, *mocks.AlerterInterface) mocks func(*repository.MockS3ControlRepository, *mocks.AlerterInterface)
@ -1080,8 +1081,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) {
{ {
test: "existing access block", test: "existing access block",
mocks: func(repository *repository.MockS3ControlRepository, alerter *mocks.AlerterInterface) { mocks: func(repository *repository.MockS3ControlRepository, alerter *mocks.AlerterInterface) {
repository.On("GetAccountID").Return("123456") repository.On("DescribeAccountPublicAccessBlock", accountID).Return(&s3control.PublicAccessBlockConfiguration{
repository.On("DescribeAccountPublicAccessBlock").Return(&s3control.PublicAccessBlockConfiguration{
BlockPublicAcls: awssdk.Bool(false), BlockPublicAcls: awssdk.Bool(false),
BlockPublicPolicy: awssdk.Bool(true), BlockPublicPolicy: awssdk.Bool(true),
IgnorePublicAcls: awssdk.Bool(false), IgnorePublicAcls: awssdk.Bool(false),
@ -1090,7 +1090,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) {
}, },
assertExpected: func(t *testing.T, got []*resource.Resource) { assertExpected: func(t *testing.T, got []*resource.Resource) {
assert.Len(t, got, 1) assert.Len(t, got, 1)
assert.Equal(t, got[0].ResourceId(), "123456") assert.Equal(t, got[0].ResourceId(), accountID)
assert.Equal(t, got[0].ResourceType(), resourceaws.AwsS3AccountPublicAccessBlock) assert.Equal(t, got[0].ResourceType(), resourceaws.AwsS3AccountPublicAccessBlock)
assert.Equal(t, got[0].Attributes(), &resource.Attributes{ assert.Equal(t, got[0].Attributes(), &resource.Attributes{
"block_public_acls": false, "block_public_acls": false,
@ -1103,7 +1103,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) {
{ {
test: "cannot list access block", test: "cannot list access block",
mocks: func(repository *repository.MockS3ControlRepository, alerter *mocks.AlerterInterface) { mocks: func(repository *repository.MockS3ControlRepository, alerter *mocks.AlerterInterface) {
repository.On("DescribeAccountPublicAccessBlock").Return(nil, dummyError) repository.On("DescribeAccountPublicAccessBlock", accountID).Return(nil, dummyError)
}, },
wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsS3AccountPublicAccessBlock), wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsS3AccountPublicAccessBlock),
}, },
@ -1125,7 +1125,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) {
remoteLibrary.AddEnumerator(aws.NewS3AccountPublicAccessBlockEnumerator( remoteLibrary.AddEnumerator(aws.NewS3AccountPublicAccessBlockEnumerator(
repo, factory, repo, factory,
tf.TerraformProviderConfig{DefaultAlias: "us-east-1"}, accountID,
alerter, alerter,
)) ))

View File

@ -99,7 +99,8 @@ var supportedTypes = map[string]ResourceTypeMeta{
"aws_s3_bucket_metric": {}, "aws_s3_bucket_metric": {},
"aws_s3_bucket_notification": {}, "aws_s3_bucket_notification": {},
"aws_s3_bucket_policy": {}, "aws_s3_bucket_policy": {},
"aws_s3_bucket_public_access_block": {}, "aws_security_group": {children: []ResourceType{ "aws_s3_bucket_public_access_block": {},
"aws_security_group": {children: []ResourceType{
"aws_security_group_rule", "aws_security_group_rule",
}}, }},
"aws_s3_account_public_access_block": {}, "aws_s3_account_public_access_block": {},