diff --git a/pkg/cmd/scan.go b/pkg/cmd/scan.go index b6e11569..453bdbd4 100644 --- a/pkg/cmd/scan.go +++ b/pkg/cmd/scan.go @@ -9,6 +9,7 @@ import ( "syscall" "github.com/cloudskiff/driftctl/pkg/telemetry" + "github.com/mitchellh/go-homedir" "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -85,6 +86,8 @@ func NewScanCmd() *cobra.Command { opts.Quiet, _ = cmd.Flags().GetBool("quiet") opts.DisableTelemetry, _ = cmd.Flags().GetBool("disable-telemetry") + opts.ConfigDir, _ = cmd.Flags().GetString("config-dir") + return nil }, RunE: func(cmd *cobra.Command, args []string) error { @@ -154,6 +157,16 @@ func NewScanCmd() *cobra.Command { "Includes cloud provider service-linked roles (disabled by default)", ) + configDir, err := homedir.Dir() + if err != nil { + configDir = os.TempDir() + } + fl.String( + "config-dir", + configDir, + "Directory path that driftctl uses for configuration.\n", + ) + return cmd } @@ -174,7 +187,7 @@ func scanRun(opts *pkg.ScanOptions) error { resFactory := terraform.NewTerraformResourceFactory(resourceSchemaRepository) - err := remote.Activate(opts.To, opts.ProviderVersion, alerter, providerLibrary, supplierLibrary, scanProgress, resourceSchemaRepository, resFactory) + err := remote.Activate(opts.To, opts.ProviderVersion, alerter, providerLibrary, supplierLibrary, scanProgress, resourceSchemaRepository, resFactory, opts.ConfigDir) if err != nil { return err } diff --git a/pkg/driftctl.go b/pkg/driftctl.go index 2ddbdeb7..47c7e8d7 100644 --- a/pkg/driftctl.go +++ b/pkg/driftctl.go @@ -30,6 +30,7 @@ type ScanOptions struct { StrictMode bool DisableTelemetry bool ProviderVersion string + ConfigDir string } type DriftCTL struct { diff --git a/pkg/iac/terraform/state/terraform_state_reader_test.go b/pkg/iac/terraform/state/terraform_state_reader_test.go index 0b6b824f..f5514e47 100644 --- a/pkg/iac/terraform/state/terraform_state_reader_test.go +++ b/pkg/iac/terraform/state/terraform_state_reader_test.go @@ -110,7 +110,7 @@ func TestTerraformStateReader_AWS_Resources(t *testing.T) { if shouldUpdate { var err error - realProvider, err = aws.NewAWSTerraformProvider("", progress) + realProvider, err = aws.NewAWSTerraformProvider("", progress, "") if err != nil { t.Fatal(err) } @@ -195,7 +195,7 @@ func TestTerraformStateReader_Github_Resources(t *testing.T) { if shouldUpdate { var err error - realProvider, err = github.NewGithubTerraformProvider("", progress) + realProvider, err = github.NewGithubTerraformProvider("", progress, "") if err != nil { t.Fatal(err) } diff --git a/pkg/remote/aws/init.go b/pkg/remote/aws/init.go index 99e6ed3d..3447acba 100644 --- a/pkg/remote/aws/init.go +++ b/pkg/remote/aws/init.go @@ -23,11 +23,13 @@ func Init(version string, alerter *alerter.Alerter, supplierLibrary *resource.SupplierLibrary, progress output.Progress, resourceSchemaRepository *resource.SchemaRepository, - factory resource.ResourceFactory) error { + factory resource.ResourceFactory, + configDir string) error { + if version == "" { version = "3.19.0" } - provider, err := NewAWSTerraformProvider(version, progress) + provider, err := NewAWSTerraformProvider(version, progress, configDir) if err != nil { return err } diff --git a/pkg/remote/aws/init_test.go b/pkg/remote/aws/init_test.go index d1316c90..298a92c5 100644 --- a/pkg/remote/aws/init_test.go +++ b/pkg/remote/aws/init_test.go @@ -8,7 +8,7 @@ import ( func InitTestAwsProvider(providerLibrary *terraform.ProviderLibrary) (*AWSTerraformProvider, error) { progress := &output.MockProgress{} progress.On("Inc").Maybe().Return() - provider, err := NewAWSTerraformProvider("", progress) + provider, err := NewAWSTerraformProvider("", progress, "") if err != nil { return nil, err } diff --git a/pkg/remote/aws/provider.go b/pkg/remote/aws/provider.go index 07e344ff..825ab674 100644 --- a/pkg/remote/aws/provider.go +++ b/pkg/remote/aws/provider.go @@ -42,12 +42,13 @@ type AWSTerraformProvider struct { session *session.Session } -func NewAWSTerraformProvider(version string, progress output.Progress) (*AWSTerraformProvider, error) { +func NewAWSTerraformProvider(version string, progress output.Progress, configDir string) (*AWSTerraformProvider, error) { p := &AWSTerraformProvider{} providerKey := "aws" installer, err := tf.NewProviderInstaller(tf.ProviderConfig{ - Key: providerKey, - Version: version, + Key: providerKey, + Version: version, + ConfigDir: configDir, }) if err != nil { return nil, err diff --git a/pkg/remote/github/init.go b/pkg/remote/github/init.go index 199c2c66..b8bfd881 100644 --- a/pkg/remote/github/init.go +++ b/pkg/remote/github/init.go @@ -21,11 +21,13 @@ func Init(version string, alerter *alerter.Alerter, supplierLibrary *resource.SupplierLibrary, progress output.Progress, resourceSchemaRepository *resource.SchemaRepository, - factory resource.ResourceFactory) error { + factory resource.ResourceFactory, + configDir string) error { if version == "" { version = "4.4.0" } - provider, err := NewGithubTerraformProvider(version, progress) + + provider, err := NewGithubTerraformProvider(version, progress, configDir) if err != nil { return err } diff --git a/pkg/remote/github/init_test.go b/pkg/remote/github/init_test.go index 741e7319..f4d4687a 100644 --- a/pkg/remote/github/init_test.go +++ b/pkg/remote/github/init_test.go @@ -6,7 +6,7 @@ import ( ) func InitTestGithubProvider(providerLibrary *terraform.ProviderLibrary) (*GithubTerraformProvider, error) { - provider, err := NewGithubTerraformProvider("", &output.MockProgress{}) + provider, err := NewGithubTerraformProvider("", &output.MockProgress{}, "") if err != nil { return nil, err } diff --git a/pkg/remote/github/provider.go b/pkg/remote/github/provider.go index 57dd1eca..5cb4b50d 100644 --- a/pkg/remote/github/provider.go +++ b/pkg/remote/github/provider.go @@ -19,12 +19,13 @@ type githubConfig struct { Organization string } -func NewGithubTerraformProvider(version string, progress output.Progress) (*GithubTerraformProvider, error) { +func NewGithubTerraformProvider(version string, progress output.Progress, configDir string) (*GithubTerraformProvider, error) { p := &GithubTerraformProvider{} providerKey := "github" installer, err := tf.NewProviderInstaller(tf.ProviderConfig{ - Key: providerKey, - Version: version, + Key: providerKey, + Version: version, + ConfigDir: configDir, }) if err != nil { return nil, err diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 8879dcd8..f1c9eb86 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -29,12 +29,13 @@ func Activate(remote, version string, alerter *alerter.Alerter, supplierLibrary *resource.SupplierLibrary, progress output.Progress, resourceSchemaRepository *resource.SchemaRepository, - factory resource.ResourceFactory) error { + factory resource.ResourceFactory, + configDir string) error { switch remote { case aws.RemoteAWSTerraform: - return aws.Init(version, alerter, providerLibrary, supplierLibrary, progress, resourceSchemaRepository, factory) + return aws.Init(version, alerter, providerLibrary, supplierLibrary, progress, resourceSchemaRepository, factory, configDir) case github.RemoteGithubTerraform: - return github.Init(version, alerter, providerLibrary, supplierLibrary, progress, resourceSchemaRepository, factory) + return github.Init(version, alerter, providerLibrary, supplierLibrary, progress, resourceSchemaRepository, factory, configDir) default: return errors.Errorf("unsupported remote '%s'", remote) } diff --git a/pkg/terraform/provider_config.go b/pkg/terraform/provider_config.go index 30b012f5..c1d89503 100644 --- a/pkg/terraform/provider_config.go +++ b/pkg/terraform/provider_config.go @@ -6,8 +6,9 @@ import ( ) type ProviderConfig struct { - Key string - Version string + Key string + Version string + ConfigDir string } func (c *ProviderConfig) GetDownloadUrl() string { diff --git a/pkg/terraform/provider_installer.go b/pkg/terraform/provider_installer.go index d7751834..f1bcbdeb 100644 --- a/pkg/terraform/provider_installer.go +++ b/pkg/terraform/provider_installer.go @@ -10,7 +10,6 @@ import ( "strings" error2 "github.com/cloudskiff/driftctl/pkg/terraform/error" - "github.com/mitchellh/go-homedir" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -28,14 +27,10 @@ type ProviderInstaller struct { } func NewProviderInstaller(config ProviderConfig) (*ProviderInstaller, error) { - homedir, err := homedir.Dir() - if err != nil { - homedir = "" - } return &ProviderInstaller{ NewProviderDownloader(), config, - homedir, + config.ConfigDir, }, nil } @@ -81,9 +76,6 @@ func (p *ProviderInstaller) Install() (string, error) { } func (p ProviderInstaller) getProviderDirectory() string { - if p.homeDir == "" { - p.homeDir = os.TempDir() - } return path.Join(p.homeDir, fmt.Sprintf("/.driftctl/plugins/%s_%s/", runtime.GOOS, runtime.GOARCH)) } diff --git a/pkg/terraform/provider_installer_test.go b/pkg/terraform/provider_installer_test.go index 12591ce7..ecf930ac 100644 --- a/pkg/terraform/provider_installer_test.go +++ b/pkg/terraform/provider_installer_test.go @@ -43,33 +43,6 @@ func TestProviderInstallerInstallDoesNotExist(t *testing.T) { } -func TestProviderInstallerInstallWithoutHomeDir(t *testing.T) { - - assert := assert.New(t) - - expectedHomeDir := os.TempDir() - expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH) - config := ProviderConfig{ - Key: "aws", - Version: "3.19.0", - } - - mockDownloader := mocks.ProviderDownloaderInterface{} - mockDownloader.On("Download", config.GetDownloadUrl(), path.Join(expectedHomeDir, expectedSubFolder)).Return(nil) - - installer := ProviderInstaller{ - config: config, - downloader: &mockDownloader, - } - - providerPath, err := installer.Install() - mockDownloader.AssertExpectations(t) - - assert.Nil(err) - assert.Equal(path.Join(expectedHomeDir, expectedSubFolder, config.GetBinaryName()), providerPath) - -} - func TestProviderInstallerInstallAlreadyExist(t *testing.T) { assert := assert.New(t) @@ -204,3 +177,30 @@ func TestProviderInstallerVersionDoesNotExist(t *testing.T) { assert.Equal("Provider version 666.666.666 does not exist", err.Error()) } + +func TestProviderInstallerWithConfigDirectory(t *testing.T) { + + assert := assert.New(t) + fakeTmpHome := t.TempDir() + + expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH) + + config := ProviderConfig{ + Key: "aws", + Version: "3.19.0", + ConfigDir: fakeTmpHome, + } + + mockDownloader := mocks.ProviderDownloaderInterface{} + mockDownloader.On("Download", config.GetDownloadUrl(), path.Join(fakeTmpHome, expectedSubFolder)).Return(nil) + + installer, _ := NewProviderInstaller(config) + installer.downloader = &mockDownloader + + providerPath, err := installer.Install() + mockDownloader.AssertExpectations(t) + + assert.Nil(err) + assert.Equal(path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName()), providerPath) + +}