diff --git a/pkg/cmd/scan.go b/pkg/cmd/scan.go index 5614d9e3..64cd057f 100644 --- a/pkg/cmd/scan.go +++ b/pkg/cmd/scan.go @@ -147,11 +147,12 @@ func scanRun(opts *pkg.ScanOptions) error { providerLibrary := terraform.NewProviderLibrary() supplierLibrary := resource.NewSupplierLibrary() - progress := globaloutput.NewProgress() + iacProgress := globaloutput.NewProgress("Scanning states", "Scanned states", true) + scanProgress := globaloutput.NewProgress("Scanning resources", "Scanned resources", true) resourceSchemaRepository := resource.NewSchemaRepository() - err := remote.Activate(opts.To, alerter, providerLibrary, supplierLibrary, progress, resourceSchemaRepository) + err := remote.Activate(opts.To, alerter, providerLibrary, supplierLibrary, scanProgress, resourceSchemaRepository) if err != nil { return err } @@ -165,14 +166,14 @@ func scanRun(opts *pkg.ScanOptions) error { scanner := pkg.NewScanner(supplierLibrary.Suppliers(), alerter, resourceSchemaRepository) - iacSupplier, err := supplier.GetIACSupplier(opts.From, providerLibrary, opts.BackendOptions, resourceSchemaRepository) + iacSupplier, err := supplier.GetIACSupplier(opts.From, providerLibrary, opts.BackendOptions, iacProgress, resourceSchemaRepository) if err != nil { return err } resFactory := terraform.NewTerraformResourceFactory(providerLibrary, resourceSchemaRepository) - ctl := pkg.NewDriftCTL(scanner, iacSupplier, alerter, resFactory, opts, resourceSchemaRepository) + ctl := pkg.NewDriftCTL(scanner, iacSupplier, alerter, resFactory, opts, scanProgress, iacProgress, resourceSchemaRepository) go func() { <-c @@ -180,10 +181,7 @@ func scanRun(opts *pkg.ScanOptions) error { ctl.Stop() }() - progress.Start() analysis, err := ctl.Run() - progress.Stop() - if err != nil { return err } diff --git a/pkg/driftctl.go b/pkg/driftctl.go index a44f8293..3b204ac0 100644 --- a/pkg/driftctl.go +++ b/pkg/driftctl.go @@ -3,6 +3,7 @@ package pkg import ( "fmt" + globaloutput "github.com/cloudskiff/driftctl/pkg/output" "github.com/jmespath/go-jmespath" "github.com/sirupsen/logrus" @@ -36,10 +37,12 @@ type DriftCTL struct { filter *jmespath.JMESPath resourceFactory resource.ResourceFactory strictMode bool + scanProgress globaloutput.Progress + iacProgress globaloutput.Progress resourceSchemaRepository resource.SchemaRepositoryInterface } -func NewDriftCTL(remoteSupplier resource.Supplier, iacSupplier resource.Supplier, alerter *alerter.Alerter, resFactory resource.ResourceFactory, opts *ScanOptions, resourceSchemaRepository resource.SchemaRepositoryInterface) *DriftCTL { +func NewDriftCTL(remoteSupplier resource.Supplier, iacSupplier resource.Supplier, alerter *alerter.Alerter, resFactory resource.ResourceFactory, opts *ScanOptions, scanProgress globaloutput.Progress, iacProgress globaloutput.Progress, resourceSchemaRepository resource.SchemaRepositoryInterface) *DriftCTL { return &DriftCTL{ remoteSupplier, iacSupplier, @@ -48,6 +51,8 @@ func NewDriftCTL(remoteSupplier resource.Supplier, iacSupplier resource.Supplier opts.Filter, resFactory, opts.StrictMode, + scanProgress, + iacProgress, resourceSchemaRepository, } } @@ -133,12 +138,16 @@ func (d DriftCTL) Stop() { func (d DriftCTL) scan() (remoteResources []resource.Resource, resourcesFromState []resource.Resource, err error) { logrus.Info("Start reading IaC") + d.iacProgress.Start() resourcesFromState, err = d.iacSupplier.Resources() if err != nil { return nil, nil, err } + d.iacProgress.Stop() logrus.Info("Start scanning cloud provider") + d.scanProgress.Start() + defer d.scanProgress.Stop() remoteResources, err = d.remoteSupplier.Resources() if err != nil { return nil, nil, err diff --git a/pkg/driftctl_test.go b/pkg/driftctl_test.go index 62e6c5f4..a23d21b9 100644 --- a/pkg/driftctl_test.go +++ b/pkg/driftctl_test.go @@ -13,6 +13,7 @@ import ( "github.com/cloudskiff/driftctl/pkg/alerter" "github.com/cloudskiff/driftctl/pkg/analyser" "github.com/cloudskiff/driftctl/pkg/filter" + "github.com/cloudskiff/driftctl/pkg/output" "github.com/cloudskiff/driftctl/pkg/resource" "github.com/cloudskiff/driftctl/pkg/resource/aws" "github.com/cloudskiff/driftctl/pkg/resource/github" @@ -72,11 +73,20 @@ func runTest(t *testing.T, cases TestCases) { c.mocks(resourceFactory) } - driftctl := pkg.NewDriftCTL(remoteSupplier, stateSupplier, testAlerter, resourceFactory, c.options, repo) + scanProgress := &output.MockProgress{} + scanProgress.On("Start").Return().Once() + scanProgress.On("Stop").Return().Once() + + iacProgress := &output.MockProgress{} + iacProgress.On("Start").Return().Once() + iacProgress.On("Stop").Return().Once() + + driftctl := pkg.NewDriftCTL(remoteSupplier, stateSupplier, testAlerter, resourceFactory, c.options, scanProgress, iacProgress, repo) analysis, err := driftctl.Run() c.assert(test.NewScanResult(t, analysis), err) + scanProgress.AssertExpectations(t) }) } } diff --git a/pkg/iac/supplier/supplier.go b/pkg/iac/supplier/supplier.go index 8166c8f1..76ed7b5e 100644 --- a/pkg/iac/supplier/supplier.go +++ b/pkg/iac/supplier/supplier.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/cloudskiff/driftctl/pkg/iac/terraform/state/backend" + "github.com/cloudskiff/driftctl/pkg/output" "github.com/cloudskiff/driftctl/pkg/terraform" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -28,7 +29,7 @@ func IsSupplierSupported(supplierKey string) bool { return false } -func GetIACSupplier(configs []config.SupplierConfig, library *terraform.ProviderLibrary, backendOpts *backend.Options, resourceSchemaRepository resource.SchemaRepositoryInterface) (resource.Supplier, error) { +func GetIACSupplier(configs []config.SupplierConfig, library *terraform.ProviderLibrary, backendOpts *backend.Options, progress output.Progress, resourceSchemaRepository resource.SchemaRepositoryInterface) (resource.Supplier, error) { chainSupplier := resource.NewChainSupplier() for _, config := range configs { if !IsSupplierSupported(config.Key) { @@ -39,7 +40,7 @@ func GetIACSupplier(configs []config.SupplierConfig, library *terraform.Provider var err error switch config.Key { case state.TerraformStateReaderSupplier: - supplier, err = state.NewReader(config, library, backendOpts, resourceSchemaRepository) + supplier, err = state.NewReader(config, library, backendOpts, progress, resourceSchemaRepository) default: return nil, errors.Errorf("Unsupported supplier '%s'", config.Key) } diff --git a/pkg/iac/supplier/supplier_test.go b/pkg/iac/supplier/supplier_test.go index 757be270..d7673a29 100644 --- a/pkg/iac/supplier/supplier_test.go +++ b/pkg/iac/supplier/supplier_test.go @@ -7,6 +7,7 @@ import ( "github.com/cloudskiff/driftctl/pkg/iac/config" "github.com/cloudskiff/driftctl/pkg/iac/terraform/state/backend" + "github.com/cloudskiff/driftctl/pkg/output" "github.com/cloudskiff/driftctl/pkg/terraform" "github.com/cloudskiff/driftctl/test/resource" ) @@ -83,8 +84,12 @@ func TestGetIACSupplier(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + progress := &output.MockProgress{} + progress.On("Start").Return().Times(1) + repo := resource.InitFakeSchemaRepository("aws", "3.19.0") - _, err := GetIACSupplier(tt.args.config, terraform.NewProviderLibrary(), tt.args.options, repo) + + _, err := GetIACSupplier(tt.args.config, terraform.NewProviderLibrary(), tt.args.options, progress, repo) if tt.wantErr != nil && err.Error() != tt.wantErr.Error() { t.Errorf("GetIACSupplier() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/iac/terraform/state/terraform_state_reader.go b/pkg/iac/terraform/state/terraform_state_reader.go index 3537e83e..ff1af569 100644 --- a/pkg/iac/terraform/state/terraform_state_reader.go +++ b/pkg/iac/terraform/state/terraform_state_reader.go @@ -3,6 +3,7 @@ package state import ( "fmt" + "github.com/cloudskiff/driftctl/pkg/output" "github.com/fatih/color" "github.com/hashicorp/terraform/addrs" "github.com/hashicorp/terraform/states" @@ -30,6 +31,7 @@ type TerraformStateReader struct { enumerator enumerator.StateEnumerator deserializers []deserializer.CTYDeserializer backendOptions *backend.Options + progress output.Progress resourceSchemaRepository resource.SchemaRepositoryInterface } @@ -38,8 +40,8 @@ func (r *TerraformStateReader) initReader() error { return nil } -func NewReader(config config.SupplierConfig, library *terraform.ProviderLibrary, backendOpts *backend.Options, resourceSchemaRepository resource.SchemaRepositoryInterface) (*TerraformStateReader, error) { - reader := TerraformStateReader{library: library, config: config, deserializers: iac.Deserializers(), backendOptions: backendOpts, resourceSchemaRepository: resourceSchemaRepository} +func NewReader(config config.SupplierConfig, library *terraform.ProviderLibrary, backendOpts *backend.Options, progress output.Progress, resourceSchemaRepository resource.SchemaRepositoryInterface) (*TerraformStateReader, error) { + reader := TerraformStateReader{library: library, config: config, deserializers: iac.Deserializers(), backendOptions: backendOpts, progress: progress, resourceSchemaRepository: resourceSchemaRepository} err := reader.initReader() if err != nil { return nil, err @@ -227,6 +229,7 @@ func (r *TerraformStateReader) retrieveForState(path string) ([]resource.Resourc "path": r.config.Path, "backend": r.config.Backend, }).Debug("Reading resources from state") + r.progress.Inc() values, err := r.retrieve() if err != nil { return nil, err diff --git a/pkg/iac/terraform/state/terraform_state_reader_test.go b/pkg/iac/terraform/state/terraform_state_reader_test.go index a78905ac..29ed59f9 100644 --- a/pkg/iac/terraform/state/terraform_state_reader_test.go +++ b/pkg/iac/terraform/state/terraform_state_reader_test.go @@ -99,14 +99,16 @@ func TestTerraformStateReader_AWS_Resources(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + progress := &output.MockProgress{} + progress.On("Inc").Return().Times(1) + progress.On("Stop").Return().Times(1) + shouldUpdate := tt.dirName == *goldenfile.Update var realProvider *aws.AWSTerraformProvider if shouldUpdate { var err error - progress := &output.MockProgress{} - progress.On("Inc").Return() realProvider, err = aws.NewAWSTerraformProvider(progress) if err != nil { t.Fatal(err) @@ -130,6 +132,7 @@ func TestTerraformStateReader_AWS_Resources(t *testing.T) { }, library: library, deserializers: iac.Deserializers(), + progress: progress, resourceSchemaRepository: repo, } @@ -180,14 +183,16 @@ func TestTerraformStateReader_Github_Resources(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + progress := &output.MockProgress{} + progress.On("Inc").Return().Times(1) + progress.On("Stop").Return().Times(1) + shouldUpdate := tt.dirName == *goldenfile.Update var realProvider *github.GithubTerraformProvider if shouldUpdate { var err error - progress := &output.MockProgress{} - progress.On("Inc").Return() realProvider, err = github.NewGithubTerraformProvider(progress) if err != nil { t.Fatal(err) @@ -211,6 +216,7 @@ func TestTerraformStateReader_Github_Resources(t *testing.T) { }, library: library, deserializers: iac.Deserializers(), + progress: progress, resourceSchemaRepository: repo, } diff --git a/pkg/output/progress.go b/pkg/output/progress.go index 7019f2b9..24b99894 100644 --- a/pkg/output/progress.go +++ b/pkg/output/progress.go @@ -1,6 +1,7 @@ package output import ( + "fmt" "time" "go.uber.org/atomic" @@ -22,17 +23,31 @@ type Progress interface { Val() uint64 } -type progress struct { - endChan chan struct{} - started *atomic.Bool - count *atomic.Uint64 +type ProgressOptions struct { + LoadingText string + FinishedText string + ShowCount bool } -func NewProgress() *progress { +type progress struct { + endChan chan struct{} + started *atomic.Bool + count *atomic.Uint64 + loadingText string + finishedText string + showCount bool + highestLineLength int +} + +func NewProgress(loadingText, finishedText string, showCount bool) *progress { return &progress{ nil, atomic.NewBool(false), atomic.NewUint64(0), + loadingText, + finishedText, + showCount, + 0, } } @@ -47,7 +62,11 @@ func (p *progress) Start() { func (p *progress) Stop() { if p.started.Swap(false) { - Printf("Scanned resources: (%d)\n", p.count.Load()) + if p.showCount { + p.printf("%s (%d)\n", p.finishedText, p.count.Load()) + } else { + p.printf("%s\r", p.finishedText) + } close(p.endChan) } } @@ -67,7 +86,7 @@ func (p *progress) Val() uint64 { func (p *progress) render() { i := -1 - Printf("Scanning resources:\r") + p.printf("%s\r", p.loadingText) for { select { case <-p.endChan: @@ -77,7 +96,11 @@ func (p *progress) render() { if i >= len(spinner) { i = 0 } - Printf("Scanning resources: %s (%d)\r", spinner[i], p.count.Load()) + if p.showCount { + p.printf("%s %s (%d)\r", p.loadingText, spinner[i], p.count.Load()) + } else { + p.printf("%s %s\r", p.loadingText, spinner[i]) + } } } } @@ -101,3 +124,20 @@ Loop: } logrus.Debug("Progress did not receive any tic. Stopping...") } + +func (p *progress) flush() { + for i := 0; i < p.highestLineLength; i++ { + Printf(" ") + } + Printf("\r") +} + +func (p *progress) printf(format string, args ...interface{}) { + txt := fmt.Sprintf(format, args...) + length := len(txt) + if length > p.highestLineLength { + p.highestLineLength = length + } + p.flush() + Printf(txt) +} diff --git a/pkg/output/progress_test.go b/pkg/output/progress_test.go index 262805d2..18d41e82 100644 --- a/pkg/output/progress_test.go +++ b/pkg/output/progress_test.go @@ -8,7 +8,7 @@ import ( ) func TestProgressTimeoutDoesNotInc(t *testing.T) { - progress := NewProgress() + progress := NewProgress("loading", "loaded", false) progress.Start() progress.Inc() progress.Stop() // should not hang @@ -21,7 +21,7 @@ func TestProgressTimeoutDoesNotInc(t *testing.T) { } func TestProgressTimeoutDoesNotHang(t *testing.T) { - progress := NewProgress() + progress := NewProgress("loading", "loaded", false) progress.Start() time.Sleep(progressTimeout) for progress.started.Load() == true { @@ -32,7 +32,7 @@ func TestProgressTimeoutDoesNotHang(t *testing.T) { } func TestProgress(t *testing.T) { - progress := NewProgress() + progress := NewProgress("loading", "loaded", false) progress.Start() progress.Inc() progress.Inc()