Merge pull request #671 from cloudskiff/issue_615_AWS_profile_to_s3
Add a way to override s3 config/profile with env varmain
commit
106366a628
|
@ -0,0 +1,50 @@
|
|||
package envproxy
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type EnvProxy struct {
|
||||
fromPrefix string
|
||||
toPrefix string
|
||||
defaultEnv map[string]string
|
||||
}
|
||||
|
||||
func NewEnvProxy(fromPrefix, toPrefix string) *EnvProxy {
|
||||
envMap := map[string]string{}
|
||||
for _, variable := range os.Environ() {
|
||||
tmp := strings.SplitN(variable, "=", 2)
|
||||
envMap[tmp[0]] = tmp[1]
|
||||
}
|
||||
return &EnvProxy{
|
||||
fromPrefix: fromPrefix,
|
||||
toPrefix: toPrefix,
|
||||
defaultEnv: envMap,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *EnvProxy) Apply() {
|
||||
if s.fromPrefix == "" || s.toPrefix == "" {
|
||||
return
|
||||
}
|
||||
for key, value := range s.defaultEnv {
|
||||
if strings.HasPrefix(key, s.fromPrefix) {
|
||||
key = strings.Replace(key, s.fromPrefix, s.toPrefix, 1)
|
||||
os.Setenv(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *EnvProxy) Restore() {
|
||||
if s.fromPrefix == "" || s.toPrefix == "" {
|
||||
return
|
||||
}
|
||||
for key, value := range s.defaultEnv {
|
||||
if strings.HasPrefix(key, s.fromPrefix) {
|
||||
key = strings.Replace(key, s.fromPrefix, s.toPrefix, 1)
|
||||
value = s.defaultEnv[key]
|
||||
}
|
||||
os.Setenv(key, value)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
package envproxy
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEnvProxy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
proxyArgs []string
|
||||
initialEnv []string
|
||||
modifiedEnv []string
|
||||
}{
|
||||
{
|
||||
name: "Without args on SetProxy",
|
||||
proxyArgs: []string{"", ""},
|
||||
initialEnv: []string{"TEST_DCTL_S3_PROFILE=test_dctl_s3_profile", "TEST_AWS_PROFILE=test_aws_profile"},
|
||||
modifiedEnv: []string{"TEST_DCTL_S3_PROFILE=test_dctl_s3_profile", "TEST_AWS_PROFILE=test_aws_profile"},
|
||||
},
|
||||
{
|
||||
name: "With args on SetProxy",
|
||||
proxyArgs: []string{"TEST_DCTL_S3_", "TEST_AWS_"},
|
||||
initialEnv: []string{"TEST_DCTL_S3_PROFILE=test_dctl_s3_profile", "TEST_AWS_PROFILE=test_aws_profile"},
|
||||
modifiedEnv: []string{"TEST_DCTL_S3_PROFILE=test_dctl_s3_profile", "TEST_AWS_PROFILE=test_dctl_s3_profile"},
|
||||
},
|
||||
{
|
||||
name: "Without toPrefix on SetProxy",
|
||||
proxyArgs: []string{"TEST_DCTL_S3_", ""},
|
||||
initialEnv: []string{"TEST_DCTL_S3_PROFILE=test_dctl_s3_profile", "TEST_AWS_PROFILE=test_aws_profile"},
|
||||
modifiedEnv: []string{"TEST_DCTL_S3_PROFILE=test_dctl_s3_profile", "TEST_AWS_PROFILE=test_aws_profile"},
|
||||
},
|
||||
{
|
||||
name: "Without fromPrefix on SetProxy",
|
||||
proxyArgs: []string{"", "TEST_AWS_"},
|
||||
initialEnv: []string{"TEST_DCTL_S3_PROFILE=test_dctl_s3_profile", "TEST_AWS_PROFILE=test_aws_profile"},
|
||||
modifiedEnv: []string{"TEST_DCTL_S3_PROFILE=test_dctl_s3_profile", "TEST_AWS_PROFILE=test_aws_profile"},
|
||||
},
|
||||
{
|
||||
name: "Without initialEnv",
|
||||
proxyArgs: []string{"TEST_DCTL_S3_", "TEST_AWS_"},
|
||||
initialEnv: []string{},
|
||||
modifiedEnv: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
for _, value := range tt.initialEnv {
|
||||
tmp := strings.SplitN(value, "=", 2)
|
||||
os.Setenv(tmp[0], tmp[1])
|
||||
}
|
||||
|
||||
envProxy := NewEnvProxy(tt.proxyArgs[0], tt.proxyArgs[1])
|
||||
|
||||
envProxy.Apply()
|
||||
|
||||
currentEnv := os.Environ()
|
||||
if !compareEnv(currentEnv, tt.modifiedEnv) {
|
||||
t.Errorf("Expected %v, got %v", tt.modifiedEnv, currentEnv)
|
||||
}
|
||||
|
||||
envProxy.Restore()
|
||||
|
||||
currentEnv = os.Environ()
|
||||
if !compareEnv(currentEnv, tt.initialEnv) {
|
||||
t.Errorf("Expected %v, got %v", tt.initialEnv, currentEnv)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func compareEnv(currentEnv, testEnv []string) bool {
|
||||
isValid := 0
|
||||
for _, initialValue := range testEnv {
|
||||
for _, value := range currentEnv {
|
||||
if initialValue == value {
|
||||
isValid++
|
||||
}
|
||||
}
|
||||
}
|
||||
return isValid == len(testEnv)
|
||||
}
|
|
@ -5,6 +5,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/cloudskiff/driftctl/pkg/envproxy"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
|
@ -33,9 +34,12 @@ func NewS3Reader(path string) (*S3Backend, error) {
|
|||
Key: &key,
|
||||
Bucket: &bucket,
|
||||
}
|
||||
envProxy := envproxy.NewEnvProxy("DCTL_S3_", "AWS_")
|
||||
envProxy.Apply()
|
||||
sess := session.Must(session.NewSessionWithOptions(session.Options{
|
||||
SharedConfigState: session.SharedConfigEnable,
|
||||
}))
|
||||
envProxy.Restore()
|
||||
backend.S3Client = s3.New(sess)
|
||||
return &backend, nil
|
||||
}
|
||||
|
|
|
@ -67,6 +67,31 @@ func TestNewS3Reader(t *testing.T) {
|
|||
)
|
||||
}
|
||||
|
||||
func TestNewS3ReaderWithEnvProxy(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
os.Setenv("AWS_DEFAULT_REGION", "us-east-1")
|
||||
os.Setenv("DCTL_S3_DEFAULT_REGION", "eu-west-3")
|
||||
reader, err := NewS3Reader("sample_bucket/path/to/state.tfstate")
|
||||
|
||||
got := reader.S3Client.(*s3.S3).Config.Region
|
||||
if aws.StringValue(got) != "eu-west-3" {
|
||||
t.Errorf("NewS3Reader().S3Client.Config.Region got = %v, want %v", aws.StringValue(got), "eu-west-3")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
assert.Equal(
|
||||
"path/to/state.tfstate",
|
||||
*reader.input.Key,
|
||||
)
|
||||
assert.Equal(
|
||||
"sample_bucket",
|
||||
*reader.input.Bucket,
|
||||
)
|
||||
}
|
||||
|
||||
func TestS3Backend_ReadWithError(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
fakeS3 := &awstest.MockFakeS3{}
|
||||
|
|
|
@ -8,26 +8,25 @@ import (
|
|||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3iface"
|
||||
"github.com/cloudskiff/driftctl/pkg/envproxy"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/bmatcuk/doublestar/v4"
|
||||
"github.com/cloudskiff/driftctl/pkg/iac/config"
|
||||
)
|
||||
|
||||
type S3EnumeratorConfig struct {
|
||||
Bucket *string
|
||||
Prefix *string
|
||||
}
|
||||
|
||||
type S3Enumerator struct {
|
||||
config config.SupplierConfig
|
||||
client s3iface.S3API
|
||||
}
|
||||
|
||||
func NewS3Enumerator(config config.SupplierConfig) *S3Enumerator {
|
||||
envProxy := envproxy.NewEnvProxy("DCTL_S3_", "AWS_")
|
||||
envProxy.Apply()
|
||||
sess := session.Must(session.NewSessionWithOptions(session.Options{
|
||||
SharedConfigState: session.SharedConfigEnable,
|
||||
}))
|
||||
envProxy.Restore()
|
||||
return &S3Enumerator{
|
||||
config,
|
||||
s3.New(sess),
|
||||
|
|
|
@ -2,6 +2,7 @@ package enumerator
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
|
@ -12,6 +13,52 @@ import (
|
|||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestS3Enumerator_NewS3Enumerator(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config config.SupplierConfig
|
||||
setEnv map[string]string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "test with no proxy env var",
|
||||
config: config.SupplierConfig{
|
||||
Key: "tfstate",
|
||||
Backend: "s3",
|
||||
Path: "terraform.tfstate",
|
||||
},
|
||||
setEnv: map[string]string{
|
||||
"AWS_DEFAULT_REGION": "us-east-1",
|
||||
},
|
||||
want: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "test with proxy env var",
|
||||
config: config.SupplierConfig{
|
||||
Key: "tfstate",
|
||||
Backend: "s3",
|
||||
Path: "terraform.tfstate",
|
||||
},
|
||||
setEnv: map[string]string{
|
||||
"AWS_DEFAULT_REGION": "us-east-1",
|
||||
"DCTL_S3_DEFAULT_REGION": "eu-west-3",
|
||||
},
|
||||
want: "eu-west-3",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for key, value := range tt.setEnv {
|
||||
os.Setenv(key, value)
|
||||
}
|
||||
got := NewS3Enumerator(tt.config).client.(*s3.S3).Config.Region
|
||||
if awssdk.StringValue(got) != tt.want {
|
||||
t.Errorf("NewS3Enumerator().client.Config.Region got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestS3Enumerator_Enumerate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
Loading…
Reference in New Issue