refactor: simplify memstore testing

main
sundowndev 2021-07-12 15:34:06 +02:00
parent b0b9d13b38
commit aebdbc688a
5 changed files with 80 additions and 195 deletions

View File

@ -240,7 +240,7 @@ func scanRun(opts *pkg.ScanOptions) error {
globaloutput.Printf(color.WhiteString("Provider version used to scan: %s. Use --tf-provider-version to use another version.\n"), resourceSchemaRepository.ProviderVersion.String()) globaloutput.Printf(color.WhiteString("Provider version used to scan: %s. Use --tf-provider-version to use another version.\n"), resourceSchemaRepository.ProviderVersion.String())
if !opts.DisableTelemetry { if !opts.DisableTelemetry {
telemetry.SendTelemetry(store) telemetry.SendTelemetry(store.Bucket(memstore.TelemetryBucket))
} }
if !analysis.IsSync() { if !analysis.IsSync() {

View File

@ -2,12 +2,10 @@ package pkg
import ( import (
"fmt" "fmt"
"runtime"
"time" "time"
"github.com/cloudskiff/driftctl/pkg/memstore" "github.com/cloudskiff/driftctl/pkg/memstore"
globaloutput "github.com/cloudskiff/driftctl/pkg/output" globaloutput "github.com/cloudskiff/driftctl/pkg/output"
"github.com/cloudskiff/driftctl/pkg/version"
"github.com/jmespath/go-jmespath" "github.com/jmespath/go-jmespath"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -144,9 +142,6 @@ func (d DriftCTL) Run() (*analyser.Analysis, error) {
analysis.Duration = time.Since(start) analysis.Duration = time.Since(start)
analysis.Date = time.Now() analysis.Date = time.Now()
d.store.Bucket(memstore.TelemetryBucket).Set("version", version.Current())
d.store.Bucket(memstore.TelemetryBucket).Set("os", runtime.GOOS)
d.store.Bucket(memstore.TelemetryBucket).Set("arch", runtime.GOARCH)
d.store.Bucket(memstore.TelemetryBucket).Set("total_resources", analysis.Summary().TotalResources) d.store.Bucket(memstore.TelemetryBucket).Set("total_resources", analysis.Summary().TotalResources)
d.store.Bucket(memstore.TelemetryBucket).Set("total_managed", analysis.Summary().TotalManaged) d.store.Bucket(memstore.TelemetryBucket).Set("total_managed", analysis.Summary().TotalManaged)
d.store.Bucket(memstore.TelemetryBucket).Set("duration", uint(analysis.Duration.Seconds()+0.5)) d.store.Bucket(memstore.TelemetryBucket).Set("duration", uint(analysis.Duration.Seconds()+0.5))

View File

@ -32,7 +32,8 @@ type TestCase struct {
stateResources []resource.Resource stateResources []resource.Resource
remoteResources []resource.Resource remoteResources []resource.Resource
mocks func(factory resource.ResourceFactory, repo resource.SchemaRepositoryInterface) mocks func(factory resource.ResourceFactory, repo resource.SchemaRepositoryInterface)
assert func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) assert func(t *testing.T, result *test.ScanResult, err error)
assertStore func(*testing.T, memstore.Store)
options *pkg.ScanOptions options *pkg.ScanOptions
} }
@ -105,7 +106,10 @@ func runTest(t *testing.T, cases TestCases) {
analysis, err := driftctl.Run() analysis, err := driftctl.Run()
c.assert(t, test.NewScanResult(t, analysis), err, store) c.assert(t, test.NewScanResult(t, analysis), err)
if c.assertStore != nil {
c.assertStore(t, store)
}
scanProgress.AssertExpectations(t) scanProgress.AssertExpectations(t)
}) })
} }
@ -127,12 +131,10 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) {
name: "analysis duration is set", name: "analysis duration is set",
stateResources: []resource.Resource{}, stateResources: []resource.Resource{},
remoteResources: []resource.Resource{}, remoteResources: []resource.Resource{},
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.NotZero(result.Duration) result.NotZero(result.Duration)
},
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version")) assertStore: func(t *testing.T, store memstore.Store) {
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 0, store.Bucket(memstore.TelemetryBucket).Get("total_resources")) assert.Equal(t, 0, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 0, store.Bucket(memstore.TelemetryBucket).Get("total_managed")) assert.Equal(t, 0, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration")) assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
@ -146,12 +148,10 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) {
remoteResources: []resource.Resource{ remoteResources: []resource.Resource{
&testresource.FakeResource{}, &testresource.FakeResource{},
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertInfrastructureIsInSync() result.AssertInfrastructureIsInSync()
},
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version")) assertStore: func(t *testing.T, store memstore.Store) {
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources")) assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_managed")) assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration")) assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
@ -166,12 +166,10 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) {
&testresource.FakeResource{}, &testresource.FakeResource{},
}, },
remoteResources: []resource.Resource{}, remoteResources: []resource.Resource{},
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertDeletedCount(1) result.AssertDeletedCount(1)
},
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version")) assertStore: func(t *testing.T, store memstore.Store) {
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources")) assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 0, store.Bucket(memstore.TelemetryBucket).Get("total_managed")) assert.Equal(t, 0, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration")) assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
@ -183,12 +181,10 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) {
remoteResources: []resource.Resource{ remoteResources: []resource.Resource{
&testresource.FakeResource{}, &testresource.FakeResource{},
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertUnmanagedCount(1) result.AssertUnmanagedCount(1)
},
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version")) assertStore: func(t *testing.T, store memstore.Store) {
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources")) assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 0, store.Bucket(memstore.TelemetryBucket).Get("total_managed")) assert.Equal(t, 0, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration")) assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
@ -212,7 +208,7 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) {
}, },
}, },
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertManagedCount(1) result.AssertManagedCount(1)
result.AssertResourceHasDrift("fake", "FakeResource", analyser.Change{ result.AssertResourceHasDrift("fake", "FakeResource", analyser.Change{
Change: diff.Change{ Change: diff.Change{
@ -223,10 +219,8 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) {
}, },
Computed: false, Computed: false,
}) })
},
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version")) assertStore: func(t *testing.T, store memstore.Store) {
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources")) assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_managed")) assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration")) assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
@ -252,7 +246,7 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) {
}, },
}, },
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertManagedCount(1) result.AssertManagedCount(1)
result.AssertResourceHasDrift("fake", aws.AwsAmiResourceType, analyser.Change{ result.AssertResourceHasDrift("fake", aws.AwsAmiResourceType, analyser.Change{
Change: diff.Change{ Change: diff.Change{
@ -263,10 +257,8 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) {
}, },
Computed: true, Computed: true,
}) })
},
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version")) assertStore: func(t *testing.T, store memstore.Store) {
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources")) assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_managed")) assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration")) assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
@ -292,7 +284,7 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) {
}, },
}, },
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertManagedCount(1) result.AssertManagedCount(1)
result.AssertResourceHasDrift("fake", "FakeResource", analyser.Change{ result.AssertResourceHasDrift("fake", "FakeResource", analyser.Change{
Change: diff.Change{ Change: diff.Change{
@ -303,10 +295,8 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) {
}, },
Computed: false, Computed: false,
}) })
},
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version")) assertStore: func(t *testing.T, store memstore.Store) {
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources")) assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_managed")) assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration")) assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
@ -332,7 +322,7 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) {
}, },
}, },
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertManagedCount(1) result.AssertManagedCount(1)
result.AssertResourceHasDrift("fake", "FakeResource", analyser.Change{ result.AssertResourceHasDrift("fake", "FakeResource", analyser.Change{
Change: diff.Change{ Change: diff.Change{
@ -343,10 +333,8 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) {
}, },
Computed: false, Computed: false,
}) })
},
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version")) assertStore: func(t *testing.T, store memstore.Store) {
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources")) assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_managed")) assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration")) assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
@ -428,15 +416,13 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) {
}, },
}, },
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertManagedCount(2) result.AssertManagedCount(2)
result.AssertUnmanagedCount(2) result.AssertUnmanagedCount(2)
result.AssertDeletedCount(0) result.AssertDeletedCount(0)
result.AssertDriftCountTotal(0) result.AssertDriftCountTotal(0)
},
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version")) assertStore: func(t *testing.T, store memstore.Store) {
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 4, store.Bucket(memstore.TelemetryBucket).Get("total_resources")) assert.Equal(t, 4, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 2, store.Bucket(memstore.TelemetryBucket).Get("total_managed")) assert.Equal(t, 2, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration")) assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
@ -524,15 +510,13 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) {
}, },
}, },
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertManagedCount(2) result.AssertManagedCount(2)
result.AssertUnmanagedCount(4) result.AssertUnmanagedCount(4)
result.AssertDeletedCount(0) result.AssertDeletedCount(0)
result.AssertDriftCountTotal(0) result.AssertDriftCountTotal(0)
},
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version")) assertStore: func(t *testing.T, store memstore.Store) {
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 6, store.Bucket(memstore.TelemetryBucket).Get("total_resources")) assert.Equal(t, 6, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 2, store.Bucket(memstore.TelemetryBucket).Get("total_managed")) assert.Equal(t, 2, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration")) assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
@ -620,17 +604,15 @@ func TestDriftctlRun_BasicBehavior(t *testing.T) {
}, },
}, },
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertCoverage(0) result.AssertCoverage(0)
result.AssertInfrastructureIsNotSync() result.AssertInfrastructureIsNotSync()
result.AssertManagedCount(0) result.AssertManagedCount(0)
result.AssertUnmanagedCount(1) result.AssertUnmanagedCount(1)
result.AssertDeletedCount(0) result.AssertDeletedCount(0)
result.AssertDriftCountTotal(0) result.AssertDriftCountTotal(0)
},
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version")) assertStore: func(t *testing.T, store memstore.Store) {
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources")) assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 0, store.Bucket(memstore.TelemetryBucket).Get("total_managed")) assert.Equal(t, 0, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration")) assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
@ -671,16 +653,9 @@ func TestDriftctlRun_BasicFilter(t *testing.T) {
Attrs: &resource.Attributes{}, Attrs: &resource.Attributes{},
}, },
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertUnmanagedCount(1) result.AssertUnmanagedCount(1)
result.AssertResourceUnmanaged("res2", "filtered") result.AssertResourceUnmanaged("res2", "filtered")
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version"))
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 0, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
}, },
options: func(t *testing.T) *pkg.ScanOptions { options: func(t *testing.T) *pkg.ScanOptions {
filterStr := "Type=='filtered'" filterStr := "Type=='filtered'"
@ -707,16 +682,9 @@ func TestDriftctlRun_BasicFilter(t *testing.T) {
Attrs: &resource.Attributes{}, Attrs: &resource.Attributes{},
}, },
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertUnmanagedCount(1) result.AssertUnmanagedCount(1)
result.AssertResourceUnmanaged("res2", "filtered") result.AssertResourceUnmanaged("res2", "filtered")
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version"))
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 0, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
}, },
options: func(t *testing.T) *pkg.ScanOptions { options: func(t *testing.T) *pkg.ScanOptions {
filterStr := "Id=='res2'" filterStr := "Id=='res2'"
@ -745,16 +713,9 @@ func TestDriftctlRun_BasicFilter(t *testing.T) {
Attrs: &resource.Attributes{}, Attrs: &resource.Attributes{},
}, },
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertUnmanagedCount(1) result.AssertUnmanagedCount(1)
result.AssertResourceUnmanaged("res1", "filtered") result.AssertResourceUnmanaged("res1", "filtered")
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version"))
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 0, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
}, },
options: func(t *testing.T) *pkg.ScanOptions { options: func(t *testing.T) *pkg.ScanOptions {
filterStr := "Attr.test_field=='value to filter on'" filterStr := "Attr.test_field=='value to filter on'"
@ -817,7 +778,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) {
Sch: getSchema(repo, aws.AwsS3BucketPolicyResourceType), Sch: getSchema(repo, aws.AwsS3BucketPolicyResourceType),
}) })
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertManagedCount(1) result.AssertManagedCount(1)
result.AssertResourceHasDrift("foo", "aws_s3_bucket_policy", analyser.Change{ result.AssertResourceHasDrift("foo", "aws_s3_bucket_policy", analyser.Change{
Change: diff.Change{ Change: diff.Change{
@ -829,13 +790,6 @@ func TestDriftctlRun_Middlewares(t *testing.T) {
Computed: false, Computed: false,
JsonString: true, JsonString: true,
}) })
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version"))
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
}, },
options: func(t *testing.T) *pkg.ScanOptions { options: func(t *testing.T) *pkg.ScanOptions {
filterStr := "Type=='aws_s3_bucket_policy' && Attr.bucket=='foo'" filterStr := "Type=='aws_s3_bucket_policy' && Attr.bucket=='foo'"
@ -929,7 +883,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) {
}) })
})).Times(1).Return(&bar, nil) })).Times(1).Return(&bar, nil)
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertManagedCount(2) result.AssertManagedCount(2)
result.AssertResourceHasDrift("vol-02862d9b39045a3a4", "aws_ebs_volume", analyser.Change{ result.AssertResourceHasDrift("vol-02862d9b39045a3a4", "aws_ebs_volume", analyser.Change{
Change: diff.Change{ Change: diff.Change{
@ -949,13 +903,6 @@ func TestDriftctlRun_Middlewares(t *testing.T) {
}, },
Computed: true, Computed: true,
}) })
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version"))
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 2, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 2, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
}, },
options: func(t *testing.T) *pkg.ScanOptions { options: func(t *testing.T) *pkg.ScanOptions {
filterStr := "Type=='aws_ebs_volume' && Attr.availability_zone=='us-east-1'" filterStr := "Type=='aws_ebs_volume' && Attr.availability_zone=='us-east-1'"
@ -1052,16 +999,9 @@ func TestDriftctlRun_Middlewares(t *testing.T) {
}, },
}, nil) }, nil)
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertManagedCount(2) result.AssertManagedCount(2)
result.AssertInfrastructureIsInSync() result.AssertInfrastructureIsInSync()
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version"))
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 2, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 2, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
}, },
options: func(t *testing.T) *pkg.ScanOptions { options: func(t *testing.T) *pkg.ScanOptions {
filterStr := "Type=='aws_route' && Attr.gateway_id=='igw-07b7844a8fd17a638'" filterStr := "Type=='aws_route' && Attr.gateway_id=='igw-07b7844a8fd17a638'"
@ -1113,7 +1053,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) {
Sch: getSchema(repo, aws.AwsSnsTopicPolicyResourceType), Sch: getSchema(repo, aws.AwsSnsTopicPolicyResourceType),
}, nil) }, nil)
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertManagedCount(1) result.AssertManagedCount(1)
result.AssertResourceHasDrift("foo", "aws_sns_topic_policy", analyser.Change{ result.AssertResourceHasDrift("foo", "aws_sns_topic_policy", analyser.Change{
Change: diff.Change{ Change: diff.Change{
@ -1125,13 +1065,6 @@ func TestDriftctlRun_Middlewares(t *testing.T) {
Computed: false, Computed: false,
JsonString: true, JsonString: true,
}) })
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version"))
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
}, },
options: func(t *testing.T) *pkg.ScanOptions { options: func(t *testing.T) *pkg.ScanOptions {
filterStr := "Type=='aws_sns_topic_policy' && Attr.arn=='arn'" filterStr := "Type=='aws_sns_topic_policy' && Attr.arn=='arn'"
@ -1182,7 +1115,7 @@ func TestDriftctlRun_Middlewares(t *testing.T) {
Sch: getSchema(repo, aws.AwsSqsQueuePolicyResourceType), Sch: getSchema(repo, aws.AwsSqsQueuePolicyResourceType),
}, nil) }, nil)
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertManagedCount(1) result.AssertManagedCount(1)
result.AssertResourceHasDrift("foo", "aws_sqs_queue_policy", analyser.Change{ result.AssertResourceHasDrift("foo", "aws_sqs_queue_policy", analyser.Change{
Change: diff.Change{ Change: diff.Change{
@ -1194,13 +1127,6 @@ func TestDriftctlRun_Middlewares(t *testing.T) {
Computed: false, Computed: false,
JsonString: true, JsonString: true,
}) })
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version"))
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 1, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
}, },
options: func(t *testing.T) *pkg.ScanOptions { options: func(t *testing.T) *pkg.ScanOptions {
filterStr := "Type=='aws_sqs_queue_policy' && Attr.queue_url=='foo'" filterStr := "Type=='aws_sqs_queue_policy' && Attr.queue_url=='foo'"
@ -1514,16 +1440,9 @@ func TestDriftctlRun_Middlewares(t *testing.T) {
}) })
})).Times(1).Return(&rule4, nil) })).Times(1).Return(&rule4, nil)
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertManagedCount(7) result.AssertManagedCount(7)
result.AssertInfrastructureIsInSync() result.AssertInfrastructureIsInSync()
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version"))
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 7, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 7, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
}, },
options: func(t *testing.T) *pkg.ScanOptions { options: func(t *testing.T) *pkg.ScanOptions {
filterStr := "Type=='aws_security_group_rule' && Attr.security_group_id=='sg-0254c038e32f25530'" filterStr := "Type=='aws_security_group_rule' && Attr.security_group_id=='sg-0254c038e32f25530'"
@ -1659,16 +1578,9 @@ func TestDriftctlRun_Middlewares(t *testing.T) {
}, },
}, nil) }, nil)
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertManagedCount(2) result.AssertManagedCount(2)
result.AssertInfrastructureIsInSync() result.AssertInfrastructureIsInSync()
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version"))
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 2, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 2, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
}, },
options: func(t *testing.T) *pkg.ScanOptions { options: func(t *testing.T) *pkg.ScanOptions {
filterStr := "Type=='aws_iam_policy_attachment'" filterStr := "Type=='aws_iam_policy_attachment'"
@ -1750,16 +1662,9 @@ func TestDriftctlRun_Middlewares(t *testing.T) {
Attrs: &resource.Attributes{}, Attrs: &resource.Attributes{},
}, },
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertInfrastructureIsInSync() result.AssertInfrastructureIsInSync()
result.AssertManagedCount(5) result.AssertManagedCount(5)
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version"))
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 5, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 5, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
}, },
}, },
{ {
@ -1836,16 +1741,9 @@ func TestDriftctlRun_Middlewares(t *testing.T) {
}, },
}, },
}, },
assert: func(t *testing.T, result *test.ScanResult, err error, store memstore.Store) { assert: func(t *testing.T, result *test.ScanResult, err error) {
result.AssertInfrastructureIsInSync() result.AssertInfrastructureIsInSync()
result.AssertManagedCount(5) result.AssertManagedCount(5)
assert.Equal(t, "dev-dev", store.Bucket(memstore.TelemetryBucket).Get("version"))
assert.Equal(t, "linux", store.Bucket(memstore.TelemetryBucket).Get("os"))
assert.Equal(t, "amd64", store.Bucket(memstore.TelemetryBucket).Get("arch"))
assert.Equal(t, 5, store.Bucket(memstore.TelemetryBucket).Get("total_resources"))
assert.Equal(t, 5, store.Bucket(memstore.TelemetryBucket).Get("total_managed"))
assert.Equal(t, uint(0), store.Bucket(memstore.TelemetryBucket).Get("duration"))
}, },
}, },
} }

View File

@ -4,8 +4,10 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"net/http" "net/http"
"runtime"
"github.com/cloudskiff/driftctl/pkg/memstore" "github.com/cloudskiff/driftctl/pkg/memstore"
"github.com/cloudskiff/driftctl/pkg/version"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -18,30 +20,22 @@ type telemetry struct {
Duration uint `json:"duration"` Duration uint `json:"duration"`
} }
func SendTelemetry(s memstore.Store) { func SendTelemetry(store memstore.Bucket) {
t := &telemetry{} t := &telemetry{
Version: version.Current(),
if val, ok := s.Bucket(memstore.TelemetryBucket).Get("version").(string); ok { Os: runtime.GOOS,
t.Version = val Arch: runtime.GOARCH,
} }
if val, ok := s.Bucket(memstore.TelemetryBucket).Get("os").(string); ok { if val, ok := store.Get("total_resources").(int); ok {
t.Os = val
}
if val, ok := s.Bucket(memstore.TelemetryBucket).Get("arch").(string); ok {
t.Arch = val
}
if val, ok := s.Bucket(memstore.TelemetryBucket).Get("total_resources").(int); ok {
t.TotalResources = val t.TotalResources = val
} }
if val, ok := s.Bucket(memstore.TelemetryBucket).Get("total_managed").(int); ok { if val, ok := store.Get("total_managed").(int); ok {
t.TotalManaged = val t.TotalManaged = val
} }
if val, ok := s.Bucket(memstore.TelemetryBucket).Get("duration").(uint); ok { if val, ok := store.Get("duration").(uint); ok {
t.Duration = val t.Duration = val
} }

View File

@ -24,7 +24,7 @@ func TestSendTelemetry(t *testing.T) {
analysis *analyser.Analysis analysis *analyser.Analysis
expectedBody *telemetry expectedBody *telemetry
response *http.Response response *http.Response
setStoreValues func(memstore.Store, *analyser.Analysis) setStoreValues func(memstore.Bucket, *analyser.Analysis)
}{ }{
{ {
name: "valid analysis", name: "valid analysis",
@ -43,13 +43,10 @@ func TestSendTelemetry(t *testing.T) {
TotalManaged: 1, TotalManaged: 1,
Duration: 123, Duration: 123,
}, },
setStoreValues: func(s memstore.Store, a *analyser.Analysis) { setStoreValues: func(s memstore.Bucket, a *analyser.Analysis) {
s.Bucket(memstore.TelemetryBucket).Set("version", version.Current()) s.Set("total_resources", a.Summary().TotalResources)
s.Bucket(memstore.TelemetryBucket).Set("os", runtime.GOOS) s.Set("total_managed", a.Summary().TotalManaged)
s.Bucket(memstore.TelemetryBucket).Set("arch", runtime.GOARCH) s.Set("duration", uint(a.Duration.Seconds()+0.5))
s.Bucket(memstore.TelemetryBucket).Set("total_resources", a.Summary().TotalResources)
s.Bucket(memstore.TelemetryBucket).Set("total_managed", a.Summary().TotalManaged)
s.Bucket(memstore.TelemetryBucket).Set("duration", uint(a.Duration.Seconds()+0.5))
}, },
}, },
{ {
@ -65,13 +62,10 @@ func TestSendTelemetry(t *testing.T) {
Arch: runtime.GOARCH, Arch: runtime.GOARCH,
Duration: 124, Duration: 124,
}, },
setStoreValues: func(s memstore.Store, a *analyser.Analysis) { setStoreValues: func(s memstore.Bucket, a *analyser.Analysis) {
s.Bucket(memstore.TelemetryBucket).Set("version", version.Current()) s.Set("total_resources", a.Summary().TotalResources)
s.Bucket(memstore.TelemetryBucket).Set("os", runtime.GOOS) s.Set("total_managed", a.Summary().TotalManaged)
s.Bucket(memstore.TelemetryBucket).Set("arch", runtime.GOARCH) s.Set("duration", uint(a.Duration.Seconds()+0.5))
s.Bucket(memstore.TelemetryBucket).Set("total_resources", a.Summary().TotalResources)
s.Bucket(memstore.TelemetryBucket).Set("total_managed", a.Summary().TotalManaged)
s.Bucket(memstore.TelemetryBucket).Set("duration", uint(a.Duration.Seconds()+0.5))
}, },
}, },
{ {
@ -85,16 +79,20 @@ func TestSendTelemetry(t *testing.T) {
a.Duration = 123.5 * 1e9 // 123.5 seconds a.Duration = 123.5 * 1e9 // 123.5 seconds
return a return a
}(), }(),
expectedBody: &telemetry{}, expectedBody: &telemetry{
setStoreValues: func(s memstore.Store, a *analyser.Analysis) {}, Version: version.Current(),
Os: runtime.GOOS,
Arch: runtime.GOARCH,
},
setStoreValues: func(s memstore.Bucket, a *analyser.Analysis) {},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
s := memstore.New() store := memstore.New().Bucket(memstore.TelemetryBucket)
if tt.analysis != nil { if tt.analysis != nil {
tt.setStoreValues(s, tt.analysis) tt.setStoreValues(store, tt.analysis)
} }
httpmock.Reset() httpmock.Reset()
@ -124,7 +122,7 @@ func TestSendTelemetry(t *testing.T) {
}, },
) )
} }
SendTelemetry(s) SendTelemetry(store)
}) })
} }
} }