refactor: avoid reading unmanaged resources

main
sundowndev 2021-12-08 14:56:20 +01:00
parent 812408be01
commit 18eea35b4f
No known key found for this signature in database
GPG Key ID: 100CE2799D978462
11 changed files with 278 additions and 64 deletions

View File

@ -74,7 +74,14 @@ type GenDriftIgnoreOptions struct {
}
func NewAnalysis(options AnalyzerOptions) *Analysis {
return &Analysis{options: options}
return &Analysis{
unmanaged: []*resource.Resource{},
managed: []*resource.Resource{},
deleted: []*resource.Resource{},
differences: []Difference{},
alerts: alerter.Alerts{},
options: options,
}
}
func (a Analysis) MarshalJSON() ([]byte, error) {

View File

@ -51,37 +51,57 @@ func NewAnalyzer(alerter *alerter.Alerter, options AnalyzerOptions, filter filte
return &Analyzer{alerter, options, filter}
}
func (a Analyzer) Analyze(remoteResources, resourcesFromState []*resource.Resource) (Analysis, error) {
analysis := Analysis{options: a.options}
func (a Analyzer) CompareEnumeration(analysis *Analysis, remoteResources, resourcesFromState []*resource.Resource) *Analysis {
// Iterate on remote resources and filter ignored resources
filteredRemoteResource := make([]*resource.Resource, 0, len(remoteResources))
filteredRemoteResources := make([]*resource.Resource, 0, len(remoteResources))
for _, remoteRes := range remoteResources {
if a.filter.IsResourceIgnored(remoteRes) || a.alerter.IsResourceIgnored(remoteRes) {
continue
}
filteredRemoteResource = append(filteredRemoteResource, remoteRes)
filteredRemoteResources = append(filteredRemoteResources, remoteRes)
}
haveComputedDiff := false
for _, stateRes := range resourcesFromState {
i, remoteRes, found := findCorrespondingRes(filteredRemoteResource, stateRes)
if a.filter.IsResourceIgnored(stateRes) || a.alerter.IsResourceIgnored(stateRes) {
continue
}
i, _, found := resource.FindCorrespondingRes(filteredRemoteResources, stateRes)
if !found {
analysis.AddDeleted(stateRes)
continue
}
// Remove managed resources, so it will remain only unmanaged ones
filteredRemoteResource = removeResourceByIndex(i, filteredRemoteResource)
filteredRemoteResources = removeResourceByIndex(i, filteredRemoteResources)
analysis.AddManaged(stateRes)
}
// Stop there if we are not in deep mode, we do not want to compute diffs
if !a.options.Deep {
if a.hasUnmanagedSecurityGroupRules(filteredRemoteResources) {
a.alerter.SendAlert("", newUnmanagedSecurityGroupRulesAlert())
}
// Add remaining unmanaged resources
analysis.AddUnmanaged(filteredRemoteResources...)
return analysis
}
func (a Analyzer) CompleteAnalysis(analysis *Analysis, managedResources, resourcesFromState []*resource.Resource) *Analysis {
// Stop there if we are not in deep mode, we do not want to compute diffs
if !a.options.Deep {
a.setAlerts(analysis)
return analysis
}
haveComputedDiff := false
for _, remoteRes := range managedResources {
if a.filter.IsResourceIgnored(remoteRes) || a.alerter.IsResourceIgnored(remoteRes) {
continue
}
_, stateRes, found := resource.FindCorrespondingRes(resourcesFromState, remoteRes)
if !found {
continue
}
@ -121,33 +141,17 @@ func (a Analyzer) Analyze(remoteResources, resourcesFromState []*resource.Resour
}
}
if a.hasUnmanagedSecurityGroupRules(filteredRemoteResource) {
a.alerter.SendAlert("", newUnmanagedSecurityGroupRulesAlert())
}
if haveComputedDiff {
a.alerter.SendAlert("", NewComputedDiffAlert())
}
// Add remaining unmanaged resources
analysis.AddUnmanaged(filteredRemoteResource...)
a.setAlerts(analysis)
// Sort resources by Terraform Id
// The purpose is to have a predictable output
analysis.SortResources()
analysis.SetAlerts(a.alerter.Retrieve())
return analysis, nil
return analysis
}
func findCorrespondingRes(resources []*resource.Resource, res *resource.Resource) (int, *resource.Resource, bool) {
for i, r := range resources {
if res.Equal(r) {
return i, r, true
}
}
return -1, nil, false
func (a Analyzer) setAlerts(analysis *Analysis) {
analysis.SetAlerts(a.alerter.Retrieve())
}
func removeResourceByIndex(i int, resources []*resource.Resource) []*resource.Resource {

View File

@ -1026,18 +1026,21 @@ func TestAnalyze(t *testing.T) {
addSchemaToRes(drift.res, repo)
}
result, err := analyzer.Analyze(c.cloud, c.iac)
analysis := NewAnalysis(AnalyzerOptions{Deep: true})
analysis = analyzer.CompareEnumeration(analysis, c.cloud, c.iac)
analysis = analyzer.CompleteAnalysis(analysis, c.cloud, c.iac)
analysis.SortResources()
if err != nil {
t.Error(err)
return
}
if result.IsSync() == c.hasDrifted {
t.Errorf("Drifted state does not match, got %t expected %t", result.IsSync(), !c.hasDrifted)
if analysis.IsSync() == c.hasDrifted {
t.Errorf("Drifted state does not match, got %t expected %t", analysis.IsSync(), !c.hasDrifted)
}
managedChanges, err := differ.Diff(result.Managed(), c.expected.Managed())
managedChanges, err := differ.Diff(analysis.Managed(), c.expected.Managed())
if err != nil {
t.Fatalf("Unable to compare %+v", err)
}
@ -1047,7 +1050,7 @@ func TestAnalyze(t *testing.T) {
}
}
unmanagedChanges, err := differ.Diff(result.Unmanaged(), c.expected.Unmanaged())
unmanagedChanges, err := differ.Diff(analysis.Unmanaged(), c.expected.Unmanaged())
if err != nil {
t.Fatalf("Unable to compare %+v", err)
}
@ -1057,7 +1060,7 @@ func TestAnalyze(t *testing.T) {
}
}
deletedChanges, err := differ.Diff(result.Deleted(), c.expected.Deleted())
deletedChanges, err := differ.Diff(analysis.Deleted(), c.expected.Deleted())
if err != nil {
t.Fatalf("Unable to compare %+v", err)
}
@ -1067,7 +1070,7 @@ func TestAnalyze(t *testing.T) {
}
}
diffChanges, err := differ.Diff(result.Differences(), c.expected.Differences())
diffChanges, err := differ.Diff(analysis.Differences(), c.expected.Differences())
if err != nil {
t.Fatalf("Unable to compare %+v", err)
}
@ -1077,7 +1080,7 @@ func TestAnalyze(t *testing.T) {
}
}
summaryChanges, err := differ.Diff(c.expected.Summary(), result.Summary())
summaryChanges, err := differ.Diff(c.expected.Summary(), analysis.Summary())
if err != nil {
t.Fatalf("Unable to compare %+v", err)
}
@ -1087,7 +1090,7 @@ func TestAnalyze(t *testing.T) {
}
}
alertsChanges, err := differ.Diff(result.Alerts(), c.expected.Alerts())
alertsChanges, err := differ.Diff(analysis.Alerts(), c.expected.Alerts())
if err != nil {
t.Fatalf("Unable to compare %+v", err)
}

View File

@ -37,7 +37,7 @@ type ScanOptions struct {
}
type DriftCTL struct {
remoteSupplier resource.Supplier
remoteSupplier resource.RemoteSupplier
iacSupplier resource.Supplier
alerter alerter.AlerterInterface
analyzer *analyser.Analyzer
@ -49,7 +49,7 @@ type DriftCTL struct {
store memstore.Store
}
func NewDriftCTL(remoteSupplier resource.Supplier,
func NewDriftCTL(remoteSupplier resource.RemoteSupplier,
iacSupplier resource.Supplier,
alerter *alerter.Alerter,
analyzer *analyser.Analyzer,
@ -75,7 +75,7 @@ func NewDriftCTL(remoteSupplier resource.Supplier,
func (d DriftCTL) Run() (*analyser.Analysis, error) {
start := time.Now()
remoteResources, resourcesFromState, err := d.scan()
remoteResources, resourcesFromState, err := d.enumerateResources()
if err != nil {
return nil, err
}
@ -149,11 +149,23 @@ func (d DriftCTL) Run() (*analyser.Analysis, error) {
}
}
analysis, err := d.analyzer.Analyze(remoteResources, resourcesFromState)
analysis := analyser.NewAnalysis(analyser.AnalyzerOptions{Deep: d.opts.Deep})
analysis = d.analyzer.CompareEnumeration(analysis, remoteResources, resourcesFromState)
if err != nil {
return nil, err
}
managedResources, err := d.readResources(analysis.Managed())
if err != nil {
return nil, err
}
analysis = d.analyzer.CompleteAnalysis(analysis, managedResources, resourcesFromState)
// Sort resources by Terraform Id
// The purpose is to have a predictable output
analysis.SortResources()
analysis.Duration = time.Since(start)
analysis.Date = time.Now()
@ -161,7 +173,7 @@ func (d DriftCTL) Run() (*analyser.Analysis, error) {
d.store.Bucket(memstore.TelemetryBucket).Set("total_managed", analysis.Summary().TotalManaged)
d.store.Bucket(memstore.TelemetryBucket).Set("duration", uint(analysis.Duration.Seconds()+0.5))
return &analysis, nil
return analysis, nil
}
func (d DriftCTL) Stop() {
@ -179,7 +191,7 @@ func (d DriftCTL) Stop() {
}
}
func (d DriftCTL) scan() (remoteResources []*resource.Resource, resourcesFromState []*resource.Resource, err error) {
func (d DriftCTL) enumerateResources() (remoteResources []*resource.Resource, resourcesFromState []*resource.Resource, err error) {
logrus.Info("Start reading IaC")
d.iacProgress.Start()
resourcesFromState, err = d.iacSupplier.Resources()
@ -188,13 +200,20 @@ func (d DriftCTL) scan() (remoteResources []*resource.Resource, resourcesFromSta
return nil, nil, err
}
logrus.Info("Start scanning cloud provider")
logrus.Info("Start enumerating cloud provider resources")
d.scanProgress.Start()
defer d.scanProgress.Stop()
remoteResources, err = d.remoteSupplier.Resources()
remoteResources, err = d.remoteSupplier.EnumerateResources()
if err != nil {
return nil, nil, err
}
return remoteResources, resourcesFromState, err
}
func (d DriftCTL) readResources(managedResources []*resource.Resource) ([]*resource.Resource, error) {
logrus.WithField("count", len(managedResources)).Info("Fetching details of managed resources")
d.scanProgress.Start()
defer d.scanProgress.Stop()
return d.remoteSupplier.ReadResources(managedResources)
}

View File

@ -73,8 +73,10 @@ func runTest(t *testing.T, cases TestCases) {
schema, _ := repo.GetSchema(res.ResourceType())
res.Sch = schema
}
remoteSupplier := &resource.MockSupplier{}
remoteSupplier.On("Resources").Return(c.remoteResources, nil)
remoteSupplier := &resource.MockRemoteSupplier{}
remoteSupplier.On("EnumerateResources").Return(c.remoteResources, nil)
remoteSupplier.On("ReadResources", mock.IsType([]*resource.Resource{})).Return(c.remoteResources, nil)
var resourceFactory resource.ResourceFactory = terraform.NewTerraformResourceFactory(repo)
@ -88,8 +90,8 @@ func runTest(t *testing.T, cases TestCases) {
}
scanProgress := &output.MockProgress{}
scanProgress.On("Start").Return().Once()
scanProgress.On("Stop").Return().Once()
scanProgress.On("Start").Return().Times(2)
scanProgress.On("Stop").Return().Times(2)
iacProgress := &output.MockProgress{}
iacProgress.On("Start").Return().Once()

View File

@ -0,0 +1,36 @@
// Code generated by mockery v2.8.0. DO NOT EDIT.
package common
import (
resource "github.com/snyk/driftctl/pkg/resource"
mock "github.com/stretchr/testify/mock"
)
// MockDetailsFetcher is an autogenerated mock type for the DetailsFetcher type
type MockDetailsFetcher struct {
mock.Mock
}
// ReadDetails provides a mock function with given fields: _a0
func (_m *MockDetailsFetcher) ReadDetails(_a0 *resource.Resource) (*resource.Resource, error) {
ret := _m.Called(_a0)
var r0 *resource.Resource
if rf, ok := ret.Get(0).(func(*resource.Resource) *resource.Resource); ok {
r0 = rf(_a0)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*resource.Resource)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(*resource.Resource) error); ok {
r1 = rf(_a0)
} else {
r1 = ret.Error(1)
}
return r0, r1
}

View File

@ -58,7 +58,7 @@ loop:
return results, runner.Err()
}
func (s *Scanner) scan() ([]*resource.Resource, error) {
func (s *Scanner) EnumerateResources() ([]*resource.Resource, error) {
for _, enumerator := range s.remoteLibrary.Enumerators() {
if s.filter.IsTypeIgnored(enumerator.SupportedType()) {
logrus.WithFields(logrus.Fields{
@ -89,16 +89,11 @@ func (s *Scanner) scan() ([]*resource.Resource, error) {
})
}
enumerationResult, err := s.retrieveRunnerResults(s.enumeratorRunner)
if err != nil {
return nil, err
}
return s.retrieveRunnerResults(s.enumeratorRunner)
}
if !s.options.Deep {
return enumerationResult, nil
}
for _, res := range enumerationResult {
func (s *Scanner) ReadResources(managedResources []*resource.Resource) ([]*resource.Resource, error) {
for _, res := range managedResources {
res := res
s.detailsFetcherRunner.Run(func() (interface{}, error) {
fetcher := s.remoteLibrary.GetDetailsFetcher(resource.ResourceType(res.ResourceType()))
@ -121,7 +116,17 @@ func (s *Scanner) scan() ([]*resource.Resource, error) {
}
func (s *Scanner) Resources() ([]*resource.Resource, error) {
resources, err := s.scan()
resources, err := s.EnumerateResources()
if err != nil {
return nil, err
}
if !s.options.Deep {
return resources, nil
}
// Be aware that this call will read all resources, no matter they're managed or not
resources, err = s.ReadResources(resources)
if err != nil {
return nil, err
}

View File

@ -29,3 +29,47 @@ func TestScannerShouldIgnoreType(t *testing.T) {
assert.Nil(t, err)
fakeEnumerator.AssertExpectations(t)
}
func TestScannerShouldReadManagedOnly(t *testing.T) {
Resources := []*resource.Resource{
{
Id: "test-1",
Type: "FakeType",
Attrs: &resource.Attributes{},
},
{
Id: "test-2",
Type: "FakeType",
Attrs: &resource.Attributes{},
},
}
// Initialize mocks
fakeEnumerator := &common.MockEnumerator{}
fakeEnumerator.On("SupportedType").Return(resource.ResourceType("FakeType"))
fakeEnumerator.On("Enumerate").Return(Resources, nil)
fakeDetailsFetcher := &common.MockDetailsFetcher{}
fakeDetailsFetcher.On("ReadDetails", Resources[1]).Return(Resources[1], nil)
remoteLibrary := common.NewRemoteLibrary()
remoteLibrary.AddEnumerator(fakeEnumerator)
remoteLibrary.AddDetailsFetcher("FakeType", fakeDetailsFetcher)
testFilter := &filter.MockFilter{}
testFilter.On("IsTypeIgnored", resource.ResourceType("FakeType")).Return(false)
s := NewScanner(remoteLibrary, alerter.NewAlerter(), ScannerOptions{Deep: true}, testFilter)
remoteResources, err := s.EnumerateResources()
assert.Nil(t, err)
remoteResources, err = s.ReadResources(remoteResources[1:])
assert.Nil(t, err)
assert.Equal(t, Resources[1:], remoteResources)
fakeEnumerator.AssertExpectations(t)
fakeDetailsFetcher.AssertExpectations(t)
testFilter.AssertExpectations(t)
}

View File

@ -0,0 +1,79 @@
// Code generated by mockery v2.8.0. DO NOT EDIT.
package resource
import mock "github.com/stretchr/testify/mock"
// MockRemoteSupplier is an autogenerated mock type for the RemoteSupplier type
type MockRemoteSupplier struct {
mock.Mock
}
// EnumerateResources provides a mock function with given fields:
func (_m *MockRemoteSupplier) EnumerateResources() ([]*Resource, error) {
ret := _m.Called()
var r0 []*Resource
if rf, ok := ret.Get(0).(func() []*Resource); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*Resource)
}
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ReadResources provides a mock function with given fields: _a0
func (_m *MockRemoteSupplier) ReadResources(_a0 []*Resource) ([]*Resource, error) {
ret := _m.Called(_a0)
var r0 []*Resource
if rf, ok := ret.Get(0).(func([]*Resource) []*Resource); ok {
r0 = rf(_a0)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*Resource)
}
}
var r1 error
if rf, ok := ret.Get(1).(func([]*Resource) error); ok {
r1 = rf(_a0)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Resources provides a mock function with given fields:
func (_m *MockRemoteSupplier) Resources() ([]*Resource, error) {
ret := _m.Called()
var r0 []*Resource
if rf, ok := ret.Get(0).(func() []*Resource); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*Resource)
}
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}

View File

@ -318,3 +318,12 @@ func (a *Attributes) sanitize(path string, original, copy reflect.Value) bool {
}
return true
}
func FindCorrespondingRes(resources []*Resource, res *Resource) (int, *Resource, bool) {
for i, r := range resources {
if res.Equal(r) {
return i, r, true
}
}
return -1, nil, false
}

View File

@ -5,6 +5,12 @@ type Supplier interface {
Resources() ([]*Resource, error)
}
type RemoteSupplier interface {
Supplier
EnumerateResources() ([]*Resource, error)
ReadResources([]*Resource) ([]*Resource, error)
}
type StoppableSupplier interface {
Supplier
Stop()