Merge pull request #1591 from snyk/fix/account_public_access_block

fix issue when scanning without any s3_account_public_access_block
main
Martin 2022-10-13 11:08:56 +02:00 committed by GitHub
commit a28bac977c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 8 deletions

View File

@ -2,6 +2,7 @@ package repository
import ( import (
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/s3control" "github.com/aws/aws-sdk-go/service/s3control"
"github.com/snyk/driftctl/enumeration/remote/aws/client" "github.com/snyk/driftctl/enumeration/remote/aws/client"
"github.com/snyk/driftctl/enumeration/remote/cache" "github.com/snyk/driftctl/enumeration/remote/cache"
@ -33,6 +34,10 @@ func (s *s3ControlRepository) DescribeAccountPublicAccessBlock(accountID string)
}) })
if err != nil { if err != nil {
if s.shouldSuppressError(err) {
return nil, nil
}
return nil, err return nil, err
} }
@ -41,3 +46,13 @@ func (s *s3ControlRepository) DescribeAccountPublicAccessBlock(accountID string)
s.cache.Put(cacheKey, result) s.cache.Put(cacheKey, result)
return result, nil return result, nil
} }
func (s *s3ControlRepository) shouldSuppressError(err error) bool {
if requestFailure, ok := err.(awserr.RequestFailure); ok {
if requestFailure.Code() == "NoSuchPublicAccessBlockConfiguration" {
// do not throw the error up if there is no access block config
return true
}
}
return false
}

View File

@ -10,7 +10,6 @@ import (
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/r3labs/diff/v2" "github.com/r3labs/diff/v2"
awstest "github.com/snyk/driftctl/test/aws" awstest "github.com/snyk/driftctl/test/aws"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -23,7 +22,7 @@ func Test_s3ControlRepository_DescribeAccountPublicAccessBlock(t *testing.T) {
name string name string
mocks func(client *awstest.MockFakeS3Control) mocks func(client *awstest.MockFakeS3Control)
want *s3control.PublicAccessBlockConfiguration want *s3control.PublicAccessBlockConfiguration
wantErr error wantErr bool
}{ }{
{ {
name: "describe account public access block", name: "describe account public access block",
@ -48,15 +47,29 @@ func Test_s3ControlRepository_DescribeAccountPublicAccessBlock(t *testing.T) {
}, },
}, },
{ {
name: "Error detting account public accessblock", name: "Error getting account public access block",
mocks: func(client *awstest.MockFakeS3Control) { mocks: func(client *awstest.MockFakeS3Control) {
fakeRequestFailure := &awstest.MockFakeRequestFailure{}
fakeRequestFailure.On("Code").Return("FakeErrorCode")
client.On("GetPublicAccessBlock", mock.Anything).Return( client.On("GetPublicAccessBlock", mock.Anything).Return(
nil, nil,
awserr.NewRequestFailure(nil, 403, ""), fakeRequestFailure,
).Once()
},
want: nil,
wantErr: true,
},
{
name: "Error no account public access block",
mocks: func(client *awstest.MockFakeS3Control) {
fakeRequestFailure := &awstest.MockFakeRequestFailure{}
fakeRequestFailure.On("Code").Return("NoSuchPublicAccessBlockConfiguration")
client.On("GetPublicAccessBlock", mock.Anything).Return(
nil,
fakeRequestFailure,
).Once() ).Once()
}, },
want: nil, want: nil,
wantErr: awserr.NewRequestFailure(nil, 403, ""),
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@ -69,9 +82,9 @@ func Test_s3ControlRepository_DescribeAccountPublicAccessBlock(t *testing.T) {
r := NewS3ControlRepository(&factory, store) r := NewS3ControlRepository(&factory, store)
got, err := r.DescribeAccountPublicAccessBlock(accountID) got, err := r.DescribeAccountPublicAccessBlock(accountID)
factory.AssertExpectations(t) factory.AssertExpectations(t)
assert.Equal(t, tt.wantErr, err) assert.Equal(t, tt.wantErr, err != nil)
if err == nil { if err == nil && got != nil {
// Check that results were cached // Check that results were cached
cachedData, err := r.DescribeAccountPublicAccessBlock(accountID) cachedData, err := r.DescribeAccountPublicAccessBlock(accountID)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -37,6 +37,10 @@ func (e *S3AccountPublicAccessBlockEnumerator) Enumerate() ([]*resource.Resource
results := make([]*resource.Resource, 0, 1) results := make([]*resource.Resource, 0, 1)
if accountPublicAccessBlock == nil {
return results, nil
}
results = append( results = append(
results, results,
e.factory.CreateAbstractResource( e.factory.CreateAbstractResource(