driftctl/pkg/remote/aws/repository/kms_repository.go

147 lines
3.8 KiB
Go

package repository
import (
"fmt"
"strings"
"sync"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/kms/kmsiface"
"github.com/cloudskiff/driftctl/pkg/remote/cache"
"github.com/sirupsen/logrus"
)
type KMSRepository interface {
ListAllKeys() ([]*kms.KeyListEntry, error)
ListAllAliases() ([]*kms.AliasListEntry, error)
}
type kmsRepository struct {
client kmsiface.KMSAPI
cache cache.Cache
describeKeyLock *sync.Mutex
}
func NewKMSRepository(session *session.Session, c cache.Cache) *kmsRepository {
return &kmsRepository{
kms.New(session),
c,
&sync.Mutex{},
}
}
func (r *kmsRepository) ListAllKeys() ([]*kms.KeyListEntry, error) {
if v := r.cache.Get("kmsListAllKeys"); v != nil {
return v.([]*kms.KeyListEntry), nil
}
var keys []*kms.KeyListEntry
input := kms.ListKeysInput{}
err := r.client.ListKeysPages(&input,
func(resp *kms.ListKeysOutput, lastPage bool) bool {
keys = append(keys, resp.Keys...)
return !lastPage
},
)
if err != nil {
return nil, err
}
customerKeys, err := r.filterKeys(keys)
if err != nil {
return nil, err
}
r.cache.Put("kmsListAllKeys", customerKeys)
return customerKeys, nil
}
func (r *kmsRepository) ListAllAliases() ([]*kms.AliasListEntry, error) {
if v := r.cache.Get("kmsListAllAliases"); v != nil {
return v.([]*kms.AliasListEntry), nil
}
var aliases []*kms.AliasListEntry
input := kms.ListAliasesInput{}
err := r.client.ListAliasesPages(&input,
func(resp *kms.ListAliasesOutput, lastPage bool) bool {
aliases = append(aliases, resp.Aliases...)
return !lastPage
},
)
if err != nil {
return nil, err
}
result, err := r.filterAliases(aliases)
if err != nil {
return nil, err
}
r.cache.Put("kmsListAllAliases", result)
return result, nil
}
func (r *kmsRepository) describeKey(keyId *string) (*kms.DescribeKeyOutput, error) {
var results interface{}
// Since this method can be call in parallel, we should lock and unlock if we want to be sure to hit the cache
r.describeKeyLock.Lock()
defer r.describeKeyLock.Unlock()
cacheKey := fmt.Sprintf("kmsDescribeKey-%s", *keyId)
results = r.cache.Get(cacheKey)
if results == nil {
var err error
results, err = r.client.DescribeKey(&kms.DescribeKeyInput{KeyId: keyId})
if err != nil {
return nil, err
}
r.cache.Put(cacheKey, results)
}
describeKey := results.(*kms.DescribeKeyOutput)
if aws.StringValue(describeKey.KeyMetadata.KeyState) == kms.KeyStatePendingDeletion {
return nil, nil
}
return describeKey, nil
}
func (r *kmsRepository) filterKeys(keys []*kms.KeyListEntry) ([]*kms.KeyListEntry, error) {
var customerKeys []*kms.KeyListEntry
for _, key := range keys {
k, err := r.describeKey(key.KeyId)
if err != nil {
return nil, err
}
if k == nil {
logrus.WithFields(logrus.Fields{
"id": *key.KeyId,
}).Debug("Ignored kms key from listing since it is pending from deletion")
continue
}
if k.KeyMetadata.KeyManager != nil && *k.KeyMetadata.KeyManager != "AWS" {
customerKeys = append(customerKeys, key)
}
}
return customerKeys, nil
}
func (r *kmsRepository) filterAliases(aliases []*kms.AliasListEntry) ([]*kms.AliasListEntry, error) {
var customerAliases []*kms.AliasListEntry
for _, alias := range aliases {
if alias.AliasName != nil && !strings.HasPrefix(*alias.AliasName, "alias/aws/") {
k, err := r.describeKey(alias.TargetKeyId)
if err != nil {
return nil, err
}
if k == nil {
logrus.WithFields(logrus.Fields{
"id": *alias.TargetKeyId,
"alias": *alias.AliasName,
}).Debug("Ignored kms key alias from listing since it is linked to a pending from deletion key")
continue
}
customerAliases = append(customerAliases, alias)
}
}
return customerAliases, nil
}