fix: inject account id to enumerator instead of repo
parent
e0104c848b
commit
c94dad7f16
|
@ -35,7 +35,7 @@ func Init(version string, alerter alerter.AlerterInterface, providerLibrary *ter
|
||||||
repositoryCache := cache.New(100)
|
repositoryCache := cache.New(100)
|
||||||
|
|
||||||
s3Repository := repository.NewS3Repository(client.NewAWSClientFactory(provider.session), repositoryCache)
|
s3Repository := repository.NewS3Repository(client.NewAWSClientFactory(provider.session), repositoryCache)
|
||||||
s3ControlRepository := repository.NewS3ControlRepository(client.NewAWSClientFactory(provider.session), provider.accountId, repositoryCache)
|
s3ControlRepository := repository.NewS3ControlRepository(client.NewAWSClientFactory(provider.session), repositoryCache)
|
||||||
ec2repository := repository.NewEC2Repository(provider.session, repositoryCache)
|
ec2repository := repository.NewEC2Repository(provider.session, repositoryCache)
|
||||||
elbv2Repository := repository.NewELBV2Repository(provider.session, repositoryCache)
|
elbv2Repository := repository.NewELBV2Repository(provider.session, repositoryCache)
|
||||||
route53repository := repository.NewRoute53Repository(provider.session, repositoryCache)
|
route53repository := repository.NewRoute53Repository(provider.session, repositoryCache)
|
||||||
|
@ -72,7 +72,7 @@ func Init(version string, alerter alerter.AlerterInterface, providerLibrary *ter
|
||||||
remoteLibrary.AddEnumerator(NewS3BucketAnalyticEnumerator(s3Repository, factory, provider.Config, alerter))
|
remoteLibrary.AddEnumerator(NewS3BucketAnalyticEnumerator(s3Repository, factory, provider.Config, alerter))
|
||||||
remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, common.NewGenericDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, provider, deserializer))
|
remoteLibrary.AddDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, common.NewGenericDetailsFetcher(aws.AwsS3BucketAnalyticsConfigurationResourceType, provider, deserializer))
|
||||||
remoteLibrary.AddEnumerator(NewS3BucketPublicAccessBlockEnumerator(s3Repository, factory, provider.Config, alerter))
|
remoteLibrary.AddEnumerator(NewS3BucketPublicAccessBlockEnumerator(s3Repository, factory, provider.Config, alerter))
|
||||||
remoteLibrary.AddEnumerator(NewS3AccountPublicAccessBlockEnumerator(s3ControlRepository, factory, provider.Config, alerter))
|
remoteLibrary.AddEnumerator(NewS3AccountPublicAccessBlockEnumerator(s3ControlRepository, factory, provider.accountId, alerter))
|
||||||
|
|
||||||
remoteLibrary.AddEnumerator(NewEC2EbsVolumeEnumerator(ec2repository, factory))
|
remoteLibrary.AddEnumerator(NewEC2EbsVolumeEnumerator(ec2repository, factory))
|
||||||
remoteLibrary.AddDetailsFetcher(aws.AwsEbsVolumeResourceType, common.NewGenericDetailsFetcher(aws.AwsEbsVolumeResourceType, provider, deserializer))
|
remoteLibrary.AddDetailsFetcher(aws.AwsEbsVolumeResourceType, common.NewGenericDetailsFetcher(aws.AwsEbsVolumeResourceType, provider, deserializer))
|
||||||
|
|
|
@ -12,13 +12,13 @@ type MockS3ControlRepository struct {
|
||||||
mock.Mock
|
mock.Mock
|
||||||
}
|
}
|
||||||
|
|
||||||
// DescribeAccountPublicAccessBlock provides a mock function with given fields:
|
// DescribeAccountPublicAccessBlock provides a mock function with given fields: accountID
|
||||||
func (_m *MockS3ControlRepository) DescribeAccountPublicAccessBlock() (*s3control.PublicAccessBlockConfiguration, error) {
|
func (_m *MockS3ControlRepository) DescribeAccountPublicAccessBlock(accountID string) (*s3control.PublicAccessBlockConfiguration, error) {
|
||||||
ret := _m.Called()
|
ret := _m.Called(accountID)
|
||||||
|
|
||||||
var r0 *s3control.PublicAccessBlockConfiguration
|
var r0 *s3control.PublicAccessBlockConfiguration
|
||||||
if rf, ok := ret.Get(0).(func() *s3control.PublicAccessBlockConfiguration); ok {
|
if rf, ok := ret.Get(0).(func(string) *s3control.PublicAccessBlockConfiguration); ok {
|
||||||
r0 = rf()
|
r0 = rf(accountID)
|
||||||
} else {
|
} else {
|
||||||
if ret.Get(0) != nil {
|
if ret.Get(0) != nil {
|
||||||
r0 = ret.Get(0).(*s3control.PublicAccessBlockConfiguration)
|
r0 = ret.Get(0).(*s3control.PublicAccessBlockConfiguration)
|
||||||
|
@ -26,8 +26,8 @@ func (_m *MockS3ControlRepository) DescribeAccountPublicAccessBlock() (*s3contro
|
||||||
}
|
}
|
||||||
|
|
||||||
var r1 error
|
var r1 error
|
||||||
if rf, ok := ret.Get(1).(func() error); ok {
|
if rf, ok := ret.Get(1).(func(string) error); ok {
|
||||||
r1 = rf()
|
r1 = rf(accountID)
|
||||||
} else {
|
} else {
|
||||||
r1 = ret.Error(1)
|
r1 = ret.Error(1)
|
||||||
}
|
}
|
||||||
|
@ -35,20 +35,6 @@ func (_m *MockS3ControlRepository) DescribeAccountPublicAccessBlock() (*s3contro
|
||||||
return r0, r1
|
return r0, r1
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountID provides a mock function with given fields:
|
|
||||||
func (_m *MockS3ControlRepository) GetAccountID() string {
|
|
||||||
ret := _m.Called()
|
|
||||||
|
|
||||||
var r0 string
|
|
||||||
if rf, ok := ret.Get(0).(func() string); ok {
|
|
||||||
r0 = rf()
|
|
||||||
} else {
|
|
||||||
r0 = ret.Get(0).(string)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r0
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockConstructorTestingTNewMockS3ControlRepository interface {
|
type mockConstructorTestingTNewMockS3ControlRepository interface {
|
||||||
mock.TestingT
|
mock.TestingT
|
||||||
Cleanup(func())
|
Cleanup(func())
|
||||||
|
|
|
@ -8,34 +8,28 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type S3ControlRepository interface {
|
type S3ControlRepository interface {
|
||||||
DescribeAccountPublicAccessBlock() (*s3control.PublicAccessBlockConfiguration, error)
|
DescribeAccountPublicAccessBlock(accountID string) (*s3control.PublicAccessBlockConfiguration, error)
|
||||||
GetAccountID() string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type s3ControlRepository struct {
|
type s3ControlRepository struct {
|
||||||
clientFactory client.AwsClientFactoryInterface
|
clientFactory client.AwsClientFactoryInterface
|
||||||
accountId string
|
|
||||||
cache cache.Cache
|
cache cache.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewS3ControlRepository(factory client.AwsClientFactoryInterface, accountId string, c cache.Cache) *s3ControlRepository {
|
func NewS3ControlRepository(factory client.AwsClientFactoryInterface, c cache.Cache) *s3ControlRepository {
|
||||||
return &s3ControlRepository{
|
return &s3ControlRepository{
|
||||||
clientFactory: factory,
|
clientFactory: factory,
|
||||||
accountId: accountId,
|
|
||||||
cache: c,
|
cache: c,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (s *s3ControlRepository) GetAccountID() string {
|
|
||||||
return s.accountId
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *s3ControlRepository) DescribeAccountPublicAccessBlock() (*s3control.PublicAccessBlockConfiguration, error) {
|
func (s *s3ControlRepository) DescribeAccountPublicAccessBlock(accountID string) (*s3control.PublicAccessBlockConfiguration, error) {
|
||||||
cacheKey := "S3DescribeAccountPublicAccessBlock"
|
cacheKey := "S3DescribeAccountPublicAccessBlock"
|
||||||
if v := s.cache.Get(cacheKey); v != nil {
|
if v := s.cache.Get(cacheKey); v != nil {
|
||||||
return v.(*s3control.PublicAccessBlockConfiguration), nil
|
return v.(*s3control.PublicAccessBlockConfiguration), nil
|
||||||
}
|
}
|
||||||
out, err := s.clientFactory.GetS3ControlClient(nil).GetPublicAccessBlock(&s3control.GetPublicAccessBlockInput{
|
out, err := s.clientFactory.GetS3ControlClient(nil).GetPublicAccessBlock(&s3control.GetPublicAccessBlockInput{
|
||||||
AccountId: aws.String(s.accountId),
|
AccountId: aws.String(accountID),
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -17,6 +17,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_s3ControlRepository_DescribeAccountPublicAccessBlock(t *testing.T) {
|
func Test_s3ControlRepository_DescribeAccountPublicAccessBlock(t *testing.T) {
|
||||||
|
accountID := "123456"
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -65,14 +66,14 @@ func Test_s3ControlRepository_DescribeAccountPublicAccessBlock(t *testing.T) {
|
||||||
tt.mocks(mockedClient)
|
tt.mocks(mockedClient)
|
||||||
factory := client.MockAwsClientFactoryInterface{}
|
factory := client.MockAwsClientFactoryInterface{}
|
||||||
factory.On("GetS3ControlClient", (*aws.Config)(nil)).Return(mockedClient).Once()
|
factory.On("GetS3ControlClient", (*aws.Config)(nil)).Return(mockedClient).Once()
|
||||||
r := NewS3ControlRepository(&factory, "", store)
|
r := NewS3ControlRepository(&factory, store)
|
||||||
got, err := r.DescribeAccountPublicAccessBlock()
|
got, err := r.DescribeAccountPublicAccessBlock(accountID)
|
||||||
factory.AssertExpectations(t)
|
factory.AssertExpectations(t)
|
||||||
assert.Equal(t, tt.wantErr, err)
|
assert.Equal(t, tt.wantErr, err)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// Check that results were cached
|
// Check that results were cached
|
||||||
cachedData, err := r.DescribeAccountPublicAccessBlock()
|
cachedData, err := r.DescribeAccountPublicAccessBlock(accountID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, got, cachedData)
|
assert.Equal(t, got, cachedData)
|
||||||
assert.IsType(t, &s3control.PublicAccessBlockConfiguration{}, store.Get("S3DescribeAccountPublicAccessBlock"))
|
assert.IsType(t, &s3control.PublicAccessBlockConfiguration{}, store.Get("S3DescribeAccountPublicAccessBlock"))
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"github.com/snyk/driftctl/enumeration/alerter"
|
"github.com/snyk/driftctl/enumeration/alerter"
|
||||||
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
|
"github.com/snyk/driftctl/enumeration/remote/aws/repository"
|
||||||
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
|
remoteerror "github.com/snyk/driftctl/enumeration/remote/error"
|
||||||
tf "github.com/snyk/driftctl/enumeration/remote/terraform"
|
|
||||||
"github.com/snyk/driftctl/enumeration/resource"
|
"github.com/snyk/driftctl/enumeration/resource"
|
||||||
"github.com/snyk/driftctl/enumeration/resource/aws"
|
"github.com/snyk/driftctl/enumeration/resource/aws"
|
||||||
)
|
)
|
||||||
|
@ -13,15 +12,15 @@ import (
|
||||||
type S3AccountPublicAccessBlockEnumerator struct {
|
type S3AccountPublicAccessBlockEnumerator struct {
|
||||||
repository repository.S3ControlRepository
|
repository repository.S3ControlRepository
|
||||||
factory resource.ResourceFactory
|
factory resource.ResourceFactory
|
||||||
providerConfig tf.TerraformProviderConfig
|
accountID string
|
||||||
alerter alerter.AlerterInterface
|
alerter alerter.AlerterInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewS3AccountPublicAccessBlockEnumerator(repo repository.S3ControlRepository, factory resource.ResourceFactory, providerConfig tf.TerraformProviderConfig, alerter alerter.AlerterInterface) *S3AccountPublicAccessBlockEnumerator {
|
func NewS3AccountPublicAccessBlockEnumerator(repo repository.S3ControlRepository, factory resource.ResourceFactory, accountId string, alerter alerter.AlerterInterface) *S3AccountPublicAccessBlockEnumerator {
|
||||||
return &S3AccountPublicAccessBlockEnumerator{
|
return &S3AccountPublicAccessBlockEnumerator{
|
||||||
repository: repo,
|
repository: repo,
|
||||||
factory: factory,
|
factory: factory,
|
||||||
providerConfig: providerConfig,
|
accountID: accountId,
|
||||||
alerter: alerter,
|
alerter: alerter,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -31,7 +30,7 @@ func (e *S3AccountPublicAccessBlockEnumerator) SupportedType() resource.Resource
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *S3AccountPublicAccessBlockEnumerator) Enumerate() ([]*resource.Resource, error) {
|
func (e *S3AccountPublicAccessBlockEnumerator) Enumerate() ([]*resource.Resource, error) {
|
||||||
accountPublicAccessBlock, err := e.repository.DescribeAccountPublicAccessBlock()
|
accountPublicAccessBlock, err := e.repository.DescribeAccountPublicAccessBlock(e.accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
|
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
|
||||||
}
|
}
|
||||||
|
@ -42,7 +41,7 @@ func (e *S3AccountPublicAccessBlockEnumerator) Enumerate() ([]*resource.Resource
|
||||||
results,
|
results,
|
||||||
e.factory.CreateAbstractResource(
|
e.factory.CreateAbstractResource(
|
||||||
string(e.SupportedType()),
|
string(e.SupportedType()),
|
||||||
e.repository.GetAccountID(),
|
e.accountID,
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"block_public_acls": awssdk.BoolValue(accountPublicAccessBlock.BlockPublicAcls),
|
"block_public_acls": awssdk.BoolValue(accountPublicAccessBlock.BlockPublicAcls),
|
||||||
"block_public_policy": awssdk.BoolValue(accountPublicAccessBlock.BlockPublicPolicy),
|
"block_public_policy": awssdk.BoolValue(accountPublicAccessBlock.BlockPublicPolicy),
|
||||||
|
|
|
@ -1071,6 +1071,7 @@ func TestS3BucketAnalytic(t *testing.T) {
|
||||||
func TestS3AccountPublicAccessBlock(t *testing.T) {
|
func TestS3AccountPublicAccessBlock(t *testing.T) {
|
||||||
dummyError := errors.New("this is an error")
|
dummyError := errors.New("this is an error")
|
||||||
|
|
||||||
|
accountID := "123456"
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
test string
|
test string
|
||||||
mocks func(*repository.MockS3ControlRepository, *mocks.AlerterInterface)
|
mocks func(*repository.MockS3ControlRepository, *mocks.AlerterInterface)
|
||||||
|
@ -1080,8 +1081,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) {
|
||||||
{
|
{
|
||||||
test: "existing access block",
|
test: "existing access block",
|
||||||
mocks: func(repository *repository.MockS3ControlRepository, alerter *mocks.AlerterInterface) {
|
mocks: func(repository *repository.MockS3ControlRepository, alerter *mocks.AlerterInterface) {
|
||||||
repository.On("GetAccountID").Return("123456")
|
repository.On("DescribeAccountPublicAccessBlock", accountID).Return(&s3control.PublicAccessBlockConfiguration{
|
||||||
repository.On("DescribeAccountPublicAccessBlock").Return(&s3control.PublicAccessBlockConfiguration{
|
|
||||||
BlockPublicAcls: awssdk.Bool(false),
|
BlockPublicAcls: awssdk.Bool(false),
|
||||||
BlockPublicPolicy: awssdk.Bool(true),
|
BlockPublicPolicy: awssdk.Bool(true),
|
||||||
IgnorePublicAcls: awssdk.Bool(false),
|
IgnorePublicAcls: awssdk.Bool(false),
|
||||||
|
@ -1090,7 +1090,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) {
|
||||||
},
|
},
|
||||||
assertExpected: func(t *testing.T, got []*resource.Resource) {
|
assertExpected: func(t *testing.T, got []*resource.Resource) {
|
||||||
assert.Len(t, got, 1)
|
assert.Len(t, got, 1)
|
||||||
assert.Equal(t, got[0].ResourceId(), "123456")
|
assert.Equal(t, got[0].ResourceId(), accountID)
|
||||||
assert.Equal(t, got[0].ResourceType(), resourceaws.AwsS3AccountPublicAccessBlock)
|
assert.Equal(t, got[0].ResourceType(), resourceaws.AwsS3AccountPublicAccessBlock)
|
||||||
assert.Equal(t, got[0].Attributes(), &resource.Attributes{
|
assert.Equal(t, got[0].Attributes(), &resource.Attributes{
|
||||||
"block_public_acls": false,
|
"block_public_acls": false,
|
||||||
|
@ -1103,7 +1103,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) {
|
||||||
{
|
{
|
||||||
test: "cannot list access block",
|
test: "cannot list access block",
|
||||||
mocks: func(repository *repository.MockS3ControlRepository, alerter *mocks.AlerterInterface) {
|
mocks: func(repository *repository.MockS3ControlRepository, alerter *mocks.AlerterInterface) {
|
||||||
repository.On("DescribeAccountPublicAccessBlock").Return(nil, dummyError)
|
repository.On("DescribeAccountPublicAccessBlock", accountID).Return(nil, dummyError)
|
||||||
},
|
},
|
||||||
wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsS3AccountPublicAccessBlock),
|
wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsS3AccountPublicAccessBlock),
|
||||||
},
|
},
|
||||||
|
@ -1125,7 +1125,7 @@ func TestS3AccountPublicAccessBlock(t *testing.T) {
|
||||||
|
|
||||||
remoteLibrary.AddEnumerator(aws.NewS3AccountPublicAccessBlockEnumerator(
|
remoteLibrary.AddEnumerator(aws.NewS3AccountPublicAccessBlockEnumerator(
|
||||||
repo, factory,
|
repo, factory,
|
||||||
tf.TerraformProviderConfig{DefaultAlias: "us-east-1"},
|
accountID,
|
||||||
alerter,
|
alerter,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
|
@ -99,7 +99,8 @@ var supportedTypes = map[string]ResourceTypeMeta{
|
||||||
"aws_s3_bucket_metric": {},
|
"aws_s3_bucket_metric": {},
|
||||||
"aws_s3_bucket_notification": {},
|
"aws_s3_bucket_notification": {},
|
||||||
"aws_s3_bucket_policy": {},
|
"aws_s3_bucket_policy": {},
|
||||||
"aws_s3_bucket_public_access_block": {}, "aws_security_group": {children: []ResourceType{
|
"aws_s3_bucket_public_access_block": {},
|
||||||
|
"aws_security_group": {children: []ResourceType{
|
||||||
"aws_security_group_rule",
|
"aws_security_group_rule",
|
||||||
}},
|
}},
|
||||||
"aws_s3_account_public_access_block": {},
|
"aws_s3_account_public_access_block": {},
|
||||||
|
|
Loading…
Reference in New Issue