Merge pull request #629 from cloudskiff/issue_555_directory_location

Add config-dir flag to change .driftctl location
main
Raphaël 2021-06-17 14:40:27 +02:00 committed by GitHub
commit 0eddd54b30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 70 additions and 56 deletions

View File

@ -9,6 +9,7 @@ import (
"syscall"
"github.com/cloudskiff/driftctl/pkg/telemetry"
"github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@ -85,6 +86,8 @@ func NewScanCmd() *cobra.Command {
opts.Quiet, _ = cmd.Flags().GetBool("quiet")
opts.DisableTelemetry, _ = cmd.Flags().GetBool("disable-telemetry")
opts.ConfigDir, _ = cmd.Flags().GetString("config-dir")
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
@ -154,6 +157,16 @@ func NewScanCmd() *cobra.Command {
"Includes cloud provider service-linked roles (disabled by default)",
)
configDir, err := homedir.Dir()
if err != nil {
configDir = os.TempDir()
}
fl.String(
"config-dir",
configDir,
"Directory path that driftctl uses for configuration.\n",
)
return cmd
}
@ -174,7 +187,7 @@ func scanRun(opts *pkg.ScanOptions) error {
resFactory := terraform.NewTerraformResourceFactory(resourceSchemaRepository)
err := remote.Activate(opts.To, opts.ProviderVersion, alerter, providerLibrary, supplierLibrary, scanProgress, resourceSchemaRepository, resFactory)
err := remote.Activate(opts.To, opts.ProviderVersion, alerter, providerLibrary, supplierLibrary, scanProgress, resourceSchemaRepository, resFactory, opts.ConfigDir)
if err != nil {
return err
}

View File

@ -30,6 +30,7 @@ type ScanOptions struct {
StrictMode bool
DisableTelemetry bool
ProviderVersion string
ConfigDir string
}
type DriftCTL struct {

View File

@ -110,7 +110,7 @@ func TestTerraformStateReader_AWS_Resources(t *testing.T) {
if shouldUpdate {
var err error
realProvider, err = aws.NewAWSTerraformProvider("", progress)
realProvider, err = aws.NewAWSTerraformProvider("", progress, "")
if err != nil {
t.Fatal(err)
}
@ -195,7 +195,7 @@ func TestTerraformStateReader_Github_Resources(t *testing.T) {
if shouldUpdate {
var err error
realProvider, err = github.NewGithubTerraformProvider("", progress)
realProvider, err = github.NewGithubTerraformProvider("", progress, "")
if err != nil {
t.Fatal(err)
}

View File

@ -23,11 +23,13 @@ func Init(version string, alerter *alerter.Alerter,
supplierLibrary *resource.SupplierLibrary,
progress output.Progress,
resourceSchemaRepository *resource.SchemaRepository,
factory resource.ResourceFactory) error {
factory resource.ResourceFactory,
configDir string) error {
if version == "" {
version = "3.19.0"
}
provider, err := NewAWSTerraformProvider(version, progress)
provider, err := NewAWSTerraformProvider(version, progress, configDir)
if err != nil {
return err
}

View File

@ -8,7 +8,7 @@ import (
func InitTestAwsProvider(providerLibrary *terraform.ProviderLibrary) (*AWSTerraformProvider, error) {
progress := &output.MockProgress{}
progress.On("Inc").Maybe().Return()
provider, err := NewAWSTerraformProvider("", progress)
provider, err := NewAWSTerraformProvider("", progress, "")
if err != nil {
return nil, err
}

View File

@ -42,12 +42,13 @@ type AWSTerraformProvider struct {
session *session.Session
}
func NewAWSTerraformProvider(version string, progress output.Progress) (*AWSTerraformProvider, error) {
func NewAWSTerraformProvider(version string, progress output.Progress, configDir string) (*AWSTerraformProvider, error) {
p := &AWSTerraformProvider{}
providerKey := "aws"
installer, err := tf.NewProviderInstaller(tf.ProviderConfig{
Key: providerKey,
Version: version,
ConfigDir: configDir,
})
if err != nil {
return nil, err

View File

@ -21,11 +21,13 @@ func Init(version string, alerter *alerter.Alerter,
supplierLibrary *resource.SupplierLibrary,
progress output.Progress,
resourceSchemaRepository *resource.SchemaRepository,
factory resource.ResourceFactory) error {
factory resource.ResourceFactory,
configDir string) error {
if version == "" {
version = "4.4.0"
}
provider, err := NewGithubTerraformProvider(version, progress)
provider, err := NewGithubTerraformProvider(version, progress, configDir)
if err != nil {
return err
}

View File

@ -6,7 +6,7 @@ import (
)
func InitTestGithubProvider(providerLibrary *terraform.ProviderLibrary) (*GithubTerraformProvider, error) {
provider, err := NewGithubTerraformProvider("", &output.MockProgress{})
provider, err := NewGithubTerraformProvider("", &output.MockProgress{}, "")
if err != nil {
return nil, err
}

View File

@ -19,12 +19,13 @@ type githubConfig struct {
Organization string
}
func NewGithubTerraformProvider(version string, progress output.Progress) (*GithubTerraformProvider, error) {
func NewGithubTerraformProvider(version string, progress output.Progress, configDir string) (*GithubTerraformProvider, error) {
p := &GithubTerraformProvider{}
providerKey := "github"
installer, err := tf.NewProviderInstaller(tf.ProviderConfig{
Key: providerKey,
Version: version,
ConfigDir: configDir,
})
if err != nil {
return nil, err

View File

@ -29,12 +29,13 @@ func Activate(remote, version string, alerter *alerter.Alerter,
supplierLibrary *resource.SupplierLibrary,
progress output.Progress,
resourceSchemaRepository *resource.SchemaRepository,
factory resource.ResourceFactory) error {
factory resource.ResourceFactory,
configDir string) error {
switch remote {
case aws.RemoteAWSTerraform:
return aws.Init(version, alerter, providerLibrary, supplierLibrary, progress, resourceSchemaRepository, factory)
return aws.Init(version, alerter, providerLibrary, supplierLibrary, progress, resourceSchemaRepository, factory, configDir)
case github.RemoteGithubTerraform:
return github.Init(version, alerter, providerLibrary, supplierLibrary, progress, resourceSchemaRepository, factory)
return github.Init(version, alerter, providerLibrary, supplierLibrary, progress, resourceSchemaRepository, factory, configDir)
default:
return errors.Errorf("unsupported remote '%s'", remote)
}

View File

@ -8,6 +8,7 @@ import (
type ProviderConfig struct {
Key string
Version string
ConfigDir string
}
func (c *ProviderConfig) GetDownloadUrl() string {

View File

@ -10,7 +10,6 @@ import (
"strings"
error2 "github.com/cloudskiff/driftctl/pkg/terraform/error"
"github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
@ -28,14 +27,10 @@ type ProviderInstaller struct {
}
func NewProviderInstaller(config ProviderConfig) (*ProviderInstaller, error) {
homedir, err := homedir.Dir()
if err != nil {
homedir = ""
}
return &ProviderInstaller{
NewProviderDownloader(),
config,
homedir,
config.ConfigDir,
}, nil
}
@ -81,9 +76,6 @@ func (p *ProviderInstaller) Install() (string, error) {
}
func (p ProviderInstaller) getProviderDirectory() string {
if p.homeDir == "" {
p.homeDir = os.TempDir()
}
return path.Join(p.homeDir, fmt.Sprintf("/.driftctl/plugins/%s_%s/", runtime.GOOS, runtime.GOARCH))
}

View File

@ -43,33 +43,6 @@ func TestProviderInstallerInstallDoesNotExist(t *testing.T) {
}
func TestProviderInstallerInstallWithoutHomeDir(t *testing.T) {
assert := assert.New(t)
expectedHomeDir := os.TempDir()
expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH)
config := ProviderConfig{
Key: "aws",
Version: "3.19.0",
}
mockDownloader := mocks.ProviderDownloaderInterface{}
mockDownloader.On("Download", config.GetDownloadUrl(), path.Join(expectedHomeDir, expectedSubFolder)).Return(nil)
installer := ProviderInstaller{
config: config,
downloader: &mockDownloader,
}
providerPath, err := installer.Install()
mockDownloader.AssertExpectations(t)
assert.Nil(err)
assert.Equal(path.Join(expectedHomeDir, expectedSubFolder, config.GetBinaryName()), providerPath)
}
func TestProviderInstallerInstallAlreadyExist(t *testing.T) {
assert := assert.New(t)
@ -204,3 +177,30 @@ func TestProviderInstallerVersionDoesNotExist(t *testing.T) {
assert.Equal("Provider version 666.666.666 does not exist", err.Error())
}
func TestProviderInstallerWithConfigDirectory(t *testing.T) {
assert := assert.New(t)
fakeTmpHome := t.TempDir()
expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH)
config := ProviderConfig{
Key: "aws",
Version: "3.19.0",
ConfigDir: fakeTmpHome,
}
mockDownloader := mocks.ProviderDownloaderInterface{}
mockDownloader.On("Download", config.GetDownloadUrl(), path.Join(fakeTmpHome, expectedSubFolder)).Return(nil)
installer, _ := NewProviderInstaller(config)
installer.downloader = &mockDownloader
providerPath, err := installer.Install()
mockDownloader.AssertExpectations(t)
assert.Nil(err)
assert.Equal(path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName()), providerPath)
}