From a8d1393e961d5db72c45efa840a841854e96a2c1 Mon Sep 17 00:00:00 2001 From: Mzack9999 Date: Wed, 3 Apr 2024 17:50:57 +0200 Subject: [PATCH] init- using resizable components --- examples/advanced/advanced.go | 7 ++- go.mod | 2 + go.sum | 3 ++ internal/runner/inputs.go | 13 +++-- pkg/core/execute_options.go | 7 ++- pkg/core/executors.go | 14 ++--- pkg/core/workflow_execute.go | 9 ++-- pkg/core/workpool.go | 26 ++++------ pkg/js/compiler/non-pool.go | 6 +-- pkg/js/compiler/pool.go | 6 +-- .../common/automaticscan/automaticscan.go | 9 ++-- pkg/protocols/dns/request.go | 7 ++- pkg/protocols/file/request.go | 9 ++-- pkg/protocols/headless/engine/page_actions.go | 28 +++++----- pkg/protocols/http/httputils/spm.go | 9 ++-- pkg/protocols/javascript/js.go | 5 +- .../network/networkclientpool/clientpool.go | 2 +- pkg/protocols/network/request.go | 7 ++- pkg/protocols/offlinehttp/request.go | 11 ++-- pkg/tmplexec/flow/vm.go | 51 +++++++++---------- 20 files changed, 126 insertions(+), 105 deletions(-) diff --git a/examples/advanced/advanced.go b/examples/advanced/advanced.go index 5ce579b3..110160f9 100644 --- a/examples/advanced/advanced.go +++ b/examples/advanced/advanced.go @@ -2,7 +2,7 @@ package main import ( nuclei "github.com/projectdiscovery/nuclei/v3/lib" - "github.com/remeh/sizedwaitgroup" + syncutil "github.com/projectdiscovery/utils/sync" ) func main() { @@ -12,7 +12,10 @@ func main() { panic(err) } // setup sizedWaitgroup to handle concurrency - sg := sizedwaitgroup.New(10) + sg, err := syncutil.New(syncutil.WithSize(10)) + if err != nil { + panic(err) + } // scan 1 = run dns templates on scanme.sh sg.Add() diff --git a/go.mod b/go.mod index 7eeaf573..6a0e8bfc 100644 --- a/go.mod +++ b/go.mod @@ -142,6 +142,8 @@ require ( github.com/docker/cli v24.0.5+incompatible // indirect github.com/docker/docker v24.0.9+incompatible // indirect github.com/docker/go-connections v0.4.0 // indirect + github.com/eapache/channels v1.1.0 // indirect + github.com/eapache/queue v1.1.0 // indirect github.com/fatih/color v1.15.0 // indirect github.com/free5gc/util v1.0.5-0.20230511064842-2e120956883b // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect diff --git a/go.sum b/go.sum index cefd9ae4..83551eb1 100644 --- a/go.sum +++ b/go.sum @@ -296,8 +296,11 @@ github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj6 github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/eapache/channels v1.1.0 h1:F1taHcn7/F0i8DYqKXJnyhJcVpp2kgFcNePxXtnyu4k= +github.com/eapache/channels v1.1.0/go.mod h1:jMm2qB5Ubtg9zLd+inMZd2/NUvXgzmWXsDaLyQIGfH0= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= +github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a h1:mATvB/9r/3gvcejNsXKSkQ6lcIaNec2nyfOdlTBR2lU= diff --git a/internal/runner/inputs.go b/internal/runner/inputs.go index 8dc27a7a..75a86991 100644 --- a/internal/runner/inputs.go +++ b/internal/runner/inputs.go @@ -12,7 +12,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs" "github.com/projectdiscovery/nuclei/v3/pkg/utils" stringsutil "github.com/projectdiscovery/utils/strings" - "github.com/remeh/sizedwaitgroup" + syncutil "github.com/projectdiscovery/utils/sync" ) const probeBulkSize = 50 @@ -45,8 +45,11 @@ func (r *Runner) initializeTemplatesHTTPInput() (*hybrid.HybridMap, error) { } // Probe the non-standard URLs and store them in cache - swg := sizedwaitgroup.New(bulkSize) - count := int32(0) + swg, err := syncutil.New(syncutil.WithSize(bulkSize)) + if err != nil { + return nil, errors.Wrap(err, "could not create adaptive group") + } + var count atomic.Int32 r.inputProvider.Iterate(func(value *contextargs.MetaInput) bool { if stringsutil.HasPrefixAny(value.Input, "http://", "https://") { return true @@ -57,7 +60,7 @@ func (r *Runner) initializeTemplatesHTTPInput() (*hybrid.HybridMap, error) { defer swg.Done() if result := utils.ProbeURL(input.Input, httpxClient); result != "" { - atomic.AddInt32(&count, 1) + count.Add(1) _ = hm.Set(input.Input, []byte(result)) } }(value) @@ -65,6 +68,6 @@ func (r *Runner) initializeTemplatesHTTPInput() (*hybrid.HybridMap, error) { }) swg.Wait() - gologger.Info().Msgf("Found %d URL from httpx", atomic.LoadInt32(&count)) + gologger.Info().Msgf("Found %d URL from httpx", count.Load()) return hm, nil } diff --git a/pkg/core/execute_options.go b/pkg/core/execute_options.go index fd1fadae..580b8b0a 100644 --- a/pkg/core/execute_options.go +++ b/pkg/core/execute_options.go @@ -4,8 +4,6 @@ import ( "sync" "sync/atomic" - "github.com/remeh/sizedwaitgroup" - "github.com/projectdiscovery/gologger" "github.com/projectdiscovery/nuclei/v3/pkg/input/provider" "github.com/projectdiscovery/nuclei/v3/pkg/output" @@ -14,6 +12,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/templates/types" "github.com/projectdiscovery/nuclei/v3/pkg/types/scanstrategy" stringsutil "github.com/projectdiscovery/utils/strings" + syncutil "github.com/projectdiscovery/utils/sync" ) // Execute takes a list of templates/workflows that have been compiled @@ -111,7 +110,7 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe for _, template := range templatesList { templateType := template.Type() - var wg *sizedwaitgroup.SizedWaitGroup + var wg *syncutil.AdaptiveWaitGroup if templateType == types.HeadlessProtocol { wg = wp.Headless } else { @@ -134,7 +133,7 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe // executeHostSpray executes scan using host spray strategy where templates are iterated over each target func (e *Engine) executeHostSpray(templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool { results := &atomic.Bool{} - wp := sizedwaitgroup.New(e.options.BulkSize + e.options.HeadlessBulkSize) + wp, _ := syncutil.New(syncutil.WithSize(e.options.BulkSize + e.options.HeadlessBulkSize)) target.Iterate(func(value *contextargs.MetaInput) bool { wp.Add() diff --git a/pkg/core/executors.go b/pkg/core/executors.go index b491bd8e..ace7acb2 100644 --- a/pkg/core/executors.go +++ b/pkg/core/executors.go @@ -11,7 +11,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/templates" "github.com/projectdiscovery/nuclei/v3/pkg/templates/types" generalTypes "github.com/projectdiscovery/nuclei/v3/pkg/types" - "github.com/remeh/sizedwaitgroup" + syncutil "github.com/projectdiscovery/utils/sync" ) // Executors are low level executors that deals with template execution on a target @@ -104,9 +104,9 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target return true } - wg.WaitGroup.Add() + wg.Add() go func(index uint32, skip bool, value *contextargs.MetaInput) { - defer wg.WaitGroup.Done() + defer wg.Done() defer cleanupInFlight(index) if skip { return @@ -140,7 +140,7 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target index++ return true }) - wg.WaitGroup.Wait() + wg.Wait() // on completion marks the template as completed currentInfo.Lock() @@ -158,14 +158,14 @@ func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, ta wp := e.GetWorkPool() for _, tpl := range alltemplates { - var sg *sizedwaitgroup.SizedWaitGroup + var sg *syncutil.AdaptiveWaitGroup if tpl.Type() == types.HeadlessProtocol { sg = wp.Headless } else { sg = wp.Default } sg.Add() - go func(template *templates.Template, value *contextargs.MetaInput, wg *sizedwaitgroup.SizedWaitGroup) { + go func(template *templates.Template, value *contextargs.MetaInput, wg *syncutil.AdaptiveWaitGroup) { defer wg.Done() var match bool @@ -213,7 +213,7 @@ func (e *ChildExecuter) Close() *atomic.Bool { func (e *ChildExecuter) Execute(template *templates.Template, value *contextargs.MetaInput) { templateType := template.Type() - var wg *sizedwaitgroup.SizedWaitGroup + var wg *syncutil.AdaptiveWaitGroup if templateType == types.HeadlessProtocol { wg = e.e.workPool.Headless } else { diff --git a/pkg/core/workflow_execute.go b/pkg/core/workflow_execute.go index cb877cc6..19d6f0d6 100644 --- a/pkg/core/workflow_execute.go +++ b/pkg/core/workflow_execute.go @@ -5,13 +5,12 @@ import ( "net/http/cookiejar" "sync/atomic" - "github.com/remeh/sizedwaitgroup" - "github.com/projectdiscovery/gologger" "github.com/projectdiscovery/nuclei/v3/pkg/output" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs" "github.com/projectdiscovery/nuclei/v3/pkg/scan" "github.com/projectdiscovery/nuclei/v3/pkg/workflows" + syncutil "github.com/projectdiscovery/utils/sync" ) const workflowStepExecutionError = "[%s] Could not execute workflow step: %s\n" @@ -32,7 +31,7 @@ func (e *Engine) executeWorkflow(ctx *scan.ScanContext, w *workflows.Workflow) b if templateThreads == 1 { templateThreads++ } - swg := sizedwaitgroup.New(templateThreads) + swg, _ := syncutil.New(syncutil.WithSize(templateThreads)) for _, template := range w.Workflows { swg.Add() @@ -40,7 +39,7 @@ func (e *Engine) executeWorkflow(ctx *scan.ScanContext, w *workflows.Workflow) b func(template *workflows.WorkflowTemplate) { defer swg.Done() - if err := e.runWorkflowStep(template, ctx, results, &swg, w); err != nil { + if err := e.runWorkflowStep(template, ctx, results, swg, w); err != nil { gologger.Warning().Msgf(workflowStepExecutionError, template.Template, err) } }(template) @@ -51,7 +50,7 @@ func (e *Engine) executeWorkflow(ctx *scan.ScanContext, w *workflows.Workflow) b // runWorkflowStep runs a workflow step for the workflow. It executes the workflow // in a recursive manner running all subtemplates and matchers. -func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan.ScanContext, results *atomic.Bool, swg *sizedwaitgroup.SizedWaitGroup, w *workflows.Workflow) error { +func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan.ScanContext, results *atomic.Bool, swg *syncutil.AdaptiveWaitGroup, w *workflows.Workflow) error { var firstMatched bool var err error var mainErr error diff --git a/pkg/core/workpool.go b/pkg/core/workpool.go index 07117595..cd17a0b3 100644 --- a/pkg/core/workpool.go +++ b/pkg/core/workpool.go @@ -1,9 +1,8 @@ package core import ( - "github.com/remeh/sizedwaitgroup" - "github.com/projectdiscovery/nuclei/v3/pkg/templates/types" + syncutil "github.com/projectdiscovery/utils/sync" ) // WorkPool implements an execution pool for executing different @@ -12,8 +11,8 @@ import ( // It also allows Configuration of such requirements. This is used // for per-module like separate headless concurrency etc. type WorkPool struct { - Headless *sizedwaitgroup.SizedWaitGroup - Default *sizedwaitgroup.SizedWaitGroup + Headless *syncutil.AdaptiveWaitGroup + Default *syncutil.AdaptiveWaitGroup config WorkPoolConfig } @@ -31,13 +30,13 @@ type WorkPoolConfig struct { // NewWorkPool returns a new WorkPool instance func NewWorkPool(config WorkPoolConfig) *WorkPool { - headlessWg := sizedwaitgroup.New(config.HeadlessTypeConcurrency) - defaultWg := sizedwaitgroup.New(config.TypeConcurrency) + headlessWg, _ := syncutil.New(syncutil.WithSize(config.HeadlessTypeConcurrency)) + defaultWg, _ := syncutil.New(syncutil.WithSize(config.TypeConcurrency)) return &WorkPool{ config: config, - Headless: &headlessWg, - Default: &defaultWg, + Headless: headlessWg, + Default: defaultWg, } } @@ -47,19 +46,14 @@ func (w *WorkPool) Wait() { w.Headless.Wait() } -// InputWorkPool is a work pool per-input -type InputWorkPool struct { - WaitGroup *sizedwaitgroup.SizedWaitGroup -} - // InputPool returns a work pool for an input type -func (w *WorkPool) InputPool(templateType types.ProtocolType) *InputWorkPool { +func (w *WorkPool) InputPool(templateType types.ProtocolType) *syncutil.AdaptiveWaitGroup { var count int if templateType == types.HeadlessProtocol { count = w.config.HeadlessInputConcurrency } else { count = w.config.InputConcurrency } - swg := sizedwaitgroup.New(count) - return &InputWorkPool{WaitGroup: &swg} + swg, _ := syncutil.New(syncutil.WithSize(count)) + return swg } diff --git a/pkg/js/compiler/non-pool.go b/pkg/js/compiler/non-pool.go index 8057c496..218b89b8 100644 --- a/pkg/js/compiler/non-pool.go +++ b/pkg/js/compiler/non-pool.go @@ -4,13 +4,13 @@ import ( "sync" "github.com/dop251/goja" - "github.com/remeh/sizedwaitgroup" + syncutil "github.com/projectdiscovery/utils/sync" ) var ( - ephemeraljsc = sizedwaitgroup.New(NonPoolingVMConcurrency) + ephemeraljsc, _ = syncutil.New(syncutil.WithSize(NonPoolingVMConcurrency)) lazyFixedSgInit = sync.OnceFunc(func() { - ephemeraljsc = sizedwaitgroup.New(NonPoolingVMConcurrency) + ephemeraljsc, _ = syncutil.New(syncutil.WithSize(NonPoolingVMConcurrency)) }) ) diff --git a/pkg/js/compiler/pool.go b/pkg/js/compiler/pool.go index 6dba600f..fc0b6163 100644 --- a/pkg/js/compiler/pool.go +++ b/pkg/js/compiler/pool.go @@ -36,7 +36,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/js/libs/goconsole" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" stringsutil "github.com/projectdiscovery/utils/strings" - "github.com/remeh/sizedwaitgroup" + syncutil "github.com/projectdiscovery/utils/sync" ) const ( @@ -51,9 +51,9 @@ var ( // autoregister console node module with default printer it uses gologger backend require.RegisterNativeModule(console.ModuleName, console.RequireWithPrinter(goconsole.NewGoConsolePrinter())) }) - pooljsc sizedwaitgroup.SizedWaitGroup + pooljsc *syncutil.AdaptiveWaitGroup lazySgInit = sync.OnceFunc(func() { - pooljsc = sizedwaitgroup.New(PoolingJsVmConcurrency) + pooljsc, _ = syncutil.New(syncutil.WithSize(PoolingJsVmConcurrency)) }) ) diff --git a/pkg/protocols/common/automaticscan/automaticscan.go b/pkg/protocols/common/automaticscan/automaticscan.go index 81939471..7119377b 100644 --- a/pkg/protocols/common/automaticscan/automaticscan.go +++ b/pkg/protocols/common/automaticscan/automaticscan.go @@ -30,8 +30,8 @@ import ( mapsutil "github.com/projectdiscovery/utils/maps" sliceutil "github.com/projectdiscovery/utils/slice" stringsutil "github.com/projectdiscovery/utils/strings" + syncutil "github.com/projectdiscovery/utils/sync" wappalyzer "github.com/projectdiscovery/wappalyzergo" - "github.com/remeh/sizedwaitgroup" "gopkg.in/yaml.v2" ) @@ -128,7 +128,10 @@ func (s *Service) Close() bool { func (s *Service) Execute() error { gologger.Info().Msgf("Executing Automatic scan on %d target[s]", s.target.Count()) // setup host concurrency - sg := sizedwaitgroup.New(s.opts.Options.BulkSize) + sg, err := syncutil.New(syncutil.WithSize(s.opts.Options.BulkSize)) + if err != nil { + return err + } s.target.Iterate(func(value *contextargs.MetaInput) bool { sg.Add() go func(input *contextargs.MetaInput) { @@ -246,7 +249,7 @@ func (s *Service) getTagsUsingDetectionTemplates(input *contextargs.MetaInput) ( // execute tech detection templates on target tags := map[string]struct{}{} m := &sync.Mutex{} - sg := sizedwaitgroup.New(s.opts.Options.TemplateThreads) + sg, _ := syncutil.New(syncutil.WithSize(s.opts.Options.TemplateThreads)) counter := atomic.Uint32{} for _, t := range s.techTemplates { diff --git a/pkg/protocols/dns/request.go b/pkg/protocols/dns/request.go index 280e8161..ba9dd666 100644 --- a/pkg/protocols/dns/request.go +++ b/pkg/protocols/dns/request.go @@ -9,7 +9,6 @@ import ( "github.com/miekg/dns" "github.com/pkg/errors" - "github.com/remeh/sizedwaitgroup" "go.uber.org/multierr" "golang.org/x/exp/maps" @@ -27,6 +26,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/utils" "github.com/projectdiscovery/retryabledns" iputil "github.com/projectdiscovery/utils/ip" + syncutil "github.com/projectdiscovery/utils/sync" ) var _ protocols.Request = &Request{} @@ -64,7 +64,10 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata, if request.generator != nil { iterator := request.generator.NewIterator() - swg := sizedwaitgroup.New(request.Threads) + swg, err := syncutil.New(syncutil.WithSize(request.Threads)) + if err != nil { + return err + } var multiErr error m := &sync.Mutex{} diff --git a/pkg/protocols/file/request.go b/pkg/protocols/file/request.go index eb73544f..f13f08d1 100644 --- a/pkg/protocols/file/request.go +++ b/pkg/protocols/file/request.go @@ -11,7 +11,6 @@ import ( "github.com/docker/go-units" "github.com/mholt/archiver" "github.com/pkg/errors" - "github.com/remeh/sizedwaitgroup" "github.com/projectdiscovery/gologger" "github.com/projectdiscovery/nuclei/v3/pkg/operators" @@ -24,6 +23,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/helpers/responsehighlighter" templateTypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types" sliceutil "github.com/projectdiscovery/utils/slice" + syncutil "github.com/projectdiscovery/utils/sync" ) var _ protocols.Request = &Request{} @@ -47,8 +47,11 @@ var errEmptyResult = errors.New("Empty result") // ExecuteWithResults executes the protocol requests and returns results instead of writing them. func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata, previous output.InternalEvent, callback protocols.OutputEventCallback) error { - wg := sizedwaitgroup.New(request.options.Options.BulkSize) - err := request.getInputPaths(input.MetaInput.Input, func(filePath string) { + wg, err := syncutil.New(syncutil.WithSize(request.options.Options.BulkSize)) + if err != nil { + return err + } + err = request.getInputPaths(input.MetaInput.Input, func(filePath string) { wg.Add() func(filePath string) { defer wg.Done() diff --git a/pkg/protocols/headless/engine/page_actions.go b/pkg/protocols/headless/engine/page_actions.go index 7338db44..348cab0a 100644 --- a/pkg/protocols/headless/engine/page_actions.go +++ b/pkg/protocols/headless/engine/page_actions.go @@ -214,7 +214,7 @@ func geTimeParameter(p *Page, act *Action, parameterName string, defaultValue ti } // ActionAddHeader executes a AddHeader action. -func (p *Page) ActionAddHeader(act *Action, out map[string]string /*TODO review unused parameter*/) error { +func (p *Page) ActionAddHeader(act *Action, out map[string]string) error { in := p.getActionArgWithDefaultValues(act, "part") args := make(map[string]string) @@ -225,7 +225,7 @@ func (p *Page) ActionAddHeader(act *Action, out map[string]string /*TODO review } // ActionSetHeader executes a SetHeader action. -func (p *Page) ActionSetHeader(act *Action, out map[string]string /*TODO review unused parameter*/) error { +func (p *Page) ActionSetHeader(act *Action, out map[string]string) error { in := p.getActionArgWithDefaultValues(act, "part") args := make(map[string]string) @@ -236,7 +236,7 @@ func (p *Page) ActionSetHeader(act *Action, out map[string]string /*TODO review } // ActionDeleteHeader executes a DeleteHeader action. -func (p *Page) ActionDeleteHeader(act *Action, out map[string]string /*TODO review unused parameter*/) error { +func (p *Page) ActionDeleteHeader(act *Action, out map[string]string) error { in := p.getActionArgWithDefaultValues(act, "part") args := make(map[string]string) @@ -343,7 +343,7 @@ func (p *Page) RunScript(action *Action, out map[string]string) error { } // ClickElement executes click actions for an element. -func (p *Page) ClickElement(act *Action, out map[string]string /*TODO review unused parameter*/) error { +func (p *Page) ClickElement(act *Action, out map[string]string) error { element, err := p.pageElementBy(act.Data) if err != nil { return errors.Wrap(err, errCouldNotGetElement) @@ -358,12 +358,12 @@ func (p *Page) ClickElement(act *Action, out map[string]string /*TODO review unu } // KeyboardAction executes a keyboard action on the page. -func (p *Page) KeyboardAction(act *Action, out map[string]string /*TODO review unused parameter*/) error { +func (p *Page) KeyboardAction(act *Action, out map[string]string) error { return p.page.Keyboard.Type([]input.Key(p.getActionArgWithDefaultValues(act, "keys"))...) } // RightClickElement executes right click actions for an element. -func (p *Page) RightClickElement(act *Action, out map[string]string /*TODO review unused parameter*/) error { +func (p *Page) RightClickElement(act *Action, out map[string]string) error { element, err := p.pageElementBy(act.Data) if err != nil { return errors.Wrap(err, errCouldNotGetElement) @@ -441,7 +441,7 @@ func (p *Page) Screenshot(act *Action, out map[string]string) error { } // InputElement executes input element actions for an element. -func (p *Page) InputElement(act *Action, out map[string]string /*TODO review unused parameter*/) error { +func (p *Page) InputElement(act *Action, out map[string]string) error { value := p.getActionArgWithDefaultValues(act, "value") if value == "" { return errinvalidArguments @@ -460,7 +460,7 @@ func (p *Page) InputElement(act *Action, out map[string]string /*TODO review unu } // TimeInputElement executes time input on an element -func (p *Page) TimeInputElement(act *Action, out map[string]string /*TODO review unused parameter*/) error { +func (p *Page) TimeInputElement(act *Action, out map[string]string) error { value := p.getActionArgWithDefaultValues(act, "value") if value == "" { return errinvalidArguments @@ -483,7 +483,7 @@ func (p *Page) TimeInputElement(act *Action, out map[string]string /*TODO review } // SelectInputElement executes select input statement action on a element -func (p *Page) SelectInputElement(act *Action, out map[string]string /*TODO review unused parameter*/) error { +func (p *Page) SelectInputElement(act *Action, out map[string]string) error { value := p.getActionArgWithDefaultValues(act, "value") if value == "" { return errinvalidArguments @@ -508,7 +508,7 @@ func (p *Page) SelectInputElement(act *Action, out map[string]string /*TODO revi } // WaitLoad waits for the page to load -func (p *Page) WaitLoad(act *Action, out map[string]string /*TODO review unused parameter*/) error { +func (p *Page) WaitLoad(act *Action, out map[string]string) error { p.page.Timeout(2 * time.Second).WaitNavigation(proto.PageLifecycleEventNameFirstMeaningfulPaint)() // Wait for the window.onload event and also wait for the network requests @@ -538,7 +538,7 @@ func (p *Page) GetResource(act *Action, out map[string]string) error { } // FilesInput acts with a file input element on page -func (p *Page) FilesInput(act *Action, out map[string]string /*TODO review unused parameter*/) error { +func (p *Page) FilesInput(act *Action, out map[string]string) error { element, err := p.pageElementBy(act.Data) if err != nil { return errors.Wrap(err, errCouldNotGetElement) @@ -589,7 +589,7 @@ func (p *Page) ExtractElement(act *Action, out map[string]string) error { } // WaitEvent waits for an event to happen on the page. -func (p *Page) WaitEvent(act *Action, out map[string]string /*TODO review unused parameter*/) (func() error, error) { +func (p *Page) WaitEvent(act *Action, out map[string]string) (func() error, error) { event := p.getActionArgWithDefaultValues(act, "event") if event == "" { return nil, errors.New("event not recognized") @@ -661,14 +661,14 @@ func (p *Page) pageElementBy(data map[string]string) (*rod.Element, error) { } // DebugAction enables debug action on a page. -func (p *Page) DebugAction(act *Action, out map[string]string /*TODO review unused parameter*/) error { +func (p *Page) DebugAction(act *Action, out map[string]string) error { p.instance.browser.engine.SlowMotion(5 * time.Second) p.instance.browser.engine.Trace(true) return nil } // SleepAction sleeps on the page for a specified duration -func (p *Page) SleepAction(act *Action, out map[string]string /*TODO review unused parameter*/) error { +func (p *Page) SleepAction(act *Action, out map[string]string) error { seconds := act.Data["duration"] if seconds == "" { seconds = "5" diff --git a/pkg/protocols/http/httputils/spm.go b/pkg/protocols/http/httputils/spm.go index ccaa9a85..bca6c2ee 100644 --- a/pkg/protocols/http/httputils/spm.go +++ b/pkg/protocols/http/httputils/spm.go @@ -4,7 +4,7 @@ import ( "context" "sync" - "github.com/remeh/sizedwaitgroup" + syncutil "github.com/projectdiscovery/utils/sync" ) // WorkPoolType is the type of work pool to use @@ -26,7 +26,7 @@ type StopAtFirstMatchHandler[T any] struct { // work pool and its type poolType WorkPoolType - sgPool sizedwaitgroup.SizedWaitGroup + sgPool *syncutil.AdaptiveWaitGroup wgPool *sync.WaitGroup // internal / unexported @@ -40,10 +40,13 @@ type StopAtFirstMatchHandler[T any] struct { // NewBlockingSPMHandler creates a new stop at first match handler func NewBlockingSPMHandler[T any](ctx context.Context, size int, spm bool) *StopAtFirstMatchHandler[T] { ctx1, cancel := context.WithCancel(ctx) + + awg, _ := syncutil.New(syncutil.WithSize(size)) + s := &StopAtFirstMatchHandler[T]{ ResultChan: make(chan T, 1), poolType: Blocking, - sgPool: sizedwaitgroup.New(size), + sgPool: awg, internalWg: &sync.WaitGroup{}, ctx: ctx1, cancel: cancel, diff --git a/pkg/protocols/javascript/js.go b/pkg/protocols/javascript/js.go index fc324869..48a1be48 100644 --- a/pkg/protocols/javascript/js.go +++ b/pkg/protocols/javascript/js.go @@ -32,8 +32,8 @@ import ( templateTypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types" "github.com/projectdiscovery/nuclei/v3/pkg/types" errorutil "github.com/projectdiscovery/utils/errors" + syncutil "github.com/projectdiscovery/utils/sync" urlutil "github.com/projectdiscovery/utils/url" - "github.com/remeh/sizedwaitgroup" ) // Request is a request for the javascript protocol @@ -404,7 +404,8 @@ func (request *Request) executeRequestParallel(ctxParent context.Context, hostPo requestOptions := request.options gotmatches := &atomic.Bool{} - sg := sizedwaitgroup.New(threads) + sg, _ := syncutil.New(syncutil.WithSize(threads)) + if request.generator != nil { iterator := request.generator.NewIterator() for { diff --git a/pkg/protocols/network/networkclientpool/clientpool.go b/pkg/protocols/network/networkclientpool/clientpool.go index 1a933413..a67cee29 100644 --- a/pkg/protocols/network/networkclientpool/clientpool.go +++ b/pkg/protocols/network/networkclientpool/clientpool.go @@ -11,7 +11,7 @@ var ( ) // Init initializes the clientpool implementation -func Init(options *types.Options /*TODO review unused parameter*/) error { +func Init(options *types.Options) error { // Don't create clients if already created in the past. if normalClient != nil { return nil diff --git a/pkg/protocols/network/request.go b/pkg/protocols/network/request.go index 01991c04..ef0ff01c 100644 --- a/pkg/protocols/network/request.go +++ b/pkg/protocols/network/request.go @@ -12,7 +12,6 @@ import ( "time" "github.com/pkg/errors" - "github.com/remeh/sizedwaitgroup" "go.uber.org/multierr" "golang.org/x/exp/maps" @@ -34,6 +33,7 @@ import ( errorutil "github.com/projectdiscovery/utils/errors" mapsutil "github.com/projectdiscovery/utils/maps" "github.com/projectdiscovery/utils/reader" + syncutil "github.com/projectdiscovery/utils/sync" ) var ( @@ -178,7 +178,10 @@ func (request *Request) executeAddress(variables map[string]interface{}, actualA iterator := request.generator.NewIterator() var multiErr error m := &sync.Mutex{} - swg := sizedwaitgroup.New(request.Threads) + swg, err := syncutil.New(syncutil.WithSize(request.Threads)) + if err != nil { + return err + } for { value, ok := iterator.Value() diff --git a/pkg/protocols/offlinehttp/request.go b/pkg/protocols/offlinehttp/request.go index 7c64859e..4a440c16 100644 --- a/pkg/protocols/offlinehttp/request.go +++ b/pkg/protocols/offlinehttp/request.go @@ -6,7 +6,6 @@ import ( "os" "github.com/pkg/errors" - "github.com/remeh/sizedwaitgroup" "github.com/projectdiscovery/gologger" "github.com/projectdiscovery/nuclei/v3/pkg/output" @@ -17,6 +16,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils" templateTypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types" "github.com/projectdiscovery/utils/conversion" + syncutil "github.com/projectdiscovery/utils/sync" ) var _ protocols.Request = &Request{} @@ -29,10 +29,13 @@ func (request *Request) Type() templateTypes.ProtocolType { } // ExecuteWithResults executes the protocol requests and returns results instead of writing them. -func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata /*TODO review unused parameter*/, previous output.InternalEvent, callback protocols.OutputEventCallback) error { - wg := sizedwaitgroup.New(request.options.Options.BulkSize) +func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata, previous output.InternalEvent, callback protocols.OutputEventCallback) error { + wg, err := syncutil.New(syncutil.WithSize(request.options.Options.BulkSize)) + if err != nil { + return err + } - err := request.getInputPaths(input.MetaInput.Input, func(data string) { + err = request.getInputPaths(input.MetaInput.Input, func(data string) { wg.Add() go func(data string) { diff --git a/pkg/tmplexec/flow/vm.go b/pkg/tmplexec/flow/vm.go index 2e22bd8e..f1f7dbb8 100644 --- a/pkg/tmplexec/flow/vm.go +++ b/pkg/tmplexec/flow/vm.go @@ -1,6 +1,7 @@ package flow import ( + "context" "reflect" "sync" @@ -12,33 +13,10 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/utils/vardump" "github.com/projectdiscovery/nuclei/v3/pkg/tmplexec/flow/builtin" "github.com/projectdiscovery/nuclei/v3/pkg/types" - "github.com/remeh/sizedwaitgroup" + "github.com/projectdiscovery/utils/sync/sizedpool" ) -type jsWaitGroup struct { - sync.Once - sg sizedwaitgroup.SizedWaitGroup -} - -var jsPool = &jsWaitGroup{} - -// GetJSRuntime returns a new JS runtime from pool -func GetJSRuntime(opts *types.Options) *goja.Runtime { - jsPool.Do(func() { - if opts.JsConcurrency < 100 { - opts.JsConcurrency = 100 - } - jsPool.sg = sizedwaitgroup.New(opts.JsConcurrency) - }) - jsPool.sg.Add() - return gojapool.Get().(*goja.Runtime) -} - -// PutJSRuntime returns a JS runtime to pool -func PutJSRuntime(runtime *goja.Runtime) { - defer jsPool.sg.Done() - gojapool.Put(runtime) -} +var jsOnce sync.Once // js runtime pool using sync.Pool var gojapool = &sync.Pool{ @@ -49,8 +27,29 @@ var gojapool = &sync.Pool{ }, } -func registerBuiltins(runtime *goja.Runtime) { +var sizedgojapool *sizedpool.SizedPool[*goja.Runtime] +// GetJSRuntime returns a new JS runtime from pool +func GetJSRuntime(opts *types.Options) *goja.Runtime { + jsOnce.Do(func() { + if opts.JsConcurrency < 100 { + opts.JsConcurrency = 100 + } + sizedgojapool, _ = sizedpool.New[*goja.Runtime]( + sizedpool.WithPool[*goja.Runtime](gojapool), + sizedpool.WithSize[*goja.Runtime](int64(opts.JsConcurrency)), + ) + }) + runtime, _ := sizedgojapool.Get(context.TODO()) + return runtime +} + +// PutJSRuntime returns a JS runtime to pool +func PutJSRuntime(runtime *goja.Runtime) { + sizedgojapool.Put(runtime) +} + +func registerBuiltins(runtime *goja.Runtime) { _ = gojs.RegisterFuncWithSignature(runtime, gojs.FuncOpts{ Name: "log", Description: "Logs a given object/message to stdout (only for debugging purposes)",