Initial commit for AWS Catalog (#3372)

dev
Leo Loobeek 2023-03-23 14:06:54 -05:00 committed by GitHub
parent 710ac0839c
commit 6659402042
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 472 additions and 0 deletions

View File

@ -0,0 +1,187 @@
package aws
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
)
// Catalog manages the AWS S3 template catalog
type Catalog struct {
svc client
}
// client interface abstracts S3 connections
type client interface {
getAllKeys() ([]string, error)
downloadKey(name string) (io.ReadCloser, error)
setBucket(bucket string)
}
type s3svc struct {
client *s3.Client
bucket string
}
// NewCatalog creates a new AWS Catalog object given a required S3 bucket name and optional configurations. If
// no configurations to set AWS keys are provided then environment variables will be used to obtain AWS credentials.
func NewCatalog(bucket string, configurations ...func(*Catalog) error) (Catalog, error) {
var c Catalog
for _, configuration := range configurations {
err := configuration(&c)
if err != nil {
return c, err
}
}
if c.svc == nil {
cfg, err := config.LoadDefaultConfig(context.TODO())
if err != nil {
return c, err
}
c.svc = &s3svc{
client: s3.NewFromConfig(cfg),
}
}
c.svc.setBucket(bucket)
return c, nil
}
// WithAWSKeys enables explicitly setting the AWS access key, secret key and region
func WithAWSKeys(accessKey, secretKey, region string) func(*Catalog) error {
return func(c *Catalog) error {
cfg, err := config.LoadDefaultConfig(context.TODO(),
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")),
config.WithRegion(region))
if err != nil {
return err
}
c.svc = &s3svc{
client: s3.NewFromConfig(cfg),
bucket: "",
}
return nil
}
}
// OpenFile downloads a file from S3 and returns the contents as an io.ReadCloser
func (c Catalog) OpenFile(filename string) (io.ReadCloser, error) {
if filename == "" {
return nil, errors.New("empty filename")
}
return c.svc.downloadKey(filename)
}
// GetTemplatePath looks for a target string performing a simple substring check
// against all S3 keys. If the input includes a wildcard (*) it is removed.
func (c Catalog) GetTemplatePath(target string) ([]string, error) {
target = strings.ReplaceAll(target, "*", "")
keys, err := c.svc.getAllKeys()
if err != nil {
return nil, err
}
var matches []string
for _, key := range keys {
if strings.Contains(key, target) {
matches = append(matches, key)
}
}
return matches, nil
}
// GetTemplatesPath returns all templates from S3
func (c Catalog) GetTemplatesPath(definitions []string) ([]string, map[string]error) {
keys, err := c.svc.getAllKeys()
if err != nil {
// necessary to implement the Catalog interface
return nil, map[string]error{"aws": err}
}
return keys, nil
}
// ResolvePath gets a full S3 key given the first param. If the second parameter is
// provided it tries to find paths relative to the second path.
func (c Catalog) ResolvePath(templateName, second string) (string, error) {
keys, err := c.svc.getAllKeys()
if err != nil {
return "", err
}
// if c second path is given, it's c folder and we join the two and check against keys
if second != "" {
target := filepath.Join(filepath.Dir(second), templateName)
for _, key := range keys {
if key == target {
return key, nil
}
}
}
// check if templateName is already an absolute path to c key
for _, key := range keys {
if key == templateName {
return templateName, nil
}
}
return "", fmt.Errorf("no such path found: %s", templateName)
}
func (s *s3svc) getAllKeys() ([]string, error) {
paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
Bucket: &s.bucket,
})
var keys []string
for paginator.HasMorePages() {
page, err := paginator.NextPage(context.TODO())
if err != nil {
return nil, err
}
for _, obj := range page.Contents {
key := aws.ToString(obj.Key)
keys = append(keys, key)
}
}
return keys, nil
}
func (s *s3svc) downloadKey(name string) (io.ReadCloser, error) {
downloader := manager.NewDownloader(s.client)
buf := manager.NewWriteAtBuffer([]byte{})
_, err := downloader.Download(context.TODO(), buf, &s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(name),
})
if err != nil {
return nil, err
}
return io.NopCloser(bytes.NewReader(buf.Bytes())), nil
}
func (s *s3svc) setBucket(bucket string) {
s.bucket = bucket
}

View File

@ -0,0 +1,285 @@
package aws
import (
"github.com/pkg/errors"
"io"
"reflect"
"strings"
"testing"
)
func TestCatalog_GetTemplatePath(t *testing.T) {
type args struct {
target string
}
tests := []struct {
name string
args args
want []string
wantErr bool
}{
{
"get all ssl files",
args{
target: "ssl",
},
[]string{
"ssl/deprecated-tls.yaml",
"ssl/detect-ssl-issuer.yaml",
"ssl/expired-ssl.yaml",
"ssl/mismatched-ssl.yaml",
},
false,
},
{
"get all ssl files with wildcard",
args{
target: "ssl*",
},
[]string{
"ssl/deprecated-tls.yaml",
"ssl/detect-ssl-issuer.yaml",
"ssl/expired-ssl.yaml",
"ssl/mismatched-ssl.yaml",
},
false,
},
{
"non-matching target",
args{
target: "I-DONT-EXIST",
},
[]string{},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, _ := NewCatalog("bucket", withMockS3Service())
got, err := c.GetTemplatePath(tt.args.target)
if (err != nil) != tt.wantErr {
t.Errorf("GetTemplatePath() error = %v, wantErr %v", err, tt.wantErr)
return
}
if len(tt.want) > 0 && !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetTemplatePath() got = %v, want %v", got, tt.want)
}
if len(tt.want) == 0 && len(got) > 0 {
t.Errorf("GetTemplatePath() got = %v, want %v", got, tt.want)
}
})
}
}
func TestCatalog_GetTemplatesPath(t *testing.T) {
tmp := newMockS3Service()
keys, _ := tmp.getAllKeys()
type args struct {
definitions []string
}
tests := []struct {
name string
args args
want []string
wantErr bool
}{
{
"without definitions",
args{
definitions: nil,
},
keys,
false,
},
{
"with definitions",
args{
definitions: []string{"ssl/deprecated-tls.yaml"},
},
keys,
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, _ := NewCatalog("bucket", withMockS3Service())
got, got1 := c.GetTemplatesPath(tt.args.definitions)
if got1 != nil {
val, exists := got1["aws"]
if exists && !tt.wantErr {
t.Errorf("GetTemplatesPath() error = %v, wantErr %v", val, tt.wantErr)
}
if !exists && len(got1) > 0 {
t.Errorf("GetTemplatesPath() should only return one key 'aws': %v", got1)
}
if !exists && tt.wantErr {
t.Errorf("GetTemplatesPath() error = %v, wantErr %v", val, tt.wantErr)
}
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetTemplatesPath() got = %v, want %v", got, tt.want)
}
})
}
}
func TestCatalog_OpenFile(t *testing.T) {
tests := []struct {
name string
filename string
wantErr bool
}{
{
"valid key",
"ssl/deprecated-tls.yaml",
false,
},
{
"non-existent key",
"something/that-doesnt-exist.yaml",
true,
},
{
"path to folder",
"cves/2023",
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, _ := NewCatalog("bucket", withMockS3Service())
got, err := c.OpenFile(tt.filename)
if (err != nil) != tt.wantErr {
t.Errorf("OpenFile() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil && got == nil {
t.Error("OpenFile() didn't return error but io.ReadCloser is nil")
}
})
}
}
func TestCatalog_ResolvePath(t *testing.T) {
type args struct {
templateName string
second string
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
"absolute path",
args{
"ssl/deprecated-tls.yaml",
"",
},
"ssl/deprecated-tls.yaml",
false,
},
{
"relative path with second param",
args{
"deprecated-tls.yaml",
"ssl/",
},
"ssl/deprecated-tls.yaml",
false,
},
{
"relative path and no second param",
args{
"cves/2023",
"",
},
"",
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, _ := NewCatalog("bucket", withMockS3Service())
got, err := c.ResolvePath(tt.args.templateName, tt.args.second)
if (err != nil) != tt.wantErr {
t.Errorf("ResolvePath() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ResolvePath() got = %v, want %v", got, tt.want)
}
})
}
}
func withMockS3Service() func(*Catalog) error {
return func(c *Catalog) error {
c.svc = newMockS3Service()
return nil
}
}
type mocks3svc struct {
keys []string
}
func newMockS3Service() mocks3svc {
return mocks3svc{
keys: []string{
"ssl/deprecated-tls.yaml",
"ssl/detect-ssl-issuer.yaml",
"ssl/expired-ssl.yaml",
"ssl/mismatched-ssl.yaml",
"cves/2023/CVE-2023-0669.yaml",
"cves/2023/CVE-2023-23488.yaml",
"cves/2023/CVE-2023-23489.yaml",
},
}
}
func (m mocks3svc) getAllKeys() ([]string, error) {
return m.keys, nil
}
func (m mocks3svc) downloadKey(name string) (io.ReadCloser, error) {
found := false
for _, key := range m.keys {
if key == name {
found = true
break
}
}
if !found {
return nil, errors.New("key not found")
}
sample := `
id: git-config
info:
name: Git Config File
author: Ice3man
severity: medium
description: Searches for the pattern /.git/config on passed URLs.
requests:
- method: GET
path:
- "{{BaseURL}}/.git/config"
matchers:
- type: word
words:
- "[core]"
`
return io.NopCloser(strings.NewReader(sample)), nil
}
func (m mocks3svc) setBucket(bucket string) {}