making payload concurrency dynamic via direct int change

dev
mzack 2024-04-03 23:06:08 +02:00
parent a140a4194e
commit af7450737a
6 changed files with 48 additions and 12 deletions

View File

@ -62,6 +62,9 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata,
variablesMap := request.options.Variables.Evaluate(vars) variablesMap := request.options.Variables.Evaluate(vars)
vars = generators.MergeMaps(vars, variablesMap, request.options.Constants) vars = generators.MergeMaps(vars, variablesMap, request.options.Constants)
// if request threads matches global payload concurrency we follow it
shouldFollowGlobal := request.Threads == request.options.Options.PayloadConcurrency
if request.generator != nil { if request.generator != nil {
iterator := request.generator.NewIterator() iterator := request.generator.NewIterator()
swg, err := syncutil.New(syncutil.WithSize(request.Threads)) swg, err := syncutil.New(syncutil.WithSize(request.Threads))
@ -76,6 +79,12 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata,
if !ok { if !ok {
break break
} }
// resize check point - nop if there are no changes
if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency {
swg.Resize(request.options.Options.PayloadConcurrency)
}
value = generators.MergeMaps(vars, value) value = generators.MergeMaps(vars, value)
swg.Add() swg.Add()
go func(newVars map[string]interface{}) { go func(newVars map[string]interface{}) {

View File

@ -143,6 +143,16 @@ func (h *StopAtFirstMatchHandler[T]) Release() {
} }
} }
func (h *StopAtFirstMatchHandler[T]) Resize(size int) {
if h.sgPool.Size != size {
h.sgPool.Resize(size)
}
}
func (h *StopAtFirstMatchHandler[T]) Size() int {
return h.sgPool.Size
}
// Wait waits for all work to be done // Wait waits for all work to be done
func (h *StopAtFirstMatchHandler[T]) Wait() { func (h *StopAtFirstMatchHandler[T]) Wait() {
switch h.poolType { switch h.poolType {

View File

@ -165,6 +165,9 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
// Workers that keeps enqueuing new requests // Workers that keeps enqueuing new requests
maxWorkers := request.Threads maxWorkers := request.Threads
// if request threads matches global payload concurrency we follow it
shouldFollowGlobal := maxWorkers == request.options.Options.PayloadConcurrency
if protocolstate.IsLowOnMemory() { if protocolstate.IsLowOnMemory() {
maxWorkers = protocolstate.GuardThreadsOrDefault(request.Threads) maxWorkers = protocolstate.GuardThreadsOrDefault(request.Threads)
} }
@ -198,6 +201,12 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
if !ok { if !ok {
break break
} }
// resize check point - nop if there are no changes
if shouldFollowGlobal && spmHandler.Size() != request.options.Options.PayloadConcurrency {
spmHandler.Resize(request.options.Options.PayloadConcurrency)
}
ctx := request.newContext(input) ctx := request.newContext(input)
generatedHttpRequest, err := generator.Make(ctx, input, inputData, payloads, dynamicValues) generatedHttpRequest, err := generator.Make(ctx, input, inputData, payloads, dynamicValues)
if err != nil { if err != nil {

View File

@ -404,6 +404,9 @@ func (request *Request) executeRequestParallel(ctxParent context.Context, hostPo
requestOptions := request.options requestOptions := request.options
gotmatches := &atomic.Bool{} gotmatches := &atomic.Bool{}
// if request threads matches global payload concurrency we follow it
shouldFollowGlobal := threads == request.options.Options.PayloadConcurrency
sg, _ := syncutil.New(syncutil.WithSize(threads)) sg, _ := syncutil.New(syncutil.WithSize(threads))
if request.generator != nil { if request.generator != nil {
@ -413,6 +416,12 @@ func (request *Request) executeRequestParallel(ctxParent context.Context, hostPo
if !ok { if !ok {
break break
} }
// resize check point - nop if there are no changes
if shouldFollowGlobal && sg.Size != request.options.Options.PayloadConcurrency {
sg.Resize(request.options.Options.PayloadConcurrency)
}
sg.Add() sg.Add()
go func() { go func() {
defer sg.Done() defer sg.Done()

View File

@ -174,6 +174,9 @@ func (request *Request) executeAddress(variables map[string]interface{}, actualA
return err return err
} }
// if request threads matches global payload concurrency we follow it
shouldFollowGlobal := request.Threads == request.options.Options.PayloadConcurrency
if request.generator != nil { if request.generator != nil {
iterator := request.generator.NewIterator() iterator := request.generator.NewIterator()
var multiErr error var multiErr error
@ -188,6 +191,12 @@ func (request *Request) executeAddress(variables map[string]interface{}, actualA
if !ok { if !ok {
break break
} }
// resize check point - nop if there are no changes
if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency {
swg.Resize(request.options.Options.PayloadConcurrency)
}
value = generators.MergeMaps(value, payloads) value = generators.MergeMaps(value, payloads)
swg.Add() swg.Add()
go func(vars map[string]interface{}) { go func(vars map[string]interface{}) {

View File

@ -34,9 +34,6 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/types" "github.com/projectdiscovery/nuclei/v3/pkg/types"
) )
// Optional Callback to update Thread count in payloads across all requests
type PayloadThreadSetterCallback func(opts *ExecutorOptions, totalRequests, currentThreads int) int
var ( var (
MaxTemplateFileSizeForEncoding = 1024 * 1024 MaxTemplateFileSizeForEncoding = 1024 * 1024
) )
@ -114,10 +111,6 @@ type ExecutorOptions struct {
// JsCompiler is abstracted javascript compiler which adds node modules and provides execution // JsCompiler is abstracted javascript compiler which adds node modules and provides execution
// environment for javascript templates // environment for javascript templates
JsCompiler *compiler.Compiler JsCompiler *compiler.Compiler
// Optional Callback function to update Thread count in payloads across all protocols
// based on given logic. by default nuclei reverts to using value of `-c` when threads count
// is not specified or is 0 in template
OverrideThreadsCount PayloadThreadSetterCallback
// AuthProvider is a provider for auth strategies // AuthProvider is a provider for auth strategies
AuthProvider authprovider.AuthProvider AuthProvider authprovider.AuthProvider
//TemporaryDirectory is the directory to store temporary files //TemporaryDirectory is the directory to store temporary files
@ -142,14 +135,11 @@ func (eo *ExecutorOptions) RateLimitTake() {
// GetThreadsForPayloadRequests returns the number of threads to use as default for // GetThreadsForPayloadRequests returns the number of threads to use as default for
// given max-request of payloads // given max-request of payloads
func (e *ExecutorOptions) GetThreadsForNPayloadRequests(totalRequests int, currentThreads int) int { func (e *ExecutorOptions) GetThreadsForNPayloadRequests(totalRequests int, currentThreads int) int {
if e.OverrideThreadsCount != nil {
return e.OverrideThreadsCount(e, totalRequests, currentThreads)
}
if currentThreads > 0 { if currentThreads > 0 {
return currentThreads return currentThreads
} else {
return e.Options.PayloadConcurrency
} }
return e.Options.PayloadConcurrency
} }
// CreateTemplateCtxStore creates template context store (which contains templateCtx for every scan) // CreateTemplateCtxStore creates template context store (which contains templateCtx for every scan)