Merge pull request #4986 from projectdiscovery/feat-3072-init-adaptive-speed

Initial Refactor for speed control
dev
Mzack9999 2024-04-10 00:49:33 +01:00 committed by GitHub
commit 721ddda915
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 464 additions and 166 deletions

View File

@ -66,3 +66,7 @@ jobs:
- name: Example SDK Advanced - name: Example SDK Advanced
run: go run . run: go run .
working-directory: examples/advanced/ working-directory: examples/advanced/
- name: Example SDK with speed control
run: go run .
working-directory: examples/with_speed_control/

View File

@ -15,6 +15,7 @@ import (
"github.com/projectdiscovery/utils/auth/pdcp" "github.com/projectdiscovery/utils/auth/pdcp"
"github.com/projectdiscovery/utils/env" "github.com/projectdiscovery/utils/env"
_ "github.com/projectdiscovery/utils/pprof" _ "github.com/projectdiscovery/utils/pprof"
stringsutil "github.com/projectdiscovery/utils/strings"
"github.com/projectdiscovery/goflags" "github.com/projectdiscovery/goflags"
"github.com/projectdiscovery/gologger" "github.com/projectdiscovery/gologger"
@ -329,13 +330,15 @@ on extensive configurability, massive extensibility and ease of use.`)
flagSet.CreateGroup("rate-limit", "Rate-Limit", flagSet.CreateGroup("rate-limit", "Rate-Limit",
flagSet.IntVarP(&options.RateLimit, "rate-limit", "rl", 150, "maximum number of requests to send per second"), flagSet.IntVarP(&options.RateLimit, "rate-limit", "rl", 150, "maximum number of requests to send per second"),
flagSet.IntVarP(&options.RateLimitMinute, "rate-limit-minute", "rlm", 0, "maximum number of requests to send per minute"), flagSet.DurationVarP(&options.RateLimitDuration, "rate-limit-duration", "rld", time.Second, "maximum number of requests to send per second"),
flagSet.IntVarP(&options.RateLimitMinute, "rate-limit-minute", "rlm", 0, "maximum number of requests to send per minute (DEPRECATED)"),
flagSet.IntVarP(&options.BulkSize, "bulk-size", "bs", 25, "maximum number of hosts to be analyzed in parallel per template"), flagSet.IntVarP(&options.BulkSize, "bulk-size", "bs", 25, "maximum number of hosts to be analyzed in parallel per template"),
flagSet.IntVarP(&options.TemplateThreads, "concurrency", "c", 25, "maximum number of templates to be executed in parallel"), flagSet.IntVarP(&options.TemplateThreads, "concurrency", "c", 25, "maximum number of templates to be executed in parallel"),
flagSet.IntVarP(&options.HeadlessBulkSize, "headless-bulk-size", "hbs", 10, "maximum number of headless hosts to be analyzed in parallel per template"), flagSet.IntVarP(&options.HeadlessBulkSize, "headless-bulk-size", "hbs", 10, "maximum number of headless hosts to be analyzed in parallel per template"),
flagSet.IntVarP(&options.HeadlessTemplateThreads, "headless-concurrency", "headc", 10, "maximum number of headless templates to be executed in parallel"), flagSet.IntVarP(&options.HeadlessTemplateThreads, "headless-concurrency", "headc", 10, "maximum number of headless templates to be executed in parallel"),
flagSet.IntVarP(&options.JsConcurrency, "js-concurrency", "jsc", 120, "maximum number of javascript runtimes to be executed in parallel"), flagSet.IntVarP(&options.JsConcurrency, "js-concurrency", "jsc", 120, "maximum number of javascript runtimes to be executed in parallel"),
flagSet.IntVarP(&options.PayloadConcurrency, "payload-concurrency", "pc", 25, "max payload concurrency for each template"), flagSet.IntVarP(&options.PayloadConcurrency, "payload-concurrency", "pc", 25, "max payload concurrency for each template"),
flagSet.IntVarP(&options.ProbeConcurrency, "probe-concurrency", "prc", 50, "http probe concurrency with httpx"),
) )
flagSet.CreateGroup("optimization", "Optimizations", flagSet.CreateGroup("optimization", "Optimizations",
flagSet.IntVar(&options.Timeout, "timeout", 10, "time to wait in seconds before timeout"), flagSet.IntVar(&options.Timeout, "timeout", 10, "time to wait in seconds before timeout"),
@ -597,10 +600,10 @@ Note: Make sure you have backup of your custom nuclei-templates before proceedin
gologger.Fatal().Msgf("could not read response: %s", err) gologger.Fatal().Msgf("could not read response: %s", err)
} }
resp = strings.TrimSpace(resp) resp = strings.TrimSpace(resp)
if strings.EqualFold(resp, "y") || strings.EqualFold(resp, "yes") { if stringsutil.EqualFoldAny(resp, "y", "yes") {
break break
} }
if strings.EqualFold(resp, "n") || strings.EqualFold(resp, "no") || resp == "" { if stringsutil.EqualFoldAny(resp, "n", "no", "") {
fmt.Println("Exiting...") fmt.Println("Exiting...")
os.Exit(0) os.Exit(0)
} }

View File

@ -2,7 +2,7 @@ package main
import ( import (
nuclei "github.com/projectdiscovery/nuclei/v3/lib" nuclei "github.com/projectdiscovery/nuclei/v3/lib"
"github.com/remeh/sizedwaitgroup" syncutil "github.com/projectdiscovery/utils/sync"
) )
func main() { func main() {
@ -12,7 +12,10 @@ func main() {
panic(err) panic(err)
} }
// setup sizedWaitgroup to handle concurrency // setup sizedWaitgroup to handle concurrency
sg := sizedwaitgroup.New(10) sg, err := syncutil.New(syncutil.WithSize(10))
if err != nil {
panic(err)
}
// scan 1 = run dns templates on scanme.sh // scan 1 = run dns templates on scanme.sh
sg.Add() sg.Add()

View File

@ -0,0 +1,104 @@
package main
import (
"log"
"sync"
"time"
nuclei "github.com/projectdiscovery/nuclei/v3/lib"
"github.com/projectdiscovery/nuclei/v3/pkg/templates/types"
)
func main() {
ne, err := initializeNucleiEngine()
if err != nil {
panic(err)
}
defer ne.Close()
ne.LoadTargets([]string{"http://honey.scanme.sh"}, false)
var wg sync.WaitGroup
wg.Add(3)
go testRateLimit(&wg, ne)
go testThreadsAndBulkSize(&wg, ne)
go testPayloadConcurrency(&wg, ne)
err = ne.ExecuteWithCallback(nil)
if err != nil {
panic(err)
}
wg.Wait()
}
func initializeNucleiEngine() (*nuclei.NucleiEngine, error) {
return nuclei.NewNucleiEngine(
nuclei.WithTemplateFilters(nuclei.TemplateFilters{Tags: []string{"oast"}}),
nuclei.EnableStatsWithOpts(nuclei.StatsOptions{MetricServerPort: 6064}),
nuclei.WithGlobalRateLimit(1, time.Second),
nuclei.WithConcurrency(nuclei.Concurrency{
TemplateConcurrency: 1,
HostConcurrency: 1,
HeadlessHostConcurrency: 1,
HeadlessTemplateConcurrency: 1,
JavascriptTemplateConcurrency: 1,
TemplatePayloadConcurrency: 1,
ProbeConcurrency: 1,
}),
)
}
func testRateLimit(wg *sync.WaitGroup, ne *nuclei.NucleiEngine) {
defer wg.Done()
verifyRateLimit(ne, 1, 5000)
}
func testThreadsAndBulkSize(wg *sync.WaitGroup, ne *nuclei.NucleiEngine) {
defer wg.Done()
initialTemplateThreads, initialBulkSize := 1, 1
verifyThreadsAndBulkSize(ne, initialTemplateThreads, initialBulkSize, 25, 25)
}
func testPayloadConcurrency(wg *sync.WaitGroup, ne *nuclei.NucleiEngine) {
defer wg.Done()
verifyPayloadConcurrency(ne, 1, 500)
}
func verifyRateLimit(ne *nuclei.NucleiEngine, initialRate, finalRate int) {
if ne.GetExecuterOptions().RateLimiter.GetLimit() != uint(initialRate) {
panic("wrong initial rate limit")
}
time.Sleep(5 * time.Second)
ne.Options().RateLimit = finalRate
time.Sleep(20 * time.Second)
if ne.GetExecuterOptions().RateLimiter.GetLimit() != uint(finalRate) {
panic("wrong final rate limit")
}
}
func verifyThreadsAndBulkSize(ne *nuclei.NucleiEngine, initialThreads, initialBulk, finalThreads, finalBulk int) {
if ne.Options().TemplateThreads != initialThreads || ne.Options().BulkSize != initialBulk {
panic("wrong initial standard concurrency")
}
time.Sleep(5 * time.Second)
ne.Options().TemplateThreads = finalThreads
ne.Options().BulkSize = finalBulk
time.Sleep(20 * time.Second)
if ne.Engine().GetWorkPool().InputPool(types.HTTPProtocol).Size != finalBulk || ne.Engine().WorkPool().Default.Size != finalThreads {
log.Fatal("wrong final concurrency", ne.Engine().WorkPool().Default.Size, finalThreads, ne.Engine().GetWorkPool().InputPool(types.HTTPProtocol).Size, finalBulk)
}
}
func verifyPayloadConcurrency(ne *nuclei.NucleiEngine, initialPayloadConcurrency, finalPayloadConcurrency int) {
if ne.Options().PayloadConcurrency != initialPayloadConcurrency {
panic("wrong initial payload concurrency")
}
time.Sleep(5 * time.Second)
ne.Options().PayloadConcurrency = finalPayloadConcurrency
time.Sleep(20 * time.Second)
if ne.GetExecuterOptions().GetThreadsForNPayloadRequests(100, 0) != finalPayloadConcurrency {
panic("wrong final payload concurrency")
}
}

12
go.mod
View File

@ -38,7 +38,7 @@ require (
github.com/weppos/publicsuffix-go v0.30.2-0.20230730094716-a20f9abcc222 github.com/weppos/publicsuffix-go v0.30.2-0.20230730094716-a20f9abcc222
github.com/xanzy/go-gitlab v0.84.0 github.com/xanzy/go-gitlab v0.84.0
go.uber.org/multierr v1.11.0 go.uber.org/multierr v1.11.0
golang.org/x/net v0.21.0 golang.org/x/net v0.24.0
golang.org/x/oauth2 v0.11.0 golang.org/x/oauth2 v0.11.0
golang.org/x/text v0.14.0 golang.org/x/text v0.14.0
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
@ -94,13 +94,13 @@ require (
github.com/projectdiscovery/tlsx v1.1.6 github.com/projectdiscovery/tlsx v1.1.6
github.com/projectdiscovery/uncover v1.0.7 github.com/projectdiscovery/uncover v1.0.7
github.com/projectdiscovery/useragent v0.0.40 github.com/projectdiscovery/useragent v0.0.40
github.com/projectdiscovery/utils v0.0.88 github.com/projectdiscovery/utils v0.0.88-0.20240404181359-663cfe2196d0
github.com/projectdiscovery/wappalyzergo v0.0.116 github.com/projectdiscovery/wappalyzergo v0.0.116
github.com/redis/go-redis/v9 v9.1.0 github.com/redis/go-redis/v9 v9.1.0
github.com/seh-msft/burpxml v1.0.1 github.com/seh-msft/burpxml v1.0.1
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/zmap/zgrab2 v0.1.8-0.20230806160807-97ba87c0e706 github.com/zmap/zgrab2 v0.1.8-0.20230806160807-97ba87c0e706
golang.org/x/term v0.17.0 golang.org/x/term v0.19.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
moul.io/http2curl v1.0.0 moul.io/http2curl v1.0.0
) )
@ -142,6 +142,8 @@ require (
github.com/docker/cli v24.0.5+incompatible // indirect github.com/docker/cli v24.0.5+incompatible // indirect
github.com/docker/docker v24.0.9+incompatible // indirect github.com/docker/docker v24.0.9+incompatible // indirect
github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-connections v0.4.0 // indirect
github.com/eapache/channels v1.1.0 // indirect
github.com/eapache/queue v1.1.0 // indirect
github.com/fatih/color v1.15.0 // indirect github.com/fatih/color v1.15.0 // indirect
github.com/free5gc/util v1.0.5-0.20230511064842-2e120956883b // indirect github.com/free5gc/util v1.0.5-0.20230511064842-2e120956883b // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect
@ -300,10 +302,10 @@ require (
go.etcd.io/bbolt v1.3.8 // indirect go.etcd.io/bbolt v1.3.8 // indirect
go.uber.org/zap v1.25.0 // indirect go.uber.org/zap v1.25.0 // indirect
goftp.io/server/v2 v2.0.1 // indirect goftp.io/server/v2 v2.0.1 // indirect
golang.org/x/crypto v0.19.0 // indirect golang.org/x/crypto v0.22.0 // indirect
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a golang.org/x/exp v0.0.0-20240119083558-1b970713d09a
golang.org/x/mod v0.14.0 // indirect golang.org/x/mod v0.14.0 // indirect
golang.org/x/sys v0.17.0 // indirect golang.org/x/sys v0.19.0 // indirect
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.17.0 golang.org/x/tools v0.17.0
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect

22
go.sum
View File

@ -296,8 +296,11 @@ github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj6
github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/eapache/channels v1.1.0 h1:F1taHcn7/F0i8DYqKXJnyhJcVpp2kgFcNePxXtnyu4k=
github.com/eapache/channels v1.1.0/go.mod h1:jMm2qB5Ubtg9zLd+inMZd2/NUvXgzmWXsDaLyQIGfH0=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc=
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M=
github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a h1:mATvB/9r/3gvcejNsXKSkQ6lcIaNec2nyfOdlTBR2lU= github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a h1:mATvB/9r/3gvcejNsXKSkQ6lcIaNec2nyfOdlTBR2lU=
@ -883,8 +886,8 @@ github.com/projectdiscovery/uncover v1.0.7 h1:ut+2lTuvmftmveqF5RTjMWAgyLj8ltPQC7
github.com/projectdiscovery/uncover v1.0.7/go.mod h1:HFXgm1sRPuoN0D4oATljPIdmbo/EEh1wVuxQqo/dwFE= github.com/projectdiscovery/uncover v1.0.7/go.mod h1:HFXgm1sRPuoN0D4oATljPIdmbo/EEh1wVuxQqo/dwFE=
github.com/projectdiscovery/useragent v0.0.40 h1:1LUhReSGPkhqsM5n40OOC9dIoNqMGs1dyGFJcOmg2Fo= github.com/projectdiscovery/useragent v0.0.40 h1:1LUhReSGPkhqsM5n40OOC9dIoNqMGs1dyGFJcOmg2Fo=
github.com/projectdiscovery/useragent v0.0.40/go.mod h1:EvK1x3s948Gtqb/XOahXcauyejCL/rSgy5d1IAvsKT4= github.com/projectdiscovery/useragent v0.0.40/go.mod h1:EvK1x3s948Gtqb/XOahXcauyejCL/rSgy5d1IAvsKT4=
github.com/projectdiscovery/utils v0.0.88 h1:oYfCXM+8VHNLyH/H6cOibkuDUwHUAOBAMRNPFX6NPrs= github.com/projectdiscovery/utils v0.0.88-0.20240404181359-663cfe2196d0 h1:2ZR0yiN0cUm/qYEMq79MfcbgM374lJSdftheYhMFxNo=
github.com/projectdiscovery/utils v0.0.88/go.mod h1:lAWzFdGXtJRPKdhUu1Z46d8B8JbASTk1Z69WY6H/3kA= github.com/projectdiscovery/utils v0.0.88-0.20240404181359-663cfe2196d0/go.mod h1:lAWzFdGXtJRPKdhUu1Z46d8B8JbASTk1Z69WY6H/3kA=
github.com/projectdiscovery/wappalyzergo v0.0.116 h1:xy+mBpwbYo/0PSzmJOQ/RXHomEh0D3nDBcbCxsW69m8= github.com/projectdiscovery/wappalyzergo v0.0.116 h1:xy+mBpwbYo/0PSzmJOQ/RXHomEh0D3nDBcbCxsW69m8=
github.com/projectdiscovery/wappalyzergo v0.0.116/go.mod h1:hc/o+fgM8KtdpFesjfBTmHTwsR+yBd+4kYZW/DGy/x8= github.com/projectdiscovery/wappalyzergo v0.0.116/go.mod h1:hc/o+fgM8KtdpFesjfBTmHTwsR+yBd+4kYZW/DGy/x8=
github.com/projectdiscovery/yamldoc-go v1.0.4 h1:eZoESapnMw6WAHiVgRwNqvbJEfNHEH148uthhFbG5jE= github.com/projectdiscovery/yamldoc-go v1.0.4 h1:eZoESapnMw6WAHiVgRwNqvbJEfNHEH148uthhFbG5jE=
@ -1183,8 +1186,8 @@ golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -1277,8 +1280,8 @@ golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -1379,8 +1382,9 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
@ -1392,8 +1396,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= golang.org/x/term v0.19.0 h1:+ThwsDv+tYfnJFhF4L8jITxu1tdTWRTZpdsWgEgjL6Q=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@ -12,15 +12,14 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"
"github.com/projectdiscovery/nuclei/v3/pkg/utils" "github.com/projectdiscovery/nuclei/v3/pkg/utils"
stringsutil "github.com/projectdiscovery/utils/strings" stringsutil "github.com/projectdiscovery/utils/strings"
"github.com/remeh/sizedwaitgroup" syncutil "github.com/projectdiscovery/utils/sync"
) )
const probeBulkSize = 50 var GlobalProbeBulkSize = 50
// initializeTemplatesHTTPInput initializes the http form of input // initializeTemplatesHTTPInput initializes the http form of input
// for any loaded http templates if input is in non-standard format. // for any loaded http templates if input is in non-standard format.
func (r *Runner) initializeTemplatesHTTPInput() (*hybrid.HybridMap, error) { func (r *Runner) initializeTemplatesHTTPInput() (*hybrid.HybridMap, error) {
hm, err := hybrid.New(hybrid.DefaultDiskOptions) hm, err := hybrid.New(hybrid.DefaultDiskOptions)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not create temporary input file") return nil, errors.Wrap(err, "could not create temporary input file")
@ -31,8 +30,8 @@ func (r *Runner) initializeTemplatesHTTPInput() (*hybrid.HybridMap, error) {
} }
gologger.Info().Msgf("Running httpx on input host") gologger.Info().Msgf("Running httpx on input host")
var bulkSize = probeBulkSize var bulkSize = GlobalProbeBulkSize
if r.options.BulkSize > probeBulkSize { if r.options.BulkSize > GlobalProbeBulkSize {
bulkSize = r.options.BulkSize bulkSize = r.options.BulkSize
} }
@ -44,20 +43,29 @@ func (r *Runner) initializeTemplatesHTTPInput() (*hybrid.HybridMap, error) {
return nil, errors.Wrap(err, "could not create httpx client") return nil, errors.Wrap(err, "could not create httpx client")
} }
shouldFollowGlobalProbeBulkSize := bulkSize == GlobalProbeBulkSize
// Probe the non-standard URLs and store them in cache // Probe the non-standard URLs and store them in cache
swg := sizedwaitgroup.New(bulkSize) swg, err := syncutil.New(syncutil.WithSize(bulkSize))
count := int32(0) if err != nil {
return nil, errors.Wrap(err, "could not create adaptive group")
}
var count atomic.Int32
r.inputProvider.Iterate(func(value *contextargs.MetaInput) bool { r.inputProvider.Iterate(func(value *contextargs.MetaInput) bool {
if stringsutil.HasPrefixAny(value.Input, "http://", "https://") { if stringsutil.HasPrefixAny(value.Input, "http://", "https://") {
return true return true
} }
if shouldFollowGlobalProbeBulkSize && swg.Size != GlobalProbeBulkSize {
swg.Resize(GlobalProbeBulkSize)
}
swg.Add() swg.Add()
go func(input *contextargs.MetaInput) { go func(input *contextargs.MetaInput) {
defer swg.Done() defer swg.Done()
if result := utils.ProbeURL(input.Input, httpxClient); result != "" { if result := utils.ProbeURL(input.Input, httpxClient); result != "" {
atomic.AddInt32(&count, 1) count.Add(1)
_ = hm.Set(input.Input, []byte(result)) _ = hm.Set(input.Input, []byte(result))
} }
}(value) }(value)
@ -65,6 +73,6 @@ func (r *Runner) initializeTemplatesHTTPInput() (*hybrid.HybridMap, error) {
}) })
swg.Wait() swg.Wait()
gologger.Info().Msgf("Found %d URL from httpx", atomic.LoadInt32(&count)) gologger.Info().Msgf("Found %d URL from httpx", count.Load())
return hm, nil return hm, nil
} }

View File

@ -314,11 +314,17 @@ func New(options *types.Options) (*Runner, error) {
} }
if options.RateLimitMinute > 0 { if options.RateLimitMinute > 0 {
runner.rateLimiter = ratelimit.New(context.Background(), uint(options.RateLimitMinute), time.Minute) gologger.Print().Msgf("[%v] %v", aurora.BrightYellow("WRN"), "rate limit per minute is deprecated - use rate-limit-duration")
} else if options.RateLimit > 0 { options.RateLimit = options.RateLimitMinute
runner.rateLimiter = ratelimit.New(context.Background(), uint(options.RateLimit), time.Second) options.RateLimitDuration = time.Minute
} else { }
if options.RateLimit > 0 && options.RateLimitDuration == 0 {
options.RateLimitDuration = time.Second
}
if options.RateLimit == 0 && options.RateLimitDuration == 0 {
runner.rateLimiter = ratelimit.NewUnlimited(context.Background()) runner.rateLimiter = ratelimit.NewUnlimited(context.Background())
} else {
runner.rateLimiter = ratelimit.New(context.Background(), uint(options.RateLimit), options.RateLimitDuration)
} }
if tmpDir, err := os.MkdirTemp("", "nuclei-tmp-*"); err == nil { if tmpDir, err := os.MkdirTemp("", "nuclei-tmp-*"); err == nil {

View File

@ -2,6 +2,7 @@ package nuclei
import ( import (
"context" "context"
"errors"
"time" "time"
"github.com/projectdiscovery/goflags" "github.com/projectdiscovery/goflags"
@ -115,17 +116,48 @@ type Concurrency struct {
HeadlessTemplateConcurrency int // number of templates to run concurrently for headless templates (per host in host-spray mode) HeadlessTemplateConcurrency int // number of templates to run concurrently for headless templates (per host in host-spray mode)
JavascriptTemplateConcurrency int // number of templates to run concurrently for javascript templates (per host in host-spray mode) JavascriptTemplateConcurrency int // number of templates to run concurrently for javascript templates (per host in host-spray mode)
TemplatePayloadConcurrency int // max concurrent payloads to run for a template (a good default is 25) TemplatePayloadConcurrency int // max concurrent payloads to run for a template (a good default is 25)
ProbeConcurrency int // max concurrent http probes to run (a good default is 50)
} }
// WithConcurrency sets concurrency options // WithConcurrency sets concurrency options
func WithConcurrency(opts Concurrency) NucleiSDKOptions { func WithConcurrency(opts Concurrency) NucleiSDKOptions {
return func(e *NucleiEngine) error { return func(e *NucleiEngine) error {
e.opts.TemplateThreads = opts.TemplateConcurrency // minimum required is 1
e.opts.BulkSize = opts.HostConcurrency if opts.TemplateConcurrency <= 0 {
e.opts.HeadlessBulkSize = opts.HeadlessHostConcurrency return errors.New("template threads must be at least 1")
e.opts.HeadlessTemplateThreads = opts.HeadlessTemplateConcurrency } else {
e.opts.JsConcurrency = opts.JavascriptTemplateConcurrency e.opts.TemplateThreads = opts.TemplateConcurrency
e.opts.PayloadConcurrency = opts.TemplatePayloadConcurrency }
if opts.HostConcurrency <= 0 {
return errors.New("host concurrency must be at least 1")
} else {
e.opts.BulkSize = opts.HostConcurrency
}
if opts.HeadlessHostConcurrency <= 0 {
return errors.New("headless host concurrency must be at least 1")
} else {
e.opts.HeadlessBulkSize = opts.HeadlessHostConcurrency
}
if opts.HeadlessTemplateConcurrency <= 0 {
return errors.New("headless template threads must be at least 1")
} else {
e.opts.HeadlessTemplateThreads = opts.HeadlessTemplateConcurrency
}
if opts.JavascriptTemplateConcurrency <= 0 {
return errors.New("js must be at least 1")
} else {
e.opts.JsConcurrency = opts.JavascriptTemplateConcurrency
}
if opts.TemplatePayloadConcurrency <= 0 {
return errors.New("payload concurrency must be at least 1")
} else {
e.opts.PayloadConcurrency = opts.TemplatePayloadConcurrency
}
if opts.ProbeConcurrency <= 0 {
return errors.New("probe concurrency must be at least 1")
} else {
e.opts.ProbeConcurrency = opts.ProbeConcurrency
}
return nil return nil
} }
} }
@ -133,7 +165,9 @@ func WithConcurrency(opts Concurrency) NucleiSDKOptions {
// WithGlobalRateLimit sets global rate (i.e all hosts combined) limit options // WithGlobalRateLimit sets global rate (i.e all hosts combined) limit options
func WithGlobalRateLimit(maxTokens int, duration time.Duration) NucleiSDKOptions { func WithGlobalRateLimit(maxTokens int, duration time.Duration) NucleiSDKOptions {
return func(e *NucleiEngine) error { return func(e *NucleiEngine) error {
e.rateLimiter = ratelimit.New(context.Background(), uint(maxTokens), duration) e.opts.RateLimit = maxTokens
e.opts.RateLimitDuration = duration
e.rateLimiter = ratelimit.New(context.Background(), uint(e.opts.RateLimit), e.opts.RateLimitDuration)
return nil return nil
} }
} }

View File

@ -42,11 +42,16 @@ func createEphemeralObjects(base *NucleiEngine, opts *types.Options) (*unsafeOpt
Parser: base.parser, Parser: base.parser,
} }
if opts.RateLimitMinute > 0 { if opts.RateLimitMinute > 0 {
u.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(opts.RateLimitMinute), time.Minute) opts.RateLimit = opts.RateLimitMinute
} else if opts.RateLimit > 0 { opts.RateLimitDuration = time.Minute
u.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(opts.RateLimit), time.Second) }
} else { if opts.RateLimit > 0 && opts.RateLimitDuration == 0 {
opts.RateLimitDuration = time.Second
}
if opts.RateLimit == 0 && opts.RateLimitDuration == 0 {
u.executerOpts.RateLimiter = ratelimit.NewUnlimited(context.Background()) u.executerOpts.RateLimiter = ratelimit.NewUnlimited(context.Background())
} else {
u.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(opts.RateLimit), opts.RateLimitDuration)
} }
u.engine = core.New(opts) u.engine = core.New(opts)
u.engine.SetExecuterOptions(u.executerOpts) u.engine.SetExecuterOptions(u.executerOpts)

View File

@ -215,6 +215,14 @@ func (e *NucleiEngine) ExecuteWithCallback(callback ...func(event *output.Result
return nil return nil
} }
func (e *NucleiEngine) Options() *types.Options {
return e.opts
}
func (e *NucleiEngine) Engine() *core.Engine {
return e.engine
}
// NewNucleiEngine creates a new nuclei engine instance // NewNucleiEngine creates a new nuclei engine instance
func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) { func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) {
// default options // default options

View File

@ -192,11 +192,16 @@ func (e *NucleiEngine) init() error {
if e.executerOpts.RateLimiter == nil { if e.executerOpts.RateLimiter == nil {
if e.opts.RateLimitMinute > 0 { if e.opts.RateLimitMinute > 0 {
e.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(e.opts.RateLimitMinute), time.Minute) e.opts.RateLimit = e.opts.RateLimitMinute
} else if e.opts.RateLimit > 0 { e.opts.RateLimitDuration = time.Minute
e.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(e.opts.RateLimit), time.Second) }
} else { if e.opts.RateLimit > 0 && e.opts.RateLimitDuration == 0 {
e.opts.RateLimitDuration = time.Second
}
if e.opts.RateLimit == 0 && e.opts.RateLimitDuration == 0 {
e.executerOpts.RateLimiter = ratelimit.NewUnlimited(context.Background()) e.executerOpts.RateLimiter = ratelimit.NewUnlimited(context.Background())
} else {
e.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(e.opts.RateLimit), e.opts.RateLimitDuration)
} }
} }

View File

@ -30,14 +30,19 @@ func New(options *types.Options) *Engine {
return engine return engine
} }
// GetWorkPool returns a workpool from options func (e *Engine) GetWorkPoolConfig() WorkPoolConfig {
func (e *Engine) GetWorkPool() *WorkPool { config := WorkPoolConfig{
return NewWorkPool(WorkPoolConfig{
InputConcurrency: e.options.BulkSize, InputConcurrency: e.options.BulkSize,
TypeConcurrency: e.options.TemplateThreads, TypeConcurrency: e.options.TemplateThreads,
HeadlessInputConcurrency: e.options.HeadlessBulkSize, HeadlessInputConcurrency: e.options.HeadlessBulkSize,
HeadlessTypeConcurrency: e.options.HeadlessTemplateThreads, HeadlessTypeConcurrency: e.options.HeadlessTemplateThreads,
}) }
return config
}
// GetWorkPool returns a workpool from options
func (e *Engine) GetWorkPool() *WorkPool {
return NewWorkPool(e.GetWorkPoolConfig())
} }
// SetExecuterOptions sets the executer options for the engine. This is required // SetExecuterOptions sets the executer options for the engine. This is required
@ -53,5 +58,7 @@ func (e *Engine) ExecuterOptions() protocols.ExecutorOptions {
// WorkPool returns the worker pool for the engine // WorkPool returns the worker pool for the engine
func (e *Engine) WorkPool() *WorkPool { func (e *Engine) WorkPool() *WorkPool {
// resize check point - nop if there are no changes
e.workPool.RefreshWithConfig(e.GetWorkPoolConfig())
return e.workPool return e.workPool
} }

View File

@ -4,8 +4,6 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"github.com/remeh/sizedwaitgroup"
"github.com/projectdiscovery/gologger" "github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/nuclei/v3/pkg/input/provider" "github.com/projectdiscovery/nuclei/v3/pkg/input/provider"
"github.com/projectdiscovery/nuclei/v3/pkg/output" "github.com/projectdiscovery/nuclei/v3/pkg/output"
@ -14,6 +12,7 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/templates/types" "github.com/projectdiscovery/nuclei/v3/pkg/templates/types"
"github.com/projectdiscovery/nuclei/v3/pkg/types/scanstrategy" "github.com/projectdiscovery/nuclei/v3/pkg/types/scanstrategy"
stringsutil "github.com/projectdiscovery/utils/strings" stringsutil "github.com/projectdiscovery/utils/strings"
syncutil "github.com/projectdiscovery/utils/sync"
) )
// Execute takes a list of templates/workflows that have been compiled // Execute takes a list of templates/workflows that have been compiled
@ -109,9 +108,11 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe
wp := e.GetWorkPool() wp := e.GetWorkPool()
for _, template := range templatesList { for _, template := range templatesList {
templateType := template.Type() // resize check point - nop if there are no changes
wp.RefreshWithConfig(e.GetWorkPoolConfig())
var wg *sizedwaitgroup.SizedWaitGroup templateType := template.Type()
var wg *syncutil.AdaptiveWaitGroup
if templateType == types.HeadlessProtocol { if templateType == types.HeadlessProtocol {
wg = wp.Headless wg = wp.Headless
} else { } else {
@ -134,7 +135,7 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe
// executeHostSpray executes scan using host spray strategy where templates are iterated over each target // executeHostSpray executes scan using host spray strategy where templates are iterated over each target
func (e *Engine) executeHostSpray(templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool { func (e *Engine) executeHostSpray(templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
results := &atomic.Bool{} results := &atomic.Bool{}
wp := sizedwaitgroup.New(e.options.BulkSize + e.options.HeadlessBulkSize) wp, _ := syncutil.New(syncutil.WithSize(e.options.BulkSize + e.options.HeadlessBulkSize))
target.Iterate(func(value *contextargs.MetaInput) bool { target.Iterate(func(value *contextargs.MetaInput) bool {
wp.Add() wp.Add()

View File

@ -11,7 +11,7 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/templates" "github.com/projectdiscovery/nuclei/v3/pkg/templates"
"github.com/projectdiscovery/nuclei/v3/pkg/templates/types" "github.com/projectdiscovery/nuclei/v3/pkg/templates/types"
generalTypes "github.com/projectdiscovery/nuclei/v3/pkg/types" generalTypes "github.com/projectdiscovery/nuclei/v3/pkg/types"
"github.com/remeh/sizedwaitgroup" syncutil "github.com/projectdiscovery/utils/sync"
) )
// Executors are low level executors that deals with template execution on a target // Executors are low level executors that deals with template execution on a target
@ -104,9 +104,9 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target
return true return true
} }
wg.WaitGroup.Add() wg.Add()
go func(index uint32, skip bool, value *contextargs.MetaInput) { go func(index uint32, skip bool, value *contextargs.MetaInput) {
defer wg.WaitGroup.Done() defer wg.Done()
defer cleanupInFlight(index) defer cleanupInFlight(index)
if skip { if skip {
return return
@ -140,7 +140,7 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target
index++ index++
return true return true
}) })
wg.WaitGroup.Wait() wg.Wait()
// on completion marks the template as completed // on completion marks the template as completed
currentInfo.Lock() currentInfo.Lock()
@ -158,14 +158,17 @@ func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, ta
wp := e.GetWorkPool() wp := e.GetWorkPool()
for _, tpl := range alltemplates { for _, tpl := range alltemplates {
var sg *sizedwaitgroup.SizedWaitGroup // resize check point - nop if there are no changes
wp.RefreshWithConfig(e.GetWorkPoolConfig())
var sg *syncutil.AdaptiveWaitGroup
if tpl.Type() == types.HeadlessProtocol { if tpl.Type() == types.HeadlessProtocol {
sg = wp.Headless sg = wp.Headless
} else { } else {
sg = wp.Default sg = wp.Default
} }
sg.Add() sg.Add()
go func(template *templates.Template, value *contextargs.MetaInput, wg *sizedwaitgroup.SizedWaitGroup) { go func(template *templates.Template, value *contextargs.MetaInput, wg *syncutil.AdaptiveWaitGroup) {
defer wg.Done() defer wg.Done()
var match bool var match bool
@ -213,7 +216,10 @@ func (e *ChildExecuter) Close() *atomic.Bool {
func (e *ChildExecuter) Execute(template *templates.Template, value *contextargs.MetaInput) { func (e *ChildExecuter) Execute(template *templates.Template, value *contextargs.MetaInput) {
templateType := template.Type() templateType := template.Type()
var wg *sizedwaitgroup.SizedWaitGroup // resize check point - nop if there are no changes
e.e.WorkPool().RefreshWithConfig(e.e.GetWorkPoolConfig())
var wg *syncutil.AdaptiveWaitGroup
if templateType == types.HeadlessProtocol { if templateType == types.HeadlessProtocol {
wg = e.e.workPool.Headless wg = e.e.workPool.Headless
} else { } else {

View File

@ -5,13 +5,12 @@ import (
"net/http/cookiejar" "net/http/cookiejar"
"sync/atomic" "sync/atomic"
"github.com/remeh/sizedwaitgroup"
"github.com/projectdiscovery/gologger" "github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/nuclei/v3/pkg/output" "github.com/projectdiscovery/nuclei/v3/pkg/output"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"
"github.com/projectdiscovery/nuclei/v3/pkg/scan" "github.com/projectdiscovery/nuclei/v3/pkg/scan"
"github.com/projectdiscovery/nuclei/v3/pkg/workflows" "github.com/projectdiscovery/nuclei/v3/pkg/workflows"
syncutil "github.com/projectdiscovery/utils/sync"
) )
const workflowStepExecutionError = "[%s] Could not execute workflow step: %s\n" const workflowStepExecutionError = "[%s] Could not execute workflow step: %s\n"
@ -32,7 +31,7 @@ func (e *Engine) executeWorkflow(ctx *scan.ScanContext, w *workflows.Workflow) b
if templateThreads == 1 { if templateThreads == 1 {
templateThreads++ templateThreads++
} }
swg := sizedwaitgroup.New(templateThreads) swg, _ := syncutil.New(syncutil.WithSize(templateThreads))
for _, template := range w.Workflows { for _, template := range w.Workflows {
swg.Add() swg.Add()
@ -40,7 +39,7 @@ func (e *Engine) executeWorkflow(ctx *scan.ScanContext, w *workflows.Workflow) b
func(template *workflows.WorkflowTemplate) { func(template *workflows.WorkflowTemplate) {
defer swg.Done() defer swg.Done()
if err := e.runWorkflowStep(template, ctx, results, &swg, w); err != nil { if err := e.runWorkflowStep(template, ctx, results, swg, w); err != nil {
gologger.Warning().Msgf(workflowStepExecutionError, template.Template, err) gologger.Warning().Msgf(workflowStepExecutionError, template.Template, err)
} }
}(template) }(template)
@ -51,7 +50,7 @@ func (e *Engine) executeWorkflow(ctx *scan.ScanContext, w *workflows.Workflow) b
// runWorkflowStep runs a workflow step for the workflow. It executes the workflow // runWorkflowStep runs a workflow step for the workflow. It executes the workflow
// in a recursive manner running all subtemplates and matchers. // in a recursive manner running all subtemplates and matchers.
func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan.ScanContext, results *atomic.Bool, swg *sizedwaitgroup.SizedWaitGroup, w *workflows.Workflow) error { func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan.ScanContext, results *atomic.Bool, swg *syncutil.AdaptiveWaitGroup, w *workflows.Workflow) error {
var firstMatched bool var firstMatched bool
var err error var err error
var mainErr error var mainErr error

View File

@ -1,9 +1,8 @@
package core package core
import ( import (
"github.com/remeh/sizedwaitgroup"
"github.com/projectdiscovery/nuclei/v3/pkg/templates/types" "github.com/projectdiscovery/nuclei/v3/pkg/templates/types"
syncutil "github.com/projectdiscovery/utils/sync"
) )
// WorkPool implements an execution pool for executing different // WorkPool implements an execution pool for executing different
@ -12,8 +11,8 @@ import (
// It also allows Configuration of such requirements. This is used // It also allows Configuration of such requirements. This is used
// for per-module like separate headless concurrency etc. // for per-module like separate headless concurrency etc.
type WorkPool struct { type WorkPool struct {
Headless *sizedwaitgroup.SizedWaitGroup Headless *syncutil.AdaptiveWaitGroup
Default *sizedwaitgroup.SizedWaitGroup Default *syncutil.AdaptiveWaitGroup
config WorkPoolConfig config WorkPoolConfig
} }
@ -31,13 +30,13 @@ type WorkPoolConfig struct {
// NewWorkPool returns a new WorkPool instance // NewWorkPool returns a new WorkPool instance
func NewWorkPool(config WorkPoolConfig) *WorkPool { func NewWorkPool(config WorkPoolConfig) *WorkPool {
headlessWg := sizedwaitgroup.New(config.HeadlessTypeConcurrency) headlessWg, _ := syncutil.New(syncutil.WithSize(config.HeadlessTypeConcurrency))
defaultWg := sizedwaitgroup.New(config.TypeConcurrency) defaultWg, _ := syncutil.New(syncutil.WithSize(config.TypeConcurrency))
return &WorkPool{ return &WorkPool{
config: config, config: config,
Headless: &headlessWg, Headless: headlessWg,
Default: &defaultWg, Default: defaultWg,
} }
} }
@ -47,19 +46,39 @@ func (w *WorkPool) Wait() {
w.Headless.Wait() w.Headless.Wait()
} }
// InputWorkPool is a work pool per-input
type InputWorkPool struct {
WaitGroup *sizedwaitgroup.SizedWaitGroup
}
// InputPool returns a work pool for an input type // InputPool returns a work pool for an input type
func (w *WorkPool) InputPool(templateType types.ProtocolType) *InputWorkPool { func (w *WorkPool) InputPool(templateType types.ProtocolType) *syncutil.AdaptiveWaitGroup {
var count int var count int
if templateType == types.HeadlessProtocol { if templateType == types.HeadlessProtocol {
count = w.config.HeadlessInputConcurrency count = w.config.HeadlessInputConcurrency
} else { } else {
count = w.config.InputConcurrency count = w.config.InputConcurrency
} }
swg := sizedwaitgroup.New(count) swg, _ := syncutil.New(syncutil.WithSize(count))
return &InputWorkPool{WaitGroup: &swg} return swg
}
func (w *WorkPool) RefreshWithConfig(config WorkPoolConfig) {
if w.config.TypeConcurrency != config.TypeConcurrency {
w.config.TypeConcurrency = config.TypeConcurrency
}
if w.config.HeadlessTypeConcurrency != config.HeadlessTypeConcurrency {
w.config.HeadlessTypeConcurrency = config.HeadlessTypeConcurrency
}
if w.config.InputConcurrency != config.InputConcurrency {
w.config.InputConcurrency = config.InputConcurrency
}
if w.config.HeadlessInputConcurrency != config.HeadlessInputConcurrency {
w.config.HeadlessInputConcurrency = config.HeadlessInputConcurrency
}
w.Refresh()
}
func (w *WorkPool) Refresh() {
if w.Default.Size != w.config.TypeConcurrency {
w.Default.Resize(w.config.TypeConcurrency)
}
if w.Headless.Size != w.config.HeadlessTypeConcurrency {
w.Headless.Resize(w.config.HeadlessTypeConcurrency)
}
} }

View File

@ -4,13 +4,13 @@ import (
"sync" "sync"
"github.com/dop251/goja" "github.com/dop251/goja"
"github.com/remeh/sizedwaitgroup" syncutil "github.com/projectdiscovery/utils/sync"
) )
var ( var (
ephemeraljsc = sizedwaitgroup.New(NonPoolingVMConcurrency) ephemeraljsc, _ = syncutil.New(syncutil.WithSize(NonPoolingVMConcurrency))
lazyFixedSgInit = sync.OnceFunc(func() { lazyFixedSgInit = sync.OnceFunc(func() {
ephemeraljsc = sizedwaitgroup.New(NonPoolingVMConcurrency) ephemeraljsc, _ = syncutil.New(syncutil.WithSize(NonPoolingVMConcurrency))
}) })
) )

View File

@ -36,7 +36,7 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/js/libs/goconsole" "github.com/projectdiscovery/nuclei/v3/pkg/js/libs/goconsole"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
stringsutil "github.com/projectdiscovery/utils/strings" stringsutil "github.com/projectdiscovery/utils/strings"
"github.com/remeh/sizedwaitgroup" syncutil "github.com/projectdiscovery/utils/sync"
) )
const ( const (
@ -51,10 +51,16 @@ var (
// autoregister console node module with default printer it uses gologger backend // autoregister console node module with default printer it uses gologger backend
require.RegisterNativeModule(console.ModuleName, console.RequireWithPrinter(goconsole.NewGoConsolePrinter())) require.RegisterNativeModule(console.ModuleName, console.RequireWithPrinter(goconsole.NewGoConsolePrinter()))
}) })
pooljsc sizedwaitgroup.SizedWaitGroup pooljsc *syncutil.AdaptiveWaitGroup
lazySgInit = sync.OnceFunc(func() { lazySgInit = sync.OnceFunc(func() {
pooljsc = sizedwaitgroup.New(PoolingJsVmConcurrency) pooljsc, _ = syncutil.New(syncutil.WithSize(PoolingJsVmConcurrency))
}) })
sgResizeCheck = func() {
// resize check point
if pooljsc.Size != PoolingJsVmConcurrency {
pooljsc.Resize(PoolingJsVmConcurrency)
}
}
) )
var gojapool = &sync.Pool{ var gojapool = &sync.Pool{
@ -116,6 +122,8 @@ func executeWithPoolingProgram(p *goja.Program, args *ExecuteArgs, opts *Execute
// its unknown (most likely cannot be done) to limit max js runtimes at a moment without making it static // its unknown (most likely cannot be done) to limit max js runtimes at a moment without making it static
// unlike sync.Pool which reacts to GC and its purposes is to reuse objects rather than creating new ones // unlike sync.Pool which reacts to GC and its purposes is to reuse objects rather than creating new ones
lazySgInit() lazySgInit()
sgResizeCheck()
pooljsc.Add() pooljsc.Add()
defer pooljsc.Done() defer pooljsc.Done()
runtime := gojapool.Get().(*goja.Runtime) runtime := gojapool.Get().(*goja.Runtime)

View File

@ -30,8 +30,8 @@ import (
mapsutil "github.com/projectdiscovery/utils/maps" mapsutil "github.com/projectdiscovery/utils/maps"
sliceutil "github.com/projectdiscovery/utils/slice" sliceutil "github.com/projectdiscovery/utils/slice"
stringsutil "github.com/projectdiscovery/utils/strings" stringsutil "github.com/projectdiscovery/utils/strings"
syncutil "github.com/projectdiscovery/utils/sync"
wappalyzer "github.com/projectdiscovery/wappalyzergo" wappalyzer "github.com/projectdiscovery/wappalyzergo"
"github.com/remeh/sizedwaitgroup"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -128,7 +128,10 @@ func (s *Service) Close() bool {
func (s *Service) Execute() error { func (s *Service) Execute() error {
gologger.Info().Msgf("Executing Automatic scan on %d target[s]", s.target.Count()) gologger.Info().Msgf("Executing Automatic scan on %d target[s]", s.target.Count())
// setup host concurrency // setup host concurrency
sg := sizedwaitgroup.New(s.opts.Options.BulkSize) sg, err := syncutil.New(syncutil.WithSize(s.opts.Options.BulkSize))
if err != nil {
return err
}
s.target.Iterate(func(value *contextargs.MetaInput) bool { s.target.Iterate(func(value *contextargs.MetaInput) bool {
sg.Add() sg.Add()
go func(input *contextargs.MetaInput) { go func(input *contextargs.MetaInput) {
@ -246,7 +249,7 @@ func (s *Service) getTagsUsingDetectionTemplates(input *contextargs.MetaInput) (
// execute tech detection templates on target // execute tech detection templates on target
tags := map[string]struct{}{} tags := map[string]struct{}{}
m := &sync.Mutex{} m := &sync.Mutex{}
sg := sizedwaitgroup.New(s.opts.Options.TemplateThreads) sg, _ := syncutil.New(syncutil.WithSize(s.opts.Options.TemplateThreads))
counter := atomic.Uint32{} counter := atomic.Uint32{}
for _, t := range s.techTemplates { for _, t := range s.techTemplates {

View File

@ -9,7 +9,6 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/remeh/sizedwaitgroup"
"go.uber.org/multierr" "go.uber.org/multierr"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@ -27,6 +26,7 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/utils" "github.com/projectdiscovery/nuclei/v3/pkg/utils"
"github.com/projectdiscovery/retryabledns" "github.com/projectdiscovery/retryabledns"
iputil "github.com/projectdiscovery/utils/ip" iputil "github.com/projectdiscovery/utils/ip"
syncutil "github.com/projectdiscovery/utils/sync"
) )
var _ protocols.Request = &Request{} var _ protocols.Request = &Request{}
@ -62,9 +62,15 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata,
variablesMap := request.options.Variables.Evaluate(vars) variablesMap := request.options.Variables.Evaluate(vars)
vars = generators.MergeMaps(vars, variablesMap, request.options.Constants) vars = generators.MergeMaps(vars, variablesMap, request.options.Constants)
// if request threads matches global payload concurrency we follow it
shouldFollowGlobal := request.Threads == request.options.Options.PayloadConcurrency
if request.generator != nil { if request.generator != nil {
iterator := request.generator.NewIterator() iterator := request.generator.NewIterator()
swg := sizedwaitgroup.New(request.Threads) swg, err := syncutil.New(syncutil.WithSize(request.Threads))
if err != nil {
return err
}
var multiErr error var multiErr error
m := &sync.Mutex{} m := &sync.Mutex{}
@ -73,6 +79,12 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata,
if !ok { if !ok {
break break
} }
// resize check point - nop if there are no changes
if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency {
swg.Resize(request.options.Options.PayloadConcurrency)
}
value = generators.MergeMaps(vars, value) value = generators.MergeMaps(vars, value)
swg.Add() swg.Add()
go func(newVars map[string]interface{}) { go func(newVars map[string]interface{}) {
@ -140,7 +152,7 @@ func (request *Request) execute(input *contextargs.Context, domain string, metad
} }
} }
request.options.RateLimiter.Take() request.options.RateLimitTake()
// Send the request to the target servers // Send the request to the target servers
response, err := dnsClient.Do(compiledRequest) response, err := dnsClient.Do(compiledRequest)

View File

@ -11,7 +11,6 @@ import (
"github.com/docker/go-units" "github.com/docker/go-units"
"github.com/mholt/archiver" "github.com/mholt/archiver"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/remeh/sizedwaitgroup"
"github.com/projectdiscovery/gologger" "github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/nuclei/v3/pkg/operators" "github.com/projectdiscovery/nuclei/v3/pkg/operators"
@ -24,6 +23,7 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/helpers/responsehighlighter" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/helpers/responsehighlighter"
templateTypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types" templateTypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types"
sliceutil "github.com/projectdiscovery/utils/slice" sliceutil "github.com/projectdiscovery/utils/slice"
syncutil "github.com/projectdiscovery/utils/sync"
) )
var _ protocols.Request = &Request{} var _ protocols.Request = &Request{}
@ -47,8 +47,11 @@ var errEmptyResult = errors.New("Empty result")
// ExecuteWithResults executes the protocol requests and returns results instead of writing them. // ExecuteWithResults executes the protocol requests and returns results instead of writing them.
func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata, previous output.InternalEvent, callback protocols.OutputEventCallback) error { func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata, previous output.InternalEvent, callback protocols.OutputEventCallback) error {
wg := sizedwaitgroup.New(request.options.Options.BulkSize) wg, err := syncutil.New(syncutil.WithSize(request.options.Options.BulkSize))
err := request.getInputPaths(input.MetaInput.Input, func(filePath string) { if err != nil {
return err
}
err = request.getInputPaths(input.MetaInput.Input, func(filePath string) {
wg.Add() wg.Add()
func(filePath string) { func(filePath string) {
defer wg.Done() defer wg.Done()

View File

@ -214,7 +214,7 @@ func geTimeParameter(p *Page, act *Action, parameterName string, defaultValue ti
} }
// ActionAddHeader executes a AddHeader action. // ActionAddHeader executes a AddHeader action.
func (p *Page) ActionAddHeader(act *Action, out map[string]string /*TODO review unused parameter*/) error { func (p *Page) ActionAddHeader(act *Action, out map[string]string) error {
in := p.getActionArgWithDefaultValues(act, "part") in := p.getActionArgWithDefaultValues(act, "part")
args := make(map[string]string) args := make(map[string]string)
@ -225,7 +225,7 @@ func (p *Page) ActionAddHeader(act *Action, out map[string]string /*TODO review
} }
// ActionSetHeader executes a SetHeader action. // ActionSetHeader executes a SetHeader action.
func (p *Page) ActionSetHeader(act *Action, out map[string]string /*TODO review unused parameter*/) error { func (p *Page) ActionSetHeader(act *Action, out map[string]string) error {
in := p.getActionArgWithDefaultValues(act, "part") in := p.getActionArgWithDefaultValues(act, "part")
args := make(map[string]string) args := make(map[string]string)
@ -236,7 +236,7 @@ func (p *Page) ActionSetHeader(act *Action, out map[string]string /*TODO review
} }
// ActionDeleteHeader executes a DeleteHeader action. // ActionDeleteHeader executes a DeleteHeader action.
func (p *Page) ActionDeleteHeader(act *Action, out map[string]string /*TODO review unused parameter*/) error { func (p *Page) ActionDeleteHeader(act *Action, out map[string]string) error {
in := p.getActionArgWithDefaultValues(act, "part") in := p.getActionArgWithDefaultValues(act, "part")
args := make(map[string]string) args := make(map[string]string)
@ -343,7 +343,7 @@ func (p *Page) RunScript(action *Action, out map[string]string) error {
} }
// ClickElement executes click actions for an element. // ClickElement executes click actions for an element.
func (p *Page) ClickElement(act *Action, out map[string]string /*TODO review unused parameter*/) error { func (p *Page) ClickElement(act *Action, out map[string]string) error {
element, err := p.pageElementBy(act.Data) element, err := p.pageElementBy(act.Data)
if err != nil { if err != nil {
return errors.Wrap(err, errCouldNotGetElement) return errors.Wrap(err, errCouldNotGetElement)
@ -358,12 +358,12 @@ func (p *Page) ClickElement(act *Action, out map[string]string /*TODO review unu
} }
// KeyboardAction executes a keyboard action on the page. // KeyboardAction executes a keyboard action on the page.
func (p *Page) KeyboardAction(act *Action, out map[string]string /*TODO review unused parameter*/) error { func (p *Page) KeyboardAction(act *Action, out map[string]string) error {
return p.page.Keyboard.Type([]input.Key(p.getActionArgWithDefaultValues(act, "keys"))...) return p.page.Keyboard.Type([]input.Key(p.getActionArgWithDefaultValues(act, "keys"))...)
} }
// RightClickElement executes right click actions for an element. // RightClickElement executes right click actions for an element.
func (p *Page) RightClickElement(act *Action, out map[string]string /*TODO review unused parameter*/) error { func (p *Page) RightClickElement(act *Action, out map[string]string) error {
element, err := p.pageElementBy(act.Data) element, err := p.pageElementBy(act.Data)
if err != nil { if err != nil {
return errors.Wrap(err, errCouldNotGetElement) return errors.Wrap(err, errCouldNotGetElement)
@ -441,7 +441,7 @@ func (p *Page) Screenshot(act *Action, out map[string]string) error {
} }
// InputElement executes input element actions for an element. // InputElement executes input element actions for an element.
func (p *Page) InputElement(act *Action, out map[string]string /*TODO review unused parameter*/) error { func (p *Page) InputElement(act *Action, out map[string]string) error {
value := p.getActionArgWithDefaultValues(act, "value") value := p.getActionArgWithDefaultValues(act, "value")
if value == "" { if value == "" {
return errinvalidArguments return errinvalidArguments
@ -460,7 +460,7 @@ func (p *Page) InputElement(act *Action, out map[string]string /*TODO review unu
} }
// TimeInputElement executes time input on an element // TimeInputElement executes time input on an element
func (p *Page) TimeInputElement(act *Action, out map[string]string /*TODO review unused parameter*/) error { func (p *Page) TimeInputElement(act *Action, out map[string]string) error {
value := p.getActionArgWithDefaultValues(act, "value") value := p.getActionArgWithDefaultValues(act, "value")
if value == "" { if value == "" {
return errinvalidArguments return errinvalidArguments
@ -483,7 +483,7 @@ func (p *Page) TimeInputElement(act *Action, out map[string]string /*TODO review
} }
// SelectInputElement executes select input statement action on a element // SelectInputElement executes select input statement action on a element
func (p *Page) SelectInputElement(act *Action, out map[string]string /*TODO review unused parameter*/) error { func (p *Page) SelectInputElement(act *Action, out map[string]string) error {
value := p.getActionArgWithDefaultValues(act, "value") value := p.getActionArgWithDefaultValues(act, "value")
if value == "" { if value == "" {
return errinvalidArguments return errinvalidArguments
@ -508,7 +508,7 @@ func (p *Page) SelectInputElement(act *Action, out map[string]string /*TODO revi
} }
// WaitLoad waits for the page to load // WaitLoad waits for the page to load
func (p *Page) WaitLoad(act *Action, out map[string]string /*TODO review unused parameter*/) error { func (p *Page) WaitLoad(act *Action, out map[string]string) error {
p.page.Timeout(2 * time.Second).WaitNavigation(proto.PageLifecycleEventNameFirstMeaningfulPaint)() p.page.Timeout(2 * time.Second).WaitNavigation(proto.PageLifecycleEventNameFirstMeaningfulPaint)()
// Wait for the window.onload event and also wait for the network requests // Wait for the window.onload event and also wait for the network requests
@ -538,7 +538,7 @@ func (p *Page) GetResource(act *Action, out map[string]string) error {
} }
// FilesInput acts with a file input element on page // FilesInput acts with a file input element on page
func (p *Page) FilesInput(act *Action, out map[string]string /*TODO review unused parameter*/) error { func (p *Page) FilesInput(act *Action, out map[string]string) error {
element, err := p.pageElementBy(act.Data) element, err := p.pageElementBy(act.Data)
if err != nil { if err != nil {
return errors.Wrap(err, errCouldNotGetElement) return errors.Wrap(err, errCouldNotGetElement)
@ -589,7 +589,7 @@ func (p *Page) ExtractElement(act *Action, out map[string]string) error {
} }
// WaitEvent waits for an event to happen on the page. // WaitEvent waits for an event to happen on the page.
func (p *Page) WaitEvent(act *Action, out map[string]string /*TODO review unused parameter*/) (func() error, error) { func (p *Page) WaitEvent(act *Action, out map[string]string) (func() error, error) {
event := p.getActionArgWithDefaultValues(act, "event") event := p.getActionArgWithDefaultValues(act, "event")
if event == "" { if event == "" {
return nil, errors.New("event not recognized") return nil, errors.New("event not recognized")
@ -661,14 +661,14 @@ func (p *Page) pageElementBy(data map[string]string) (*rod.Element, error) {
} }
// DebugAction enables debug action on a page. // DebugAction enables debug action on a page.
func (p *Page) DebugAction(act *Action, out map[string]string /*TODO review unused parameter*/) error { func (p *Page) DebugAction(act *Action, out map[string]string) error {
p.instance.browser.engine.SlowMotion(5 * time.Second) p.instance.browser.engine.SlowMotion(5 * time.Second)
p.instance.browser.engine.Trace(true) p.instance.browser.engine.Trace(true)
return nil return nil
} }
// SleepAction sleeps on the page for a specified duration // SleepAction sleeps on the page for a specified duration
func (p *Page) SleepAction(act *Action, out map[string]string /*TODO review unused parameter*/) error { func (p *Page) SleepAction(act *Action, out map[string]string) error {
seconds := act.Data["duration"] seconds := act.Data["duration"]
if seconds == "" { if seconds == "" {
seconds = "5" seconds = "5"

View File

@ -4,7 +4,7 @@ import (
"context" "context"
"sync" "sync"
"github.com/remeh/sizedwaitgroup" syncutil "github.com/projectdiscovery/utils/sync"
) )
// WorkPoolType is the type of work pool to use // WorkPoolType is the type of work pool to use
@ -26,7 +26,7 @@ type StopAtFirstMatchHandler[T any] struct {
// work pool and its type // work pool and its type
poolType WorkPoolType poolType WorkPoolType
sgPool sizedwaitgroup.SizedWaitGroup sgPool *syncutil.AdaptiveWaitGroup
wgPool *sync.WaitGroup wgPool *sync.WaitGroup
// internal / unexported // internal / unexported
@ -40,10 +40,13 @@ type StopAtFirstMatchHandler[T any] struct {
// NewBlockingSPMHandler creates a new stop at first match handler // NewBlockingSPMHandler creates a new stop at first match handler
func NewBlockingSPMHandler[T any](ctx context.Context, size int, spm bool) *StopAtFirstMatchHandler[T] { func NewBlockingSPMHandler[T any](ctx context.Context, size int, spm bool) *StopAtFirstMatchHandler[T] {
ctx1, cancel := context.WithCancel(ctx) ctx1, cancel := context.WithCancel(ctx)
awg, _ := syncutil.New(syncutil.WithSize(size))
s := &StopAtFirstMatchHandler[T]{ s := &StopAtFirstMatchHandler[T]{
ResultChan: make(chan T, 1), ResultChan: make(chan T, 1),
poolType: Blocking, poolType: Blocking,
sgPool: sizedwaitgroup.New(size), sgPool: awg,
internalWg: &sync.WaitGroup{}, internalWg: &sync.WaitGroup{},
ctx: ctx1, ctx: ctx1,
cancel: cancel, cancel: cancel,
@ -140,6 +143,16 @@ func (h *StopAtFirstMatchHandler[T]) Release() {
} }
} }
func (h *StopAtFirstMatchHandler[T]) Resize(size int) {
if h.sgPool.Size != size {
h.sgPool.Resize(size)
}
}
func (h *StopAtFirstMatchHandler[T]) Size() int {
return h.sgPool.Size
}
// Wait waits for all work to be done // Wait waits for all work to be done
func (h *StopAtFirstMatchHandler[T]) Wait() { func (h *StopAtFirstMatchHandler[T]) Wait() {
switch h.poolType { switch h.poolType {

View File

@ -165,6 +165,9 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
// Workers that keeps enqueuing new requests // Workers that keeps enqueuing new requests
maxWorkers := request.Threads maxWorkers := request.Threads
// if request threads matches global payload concurrency we follow it
shouldFollowGlobal := maxWorkers == request.options.Options.PayloadConcurrency
if protocolstate.IsLowOnMemory() { if protocolstate.IsLowOnMemory() {
maxWorkers = protocolstate.GuardThreadsOrDefault(request.Threads) maxWorkers = protocolstate.GuardThreadsOrDefault(request.Threads)
} }
@ -198,6 +201,12 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
if !ok { if !ok {
break break
} }
// resize check point - nop if there are no changes
if shouldFollowGlobal && spmHandler.Size() != request.options.Options.PayloadConcurrency {
spmHandler.Resize(request.options.Options.PayloadConcurrency)
}
ctx := request.newContext(input) ctx := request.newContext(input)
generatedHttpRequest, err := generator.Make(ctx, input, inputData, payloads, dynamicValues) generatedHttpRequest, err := generator.Make(ctx, input, inputData, payloads, dynamicValues)
if err != nil { if err != nil {
@ -222,7 +231,7 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
return return
case spmHandler.ResultChan <- func() error { case spmHandler.ResultChan <- func() error {
// putting ratelimiter here prevents any unnecessary waiting if any // putting ratelimiter here prevents any unnecessary waiting if any
request.options.RateLimiter.Take() request.options.RateLimitTake()
previous := make(map[string]interface{}) previous := make(map[string]interface{})
return request.executeRequest(input, httpRequest, previous, false, wrappedCallback, 0) return request.executeRequest(input, httpRequest, previous, false, wrappedCallback, 0)
}(): }():
@ -366,7 +375,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
executeFunc := func(data string, payloads, dynamicValue map[string]interface{}) (bool, error) { executeFunc := func(data string, payloads, dynamicValue map[string]interface{}) (bool, error) {
hasInteractMatchers := interactsh.HasMatchers(request.CompiledOperators) hasInteractMatchers := interactsh.HasMatchers(request.CompiledOperators)
request.options.RateLimiter.Take() request.options.RateLimitTake()
ctx := request.newContext(input) ctx := request.newContext(input)
ctxWithTimeout, cancel := context.WithTimeout(ctx, time.Duration(request.options.Options.Timeout)*time.Second) ctxWithTimeout, cancel := context.WithTimeout(ctx, time.Duration(request.options.Options.Timeout)*time.Second)

View File

@ -145,7 +145,7 @@ func (request *Request) executeGeneratedFuzzingRequest(gr fuzz.GeneratedRequest,
if request.options.HostErrorsCache != nil && request.options.HostErrorsCache.Check(input.MetaInput.Input) { if request.options.HostErrorsCache != nil && request.options.HostErrorsCache.Check(input.MetaInput.Input) {
return false return false
} }
request.options.RateLimiter.Take() request.options.RateLimitTake()
req := &generatedRequest{ req := &generatedRequest{
request: gr.Request, request: gr.Request,
dynamicValues: gr.DynamicValues, dynamicValues: gr.DynamicValues,

View File

@ -34,8 +34,8 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/types" "github.com/projectdiscovery/nuclei/v3/pkg/types"
errorutil "github.com/projectdiscovery/utils/errors" errorutil "github.com/projectdiscovery/utils/errors"
iputil "github.com/projectdiscovery/utils/ip" iputil "github.com/projectdiscovery/utils/ip"
syncutil "github.com/projectdiscovery/utils/sync"
urlutil "github.com/projectdiscovery/utils/url" urlutil "github.com/projectdiscovery/utils/url"
"github.com/remeh/sizedwaitgroup"
) )
// Request is a request for the javascript protocol // Request is a request for the javascript protocol
@ -406,7 +406,11 @@ func (request *Request) executeRequestParallel(ctxParent context.Context, hostPo
requestOptions := request.options requestOptions := request.options
gotmatches := &atomic.Bool{} gotmatches := &atomic.Bool{}
sg := sizedwaitgroup.New(threads) // if request threads matches global payload concurrency we follow it
shouldFollowGlobal := threads == request.options.Options.PayloadConcurrency
sg, _ := syncutil.New(syncutil.WithSize(threads))
if request.generator != nil { if request.generator != nil {
iterator := request.generator.NewIterator() iterator := request.generator.NewIterator()
for { for {
@ -414,6 +418,12 @@ func (request *Request) executeRequestParallel(ctxParent context.Context, hostPo
if !ok { if !ok {
break break
} }
// resize check point - nop if there are no changes
if shouldFollowGlobal && sg.Size != request.options.Options.PayloadConcurrency {
sg.Resize(request.options.Options.PayloadConcurrency)
}
sg.Add() sg.Add()
go func() { go func() {
defer sg.Done() defer sg.Done()

View File

@ -11,7 +11,7 @@ var (
) )
// Init initializes the clientpool implementation // Init initializes the clientpool implementation
func Init(options *types.Options /*TODO review unused parameter*/) error { func Init(options *types.Options) error {
// Don't create clients if already created in the past. // Don't create clients if already created in the past.
if normalClient != nil { if normalClient != nil {
return nil return nil

View File

@ -12,7 +12,6 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/remeh/sizedwaitgroup"
"go.uber.org/multierr" "go.uber.org/multierr"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@ -34,6 +33,7 @@ import (
errorutil "github.com/projectdiscovery/utils/errors" errorutil "github.com/projectdiscovery/utils/errors"
mapsutil "github.com/projectdiscovery/utils/maps" mapsutil "github.com/projectdiscovery/utils/maps"
"github.com/projectdiscovery/utils/reader" "github.com/projectdiscovery/utils/reader"
syncutil "github.com/projectdiscovery/utils/sync"
) )
var ( var (
@ -174,17 +174,29 @@ func (request *Request) executeAddress(variables map[string]interface{}, actualA
return err return err
} }
// if request threads matches global payload concurrency we follow it
shouldFollowGlobal := request.Threads == request.options.Options.PayloadConcurrency
if request.generator != nil { if request.generator != nil {
iterator := request.generator.NewIterator() iterator := request.generator.NewIterator()
var multiErr error var multiErr error
m := &sync.Mutex{} m := &sync.Mutex{}
swg := sizedwaitgroup.New(request.Threads) swg, err := syncutil.New(syncutil.WithSize(request.Threads))
if err != nil {
return err
}
for { for {
value, ok := iterator.Value() value, ok := iterator.Value()
if !ok { if !ok {
break break
} }
// resize check point - nop if there are no changes
if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency {
swg.Resize(request.options.Options.PayloadConcurrency)
}
value = generators.MergeMaps(value, payloads) value = generators.MergeMaps(value, payloads)
swg.Add() swg.Add()
go func(vars map[string]interface{}) { go func(vars map[string]interface{}) {

View File

@ -6,7 +6,6 @@ import (
"os" "os"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/remeh/sizedwaitgroup"
"github.com/projectdiscovery/gologger" "github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/nuclei/v3/pkg/output" "github.com/projectdiscovery/nuclei/v3/pkg/output"
@ -17,6 +16,7 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils"
templateTypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types" templateTypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types"
"github.com/projectdiscovery/utils/conversion" "github.com/projectdiscovery/utils/conversion"
syncutil "github.com/projectdiscovery/utils/sync"
) )
var _ protocols.Request = &Request{} var _ protocols.Request = &Request{}
@ -29,10 +29,13 @@ func (request *Request) Type() templateTypes.ProtocolType {
} }
// ExecuteWithResults executes the protocol requests and returns results instead of writing them. // ExecuteWithResults executes the protocol requests and returns results instead of writing them.
func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata /*TODO review unused parameter*/, previous output.InternalEvent, callback protocols.OutputEventCallback) error { func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata, previous output.InternalEvent, callback protocols.OutputEventCallback) error {
wg := sizedwaitgroup.New(request.options.Options.BulkSize) wg, err := syncutil.New(syncutil.WithSize(request.options.Options.BulkSize))
if err != nil {
return err
}
err := request.getInputPaths(input.MetaInput.Input, func(data string) { err = request.getInputPaths(input.MetaInput.Input, func(data string) {
wg.Add() wg.Add()
go func(data string) { go func(data string) {

View File

@ -34,9 +34,6 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/types" "github.com/projectdiscovery/nuclei/v3/pkg/types"
) )
// Optional Callback to update Thread count in payloads across all requests
type PayloadThreadSetterCallback func(opts *ExecutorOptions, totalRequests, currentThreads int) int
var ( var (
MaxTemplateFileSizeForEncoding = 1024 * 1024 MaxTemplateFileSizeForEncoding = 1024 * 1024
) )
@ -114,10 +111,6 @@ type ExecutorOptions struct {
// JsCompiler is abstracted javascript compiler which adds node modules and provides execution // JsCompiler is abstracted javascript compiler which adds node modules and provides execution
// environment for javascript templates // environment for javascript templates
JsCompiler *compiler.Compiler JsCompiler *compiler.Compiler
// Optional Callback function to update Thread count in payloads across all protocols
// based on given logic. by default nuclei reverts to using value of `-c` when threads count
// is not specified or is 0 in template
OverrideThreadsCount PayloadThreadSetterCallback
// AuthProvider is a provider for auth strategies // AuthProvider is a provider for auth strategies
AuthProvider authprovider.AuthProvider AuthProvider authprovider.AuthProvider
//TemporaryDirectory is the directory to store temporary files //TemporaryDirectory is the directory to store temporary files
@ -128,17 +121,25 @@ type ExecutorOptions struct {
ExportReqURLPattern bool ExportReqURLPattern bool
} }
// todo: centralizing components is not feasible with current clogged architecture
// a possible approach could be an internal event bus with pub-subs? This would be less invasive than
// reworking dep injection from scratch
func (eo *ExecutorOptions) RateLimitTake() {
if eo.RateLimiter.GetLimit() != uint(eo.Options.RateLimit) {
eo.RateLimiter.SetLimit(uint(eo.Options.RateLimit))
eo.RateLimiter.SetDuration(eo.Options.RateLimitDuration)
}
eo.RateLimiter.Take()
}
// GetThreadsForPayloadRequests returns the number of threads to use as default for // GetThreadsForPayloadRequests returns the number of threads to use as default for
// given max-request of payloads // given max-request of payloads
func (e *ExecutorOptions) GetThreadsForNPayloadRequests(totalRequests int, currentThreads int) int { func (e *ExecutorOptions) GetThreadsForNPayloadRequests(totalRequests int, currentThreads int) int {
if e.OverrideThreadsCount != nil {
return e.OverrideThreadsCount(e, totalRequests, currentThreads)
}
if currentThreads > 0 { if currentThreads > 0 {
return currentThreads return currentThreads
} else {
return e.Options.PayloadConcurrency
} }
return e.Options.PayloadConcurrency
} }
// CreateTemplateCtxStore creates template context store (which contains templateCtx for every scan) // CreateTemplateCtxStore creates template context store (which contains templateCtx for every scan)

View File

@ -54,6 +54,7 @@ var DefaultOptions = &types.Options{
Timeout: 5, Timeout: 5,
Retries: 1, Retries: 1,
RateLimit: 150, RateLimit: 150,
RateLimitDuration: time.Second,
ProjectPath: "", ProjectPath: "",
Severities: severity.Severities{}, Severities: severity.Severities{},
Targets: []string{}, Targets: []string{},

View File

@ -1,6 +1,7 @@
package flow package flow
import ( import (
"context"
"reflect" "reflect"
"sync" "sync"
@ -12,33 +13,10 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/utils/vardump" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/utils/vardump"
"github.com/projectdiscovery/nuclei/v3/pkg/tmplexec/flow/builtin" "github.com/projectdiscovery/nuclei/v3/pkg/tmplexec/flow/builtin"
"github.com/projectdiscovery/nuclei/v3/pkg/types" "github.com/projectdiscovery/nuclei/v3/pkg/types"
"github.com/remeh/sizedwaitgroup" "github.com/projectdiscovery/utils/sync/sizedpool"
) )
type jsWaitGroup struct { var jsOnce sync.Once
sync.Once
sg sizedwaitgroup.SizedWaitGroup
}
var jsPool = &jsWaitGroup{}
// GetJSRuntime returns a new JS runtime from pool
func GetJSRuntime(opts *types.Options) *goja.Runtime {
jsPool.Do(func() {
if opts.JsConcurrency < 100 {
opts.JsConcurrency = 100
}
jsPool.sg = sizedwaitgroup.New(opts.JsConcurrency)
})
jsPool.sg.Add()
return gojapool.Get().(*goja.Runtime)
}
// PutJSRuntime returns a JS runtime to pool
func PutJSRuntime(runtime *goja.Runtime) {
defer jsPool.sg.Done()
gojapool.Put(runtime)
}
// js runtime pool using sync.Pool // js runtime pool using sync.Pool
var gojapool = &sync.Pool{ var gojapool = &sync.Pool{
@ -49,8 +27,29 @@ var gojapool = &sync.Pool{
}, },
} }
func registerBuiltins(runtime *goja.Runtime) { var sizedgojapool *sizedpool.SizedPool[*goja.Runtime]
// GetJSRuntime returns a new JS runtime from pool
func GetJSRuntime(opts *types.Options) *goja.Runtime {
jsOnce.Do(func() {
if opts.JsConcurrency < 100 {
opts.JsConcurrency = 100
}
sizedgojapool, _ = sizedpool.New[*goja.Runtime](
sizedpool.WithPool[*goja.Runtime](gojapool),
sizedpool.WithSize[*goja.Runtime](int64(opts.JsConcurrency)),
)
})
runtime, _ := sizedgojapool.Get(context.TODO())
return runtime
}
// PutJSRuntime returns a JS runtime to pool
func PutJSRuntime(runtime *goja.Runtime) {
sizedgojapool.Put(runtime)
}
func registerBuiltins(runtime *goja.Runtime) {
_ = gojs.RegisterFuncWithSignature(runtime, gojs.FuncOpts{ _ = gojs.RegisterFuncWithSignature(runtime, gojs.FuncOpts{
Name: "log", Name: "log",
Description: "Logs a given object/message to stdout (only for debugging purposes)", Description: "Logs a given object/message to stdout (only for debugging purposes)",

View File

@ -132,7 +132,10 @@ type Options struct {
Retries int Retries int
// Rate-Limit is the maximum number of requests per specified target // Rate-Limit is the maximum number of requests per specified target
RateLimit int RateLimit int
// Rate Limit Duration interval between burst resets
RateLimitDuration time.Duration
// Rate-Limit is the maximum number of requests per minute for specified target // Rate-Limit is the maximum number of requests per minute for specified target
// Deprecated: Use RateLimitDuration - automatically set Rate Limit Duration to 60 seconds
RateLimitMinute int RateLimitMinute int
// PageTimeout is the maximum time to wait for a page in seconds // PageTimeout is the maximum time to wait for a page in seconds
PageTimeout int PageTimeout int
@ -382,6 +385,8 @@ type Options struct {
SkipFormatValidation bool SkipFormatValidation bool
// PayloadConcurrency is the number of concurrent payloads to run per template // PayloadConcurrency is the number of concurrent payloads to run per template
PayloadConcurrency int PayloadConcurrency int
// ProbeConcurrency is the number of concurrent http probes to run with httpx
ProbeConcurrency int
// Dast only runs DAST templates // Dast only runs DAST templates
DAST bool DAST bool
} }
@ -410,6 +415,7 @@ func (options *Options) HasClientCertificates() bool {
func DefaultOptions() *Options { func DefaultOptions() *Options {
return &Options{ return &Options{
RateLimit: 150, RateLimit: 150,
RateLimitDuration: time.Second,
BulkSize: 25, BulkSize: 25,
TemplateThreads: 25, TemplateThreads: 25,
HeadlessBulkSize: 10, HeadlessBulkSize: 10,