Fix api_gateway_authorizer list signature + tests

main
William Beuil 2021-10-06 16:43:51 +02:00
parent 7f7c239d5b
commit 9c1e68b226
No known key found for this signature in database
GPG Key ID: BED2072C5C2BF537
5 changed files with 51 additions and 58 deletions

View File

@ -281,7 +281,8 @@ func TestApiGatewayAuthorizer(t *testing.T) {
test: "no api gateway authorizers",
mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) {
repo.On("ListAllRestApis").Return(apis, nil)
repo.On("ListAllRestApiAuthorizers", apis).Return([]*apigateway.Authorizer{}, nil)
repo.On("ListAllRestApiAuthorizers", *apis[0].Id).Return([]*apigateway.Authorizer{}, nil).Once()
repo.On("ListAllRestApiAuthorizers", *apis[1].Id).Return([]*apigateway.Authorizer{}, nil).Once()
},
assertExpected: func(t *testing.T, got []*resource.Resource) {
assert.Len(t, got, 0)
@ -291,10 +292,12 @@ func TestApiGatewayAuthorizer(t *testing.T) {
test: "multiple api gateway authorizers",
mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) {
repo.On("ListAllRestApis").Return(apis, nil)
repo.On("ListAllRestApiAuthorizers", apis).Return([]*apigateway.Authorizer{
repo.On("ListAllRestApiAuthorizers", *apis[0].Id).Return([]*apigateway.Authorizer{
{Id: awssdk.String("ypcpde")},
}, nil).Once()
repo.On("ListAllRestApiAuthorizers", *apis[1].Id).Return([]*apigateway.Authorizer{
{Id: awssdk.String("bwhebj")},
}, nil)
}, nil).Once()
},
assertExpected: func(t *testing.T, got []*resource.Resource) {
assert.Len(t, got, 2)
@ -318,7 +321,7 @@ func TestApiGatewayAuthorizer(t *testing.T) {
test: "cannot list api gateway resources",
mocks: func(repo *repository.MockApiGatewayRepository, alerter *mocks.AlerterInterface) {
repo.On("ListAllRestApis").Return(apis, nil)
repo.On("ListAllRestApiAuthorizers", apis).Return(nil, dummyError)
repo.On("ListAllRestApiAuthorizers", *apis[0].Id).Return(nil, dummyError)
alerter.On("SendAlert", resourceaws.AwsApiGatewayAuthorizerResourceType, alerts.NewRemoteAccessDeniedAlert(common.RemoteAWSTerraform, remoteerr.NewResourceListingErrorWithType(dummyError, resourceaws.AwsApiGatewayAuthorizerResourceType, resourceaws.AwsApiGatewayAuthorizerResourceType), alerts.EnumerationPhase)).Return()
},
wantErr: remoteerr.NewResourceListingError(dummyError, resourceaws.AwsApiGatewayAuthorizerResourceType),

View File

@ -29,22 +29,27 @@ func (e *ApiGatewayAuthorizerEnumerator) Enumerate() ([]*resource.Resource, erro
return nil, remoteerror.NewResourceListingErrorWithType(err, string(e.SupportedType()), aws.AwsApiGatewayRestApiResourceType)
}
authorizers, err := e.repository.ListAllRestApiAuthorizers(apis)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
results := make([]*resource.Resource, 0)
results := make([]*resource.Resource, len(authorizers))
for _, api := range apis {
a := api
authorizers, err := e.repository.ListAllRestApiAuthorizers(*a.Id)
if err != nil {
return nil, remoteerror.NewResourceListingError(err, string(e.SupportedType()))
}
for _, authorizer := range authorizers {
au := authorizer
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*au.Id,
map[string]interface{}{},
),
)
}
for _, authorizer := range authorizers {
results = append(
results,
e.factory.CreateAbstractResource(
string(e.SupportedType()),
*authorizer.Id,
map[string]interface{}{},
),
)
}
return results, err

View File

@ -13,7 +13,7 @@ type ApiGatewayRepository interface {
ListAllRestApis() ([]*apigateway.RestApi, error)
GetAccount() (*apigateway.Account, error)
ListAllApiKeys() ([]*apigateway.ApiKey, error)
ListAllRestApiAuthorizers([]*apigateway.RestApi) ([]*apigateway.Authorizer, error)
ListAllRestApiAuthorizers(string) ([]*apigateway.Authorizer, error)
ListAllRestApiStages(string) ([]*apigateway.Stage, error)
ListAllRestApiResources(string) ([]*apigateway.Resource, error)
}
@ -86,28 +86,22 @@ func (r *apigatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, error) {
return apiKeys, nil
}
func (r *apigatewayRepository) ListAllRestApiAuthorizers(apis []*apigateway.RestApi) ([]*apigateway.Authorizer, error) {
var authorizers []*apigateway.Authorizer
for _, api := range apis {
a := *api
cacheKey := fmt.Sprintf("apigatewayListAllRestApiAuthorizers_api_%s", *a.Id)
if v := r.cache.Get(cacheKey); v != nil {
authorizers = append(authorizers, v.([]*apigateway.Authorizer)...)
continue
}
input := &apigateway.GetAuthorizersInput{
RestApiId: a.Id,
}
resources, err := r.client.GetAuthorizers(input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Items)
authorizers = append(authorizers, resources.Items...)
func (r *apigatewayRepository) ListAllRestApiAuthorizers(apiId string) ([]*apigateway.Authorizer, error) {
cacheKey := fmt.Sprintf("apigatewayListAllRestApiAuthorizers_api_%s", apiId)
if v := r.cache.Get(cacheKey); v != nil {
return v.([]*apigateway.Authorizer), nil
}
return authorizers, nil
input := &apigateway.GetAuthorizersInput{
RestApiId: &apiId,
}
resources, err := r.client.GetAuthorizers(input)
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, resources.Items)
return resources.Items, nil
}
func (r *apigatewayRepository) ListAllRestApiStages(apiId string) ([]*apigateway.Stage, error) {

View File

@ -211,9 +211,8 @@ func Test_apigatewayRepository_ListAllApiKeys(t *testing.T) {
}
func Test_apigatewayRepository_ListAllRestApiAuthorizers(t *testing.T) {
apis := []*apigateway.RestApi{
{Id: aws.String("restapi1")},
{Id: aws.String("restapi2")},
api := &apigateway.RestApi{
Id: aws.String("restapi1"),
}
apiAuthorizers := []*apigateway.Authorizer{
@ -235,25 +234,17 @@ func Test_apigatewayRepository_ListAllRestApiAuthorizers(t *testing.T) {
client.On("GetAuthorizers",
&apigateway.GetAuthorizersInput{
RestApiId: aws.String("restapi1"),
}).Return(&apigateway.GetAuthorizersOutput{Items: apiAuthorizers[:2]}, nil).Once()
client.On("GetAuthorizers",
&apigateway.GetAuthorizersInput{
RestApiId: aws.String("restapi2"),
}).Return(&apigateway.GetAuthorizersOutput{Items: apiAuthorizers[2:]}, nil).Once()
}).Return(&apigateway.GetAuthorizersOutput{Items: apiAuthorizers}, nil).Once()
store.On("Get", "apigatewayListAllRestApiAuthorizers_api_restapi1").Return(nil).Times(1)
store.On("Put", "apigatewayListAllRestApiAuthorizers_api_restapi1", apiAuthorizers[:2]).Return(false).Times(1)
store.On("Get", "apigatewayListAllRestApiAuthorizers_api_restapi2").Return(nil).Times(1)
store.On("Put", "apigatewayListAllRestApiAuthorizers_api_restapi2", apiAuthorizers[2:]).Return(false).Times(1)
store.On("Put", "apigatewayListAllRestApiAuthorizers_api_restapi1", apiAuthorizers).Return(false).Times(1)
},
want: apiAuthorizers,
},
{
name: "should hit cache",
mocks: func(client *awstest.MockFakeApiGateway, store *cache.MockCache) {
store.On("Get", "apigatewayListAllRestApiAuthorizers_api_restapi1").Return(apiAuthorizers[:2]).Times(1)
store.On("Get", "apigatewayListAllRestApiAuthorizers_api_restapi2").Return(apiAuthorizers[2:]).Times(1)
store.On("Get", "apigatewayListAllRestApiAuthorizers_api_restapi1").Return(apiAuthorizers).Times(1)
},
want: apiAuthorizers,
},
@ -267,7 +258,7 @@ func Test_apigatewayRepository_ListAllRestApiAuthorizers(t *testing.T) {
client: client,
cache: store,
}
got, err := r.ListAllRestApiAuthorizers(apis)
got, err := r.ListAllRestApiAuthorizers(*api.Id)
assert.Equal(t, tt.wantErr, err)
changelog, err := diff.Diff(got, tt.want)

View File

@ -59,11 +59,11 @@ func (_m *MockApiGatewayRepository) ListAllApiKeys() ([]*apigateway.ApiKey, erro
}
// ListAllRestApiAuthorizers provides a mock function with given fields: _a0
func (_m *MockApiGatewayRepository) ListAllRestApiAuthorizers(_a0 []*apigateway.RestApi) ([]*apigateway.Authorizer, error) {
func (_m *MockApiGatewayRepository) ListAllRestApiAuthorizers(_a0 string) ([]*apigateway.Authorizer, error) {
ret := _m.Called(_a0)
var r0 []*apigateway.Authorizer
if rf, ok := ret.Get(0).(func([]*apigateway.RestApi) []*apigateway.Authorizer); ok {
if rf, ok := ret.Get(0).(func(string) []*apigateway.Authorizer); ok {
r0 = rf(_a0)
} else {
if ret.Get(0) != nil {
@ -72,7 +72,7 @@ func (_m *MockApiGatewayRepository) ListAllRestApiAuthorizers(_a0 []*apigateway.
}
var r1 error
if rf, ok := ret.Get(1).(func([]*apigateway.RestApi) error); ok {
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(_a0)
} else {
r1 = ret.Error(1)