Merge pull request #629 from cloudskiff/issue_555_directory_location
Add config-dir flag to change .driftctl locationmain
commit
0eddd54b30
|
@ -9,6 +9,7 @@ import (
|
|||
"syscall"
|
||||
|
||||
"github.com/cloudskiff/driftctl/pkg/telemetry"
|
||||
"github.com/mitchellh/go-homedir"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
@ -85,6 +86,8 @@ func NewScanCmd() *cobra.Command {
|
|||
opts.Quiet, _ = cmd.Flags().GetBool("quiet")
|
||||
opts.DisableTelemetry, _ = cmd.Flags().GetBool("disable-telemetry")
|
||||
|
||||
opts.ConfigDir, _ = cmd.Flags().GetString("config-dir")
|
||||
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
|
@ -154,6 +157,16 @@ func NewScanCmd() *cobra.Command {
|
|||
"Includes cloud provider service-linked roles (disabled by default)",
|
||||
)
|
||||
|
||||
configDir, err := homedir.Dir()
|
||||
if err != nil {
|
||||
configDir = os.TempDir()
|
||||
}
|
||||
fl.String(
|
||||
"config-dir",
|
||||
configDir,
|
||||
"Directory path that driftctl uses for configuration.\n",
|
||||
)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
|
@ -174,7 +187,7 @@ func scanRun(opts *pkg.ScanOptions) error {
|
|||
|
||||
resFactory := terraform.NewTerraformResourceFactory(resourceSchemaRepository)
|
||||
|
||||
err := remote.Activate(opts.To, opts.ProviderVersion, alerter, providerLibrary, supplierLibrary, scanProgress, resourceSchemaRepository, resFactory)
|
||||
err := remote.Activate(opts.To, opts.ProviderVersion, alerter, providerLibrary, supplierLibrary, scanProgress, resourceSchemaRepository, resFactory, opts.ConfigDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@ type ScanOptions struct {
|
|||
StrictMode bool
|
||||
DisableTelemetry bool
|
||||
ProviderVersion string
|
||||
ConfigDir string
|
||||
}
|
||||
|
||||
type DriftCTL struct {
|
||||
|
|
|
@ -110,7 +110,7 @@ func TestTerraformStateReader_AWS_Resources(t *testing.T) {
|
|||
|
||||
if shouldUpdate {
|
||||
var err error
|
||||
realProvider, err = aws.NewAWSTerraformProvider("", progress)
|
||||
realProvider, err = aws.NewAWSTerraformProvider("", progress, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -195,7 +195,7 @@ func TestTerraformStateReader_Github_Resources(t *testing.T) {
|
|||
|
||||
if shouldUpdate {
|
||||
var err error
|
||||
realProvider, err = github.NewGithubTerraformProvider("", progress)
|
||||
realProvider, err = github.NewGithubTerraformProvider("", progress, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -23,11 +23,13 @@ func Init(version string, alerter *alerter.Alerter,
|
|||
supplierLibrary *resource.SupplierLibrary,
|
||||
progress output.Progress,
|
||||
resourceSchemaRepository *resource.SchemaRepository,
|
||||
factory resource.ResourceFactory) error {
|
||||
factory resource.ResourceFactory,
|
||||
configDir string) error {
|
||||
|
||||
if version == "" {
|
||||
version = "3.19.0"
|
||||
}
|
||||
provider, err := NewAWSTerraformProvider(version, progress)
|
||||
provider, err := NewAWSTerraformProvider(version, progress, configDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
func InitTestAwsProvider(providerLibrary *terraform.ProviderLibrary) (*AWSTerraformProvider, error) {
|
||||
progress := &output.MockProgress{}
|
||||
progress.On("Inc").Maybe().Return()
|
||||
provider, err := NewAWSTerraformProvider("", progress)
|
||||
provider, err := NewAWSTerraformProvider("", progress, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -42,12 +42,13 @@ type AWSTerraformProvider struct {
|
|||
session *session.Session
|
||||
}
|
||||
|
||||
func NewAWSTerraformProvider(version string, progress output.Progress) (*AWSTerraformProvider, error) {
|
||||
func NewAWSTerraformProvider(version string, progress output.Progress, configDir string) (*AWSTerraformProvider, error) {
|
||||
p := &AWSTerraformProvider{}
|
||||
providerKey := "aws"
|
||||
installer, err := tf.NewProviderInstaller(tf.ProviderConfig{
|
||||
Key: providerKey,
|
||||
Version: version,
|
||||
Key: providerKey,
|
||||
Version: version,
|
||||
ConfigDir: configDir,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -21,11 +21,13 @@ func Init(version string, alerter *alerter.Alerter,
|
|||
supplierLibrary *resource.SupplierLibrary,
|
||||
progress output.Progress,
|
||||
resourceSchemaRepository *resource.SchemaRepository,
|
||||
factory resource.ResourceFactory) error {
|
||||
factory resource.ResourceFactory,
|
||||
configDir string) error {
|
||||
if version == "" {
|
||||
version = "4.4.0"
|
||||
}
|
||||
provider, err := NewGithubTerraformProvider(version, progress)
|
||||
|
||||
provider, err := NewGithubTerraformProvider(version, progress, configDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
)
|
||||
|
||||
func InitTestGithubProvider(providerLibrary *terraform.ProviderLibrary) (*GithubTerraformProvider, error) {
|
||||
provider, err := NewGithubTerraformProvider("", &output.MockProgress{})
|
||||
provider, err := NewGithubTerraformProvider("", &output.MockProgress{}, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -19,12 +19,13 @@ type githubConfig struct {
|
|||
Organization string
|
||||
}
|
||||
|
||||
func NewGithubTerraformProvider(version string, progress output.Progress) (*GithubTerraformProvider, error) {
|
||||
func NewGithubTerraformProvider(version string, progress output.Progress, configDir string) (*GithubTerraformProvider, error) {
|
||||
p := &GithubTerraformProvider{}
|
||||
providerKey := "github"
|
||||
installer, err := tf.NewProviderInstaller(tf.ProviderConfig{
|
||||
Key: providerKey,
|
||||
Version: version,
|
||||
Key: providerKey,
|
||||
Version: version,
|
||||
ConfigDir: configDir,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -29,12 +29,13 @@ func Activate(remote, version string, alerter *alerter.Alerter,
|
|||
supplierLibrary *resource.SupplierLibrary,
|
||||
progress output.Progress,
|
||||
resourceSchemaRepository *resource.SchemaRepository,
|
||||
factory resource.ResourceFactory) error {
|
||||
factory resource.ResourceFactory,
|
||||
configDir string) error {
|
||||
switch remote {
|
||||
case aws.RemoteAWSTerraform:
|
||||
return aws.Init(version, alerter, providerLibrary, supplierLibrary, progress, resourceSchemaRepository, factory)
|
||||
return aws.Init(version, alerter, providerLibrary, supplierLibrary, progress, resourceSchemaRepository, factory, configDir)
|
||||
case github.RemoteGithubTerraform:
|
||||
return github.Init(version, alerter, providerLibrary, supplierLibrary, progress, resourceSchemaRepository, factory)
|
||||
return github.Init(version, alerter, providerLibrary, supplierLibrary, progress, resourceSchemaRepository, factory, configDir)
|
||||
default:
|
||||
return errors.Errorf("unsupported remote '%s'", remote)
|
||||
}
|
||||
|
|
|
@ -6,8 +6,9 @@ import (
|
|||
)
|
||||
|
||||
type ProviderConfig struct {
|
||||
Key string
|
||||
Version string
|
||||
Key string
|
||||
Version string
|
||||
ConfigDir string
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetDownloadUrl() string {
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
"strings"
|
||||
|
||||
error2 "github.com/cloudskiff/driftctl/pkg/terraform/error"
|
||||
"github.com/mitchellh/go-homedir"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
|
@ -28,14 +27,10 @@ type ProviderInstaller struct {
|
|||
}
|
||||
|
||||
func NewProviderInstaller(config ProviderConfig) (*ProviderInstaller, error) {
|
||||
homedir, err := homedir.Dir()
|
||||
if err != nil {
|
||||
homedir = ""
|
||||
}
|
||||
return &ProviderInstaller{
|
||||
NewProviderDownloader(),
|
||||
config,
|
||||
homedir,
|
||||
config.ConfigDir,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -81,9 +76,6 @@ func (p *ProviderInstaller) Install() (string, error) {
|
|||
}
|
||||
|
||||
func (p ProviderInstaller) getProviderDirectory() string {
|
||||
if p.homeDir == "" {
|
||||
p.homeDir = os.TempDir()
|
||||
}
|
||||
return path.Join(p.homeDir, fmt.Sprintf("/.driftctl/plugins/%s_%s/", runtime.GOOS, runtime.GOARCH))
|
||||
}
|
||||
|
||||
|
|
|
@ -43,33 +43,6 @@ func TestProviderInstallerInstallDoesNotExist(t *testing.T) {
|
|||
|
||||
}
|
||||
|
||||
func TestProviderInstallerInstallWithoutHomeDir(t *testing.T) {
|
||||
|
||||
assert := assert.New(t)
|
||||
|
||||
expectedHomeDir := os.TempDir()
|
||||
expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH)
|
||||
config := ProviderConfig{
|
||||
Key: "aws",
|
||||
Version: "3.19.0",
|
||||
}
|
||||
|
||||
mockDownloader := mocks.ProviderDownloaderInterface{}
|
||||
mockDownloader.On("Download", config.GetDownloadUrl(), path.Join(expectedHomeDir, expectedSubFolder)).Return(nil)
|
||||
|
||||
installer := ProviderInstaller{
|
||||
config: config,
|
||||
downloader: &mockDownloader,
|
||||
}
|
||||
|
||||
providerPath, err := installer.Install()
|
||||
mockDownloader.AssertExpectations(t)
|
||||
|
||||
assert.Nil(err)
|
||||
assert.Equal(path.Join(expectedHomeDir, expectedSubFolder, config.GetBinaryName()), providerPath)
|
||||
|
||||
}
|
||||
|
||||
func TestProviderInstallerInstallAlreadyExist(t *testing.T) {
|
||||
|
||||
assert := assert.New(t)
|
||||
|
@ -204,3 +177,30 @@ func TestProviderInstallerVersionDoesNotExist(t *testing.T) {
|
|||
|
||||
assert.Equal("Provider version 666.666.666 does not exist", err.Error())
|
||||
}
|
||||
|
||||
func TestProviderInstallerWithConfigDirectory(t *testing.T) {
|
||||
|
||||
assert := assert.New(t)
|
||||
fakeTmpHome := t.TempDir()
|
||||
|
||||
expectedSubFolder := fmt.Sprintf("/.driftctl/plugins/%s_%s", runtime.GOOS, runtime.GOARCH)
|
||||
|
||||
config := ProviderConfig{
|
||||
Key: "aws",
|
||||
Version: "3.19.0",
|
||||
ConfigDir: fakeTmpHome,
|
||||
}
|
||||
|
||||
mockDownloader := mocks.ProviderDownloaderInterface{}
|
||||
mockDownloader.On("Download", config.GetDownloadUrl(), path.Join(fakeTmpHome, expectedSubFolder)).Return(nil)
|
||||
|
||||
installer, _ := NewProviderInstaller(config)
|
||||
installer.downloader = &mockDownloader
|
||||
|
||||
providerPath, err := installer.Install()
|
||||
mockDownloader.AssertExpectations(t)
|
||||
|
||||
assert.Nil(err)
|
||||
assert.Equal(path.Join(fakeTmpHome, expectedSubFolder, config.GetBinaryName()), providerPath)
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue