From ff4c61a0eb9455bb54cfb783bb9c8e22c4da66da Mon Sep 17 00:00:00 2001 From: Ice3man543 Date: Wed, 23 Dec 2020 22:09:11 +0530 Subject: [PATCH] Added dns client pool + misc changes to http client pool --- v2/pkg/protocols/dns/clientpool/clientpool.go | 78 +++++++++++++++++++ .../protocols/http/clientpool/clientpool.go | 36 ++++++++- 2 files changed, 110 insertions(+), 4 deletions(-) create mode 100644 v2/pkg/protocols/dns/clientpool/clientpool.go diff --git a/v2/pkg/protocols/dns/clientpool/clientpool.go b/v2/pkg/protocols/dns/clientpool/clientpool.go new file mode 100644 index 00000000..442dc8e5 --- /dev/null +++ b/v2/pkg/protocols/dns/clientpool/clientpool.go @@ -0,0 +1,78 @@ +package clientpool + +import ( + "strconv" + "strings" + "sync" + + "github.com/projectdiscovery/nuclei/v2/pkg/types" + "github.com/projectdiscovery/retryabledns" +) + +var ( + poolMutex *sync.RWMutex + normalClient *retryabledns.Client + clientPool map[string]*retryabledns.Client +) + +// defaultResolvers contains the list of resolvers known to be trusted. +var defaultResolvers = []string{ + "1.1.1.1:53", // Cloudflare + "1.0.0.1:53", // Cloudflare + "8.8.8.8:53", // Google + "8.8.4.4:53", // Google +} + +// Init initializes the clientpool implementation +func Init(options *types.Options) error { + // Don't create clients if already created in past. + if normalClient != nil { + return nil + } + poolMutex = &sync.RWMutex{} + clientPool = make(map[string]*retryabledns.Client) + + if client, err := Get(options, &Configuration{}); err != nil { + return err + } else { + normalClient = client + } + return nil +} + +// Configuration contains the custom configuration options for a client +type Configuration struct { + // Retries contains the retries for the dns client + Retries int +} + +// Hash returns the hash of the configuration to allow client pooling +func (c *Configuration) Hash() string { + builder := &strings.Builder{} + builder.Grow(8) + builder.WriteString("r") + builder.WriteString(strconv.Itoa(c.Retries)) + hash := builder.String() + return hash +} + +// Get creates or gets a client for the protocol based on custom configuration +func Get(options *types.Options, configuration *Configuration) (*retryabledns.Client, error) { + if !(configuration.Retries > 0) { + return normalClient, nil + } + hash := configuration.Hash() + poolMutex.RLock() + if client, ok := clientPool[hash]; ok { + poolMutex.RUnlock() + return client, nil + } + poolMutex.RUnlock() + + client := retryabledns.New(defaultResolvers, configuration.Retries) + + poolMutex.Lock() + clientPool[hash] = client + poolMutex.Unlock() + return client, nil +} diff --git a/v2/pkg/protocols/http/clientpool/clientpool.go b/v2/pkg/protocols/http/clientpool/clientpool.go index 747b5537..90bc2a89 100644 --- a/v2/pkg/protocols/http/clientpool/clientpool.go +++ b/v2/pkg/protocols/http/clientpool/clientpool.go @@ -15,19 +15,34 @@ import ( "github.com/pkg/errors" "github.com/projectdiscovery/fastdialer/fastdialer" "github.com/projectdiscovery/nuclei/v2/pkg/types" + "github.com/projectdiscovery/rawhttp" "github.com/projectdiscovery/retryablehttp-go" "golang.org/x/net/proxy" ) var ( - dialer *fastdialer.Dialer - poolMutex *sync.RWMutex - clientPool map[string]*retryablehttp.Client + dialer *fastdialer.Dialer + rawhttpClient *rawhttp.Client + poolMutex *sync.RWMutex + normalClient *retryablehttp.Client + clientPool map[string]*retryablehttp.Client ) -func init() { +// Init initializes the clientpool implementation +func Init(options *types.Options) error { + // Don't create clients if already created in past. + if normalClient != nil { + return nil + } poolMutex = &sync.RWMutex{} clientPool = make(map[string]*retryablehttp.Client) + + if client, err := Get(options, &Configuration{}); err != nil { + return err + } else { + normalClient = client + } + return nil } // Configuration contains the custom configuration options for a client @@ -43,6 +58,7 @@ type Configuration struct { // Hash returns the hash of the configuration to allow client pooling func (c *Configuration) Hash() string { builder := &strings.Builder{} + builder.Grow(16) builder.WriteString("t") builder.WriteString(strconv.Itoa(c.Threads)) builder.WriteString("m") @@ -53,8 +69,19 @@ func (c *Configuration) Hash() string { return hash } +// GetRawHTTP returns the rawhttp request client +func GetRawHTTP() *rawhttp.Client { + if rawhttpClient == nil { + rawhttpClient = rawhttp.NewClient(rawhttp.DefaultOptions) + } + return rawhttpClient +} + // Get creates or gets a client for the protocol based on custom configuration func Get(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) { + if !(configuration.Threads > 0 && configuration.MaxRedirects > 0 && configuration.FollowRedirects) { + return normalClient, nil + } var proxyURL *url.URL var err error @@ -139,6 +166,7 @@ func Get(options *types.Options, configuration *Configuration) (*retryablehttp.C Timeout: time.Duration(options.Timeout) * time.Second, CheckRedirect: makeCheckRedirectFunc(followRedirects, maxRedirects), }, retryablehttpOptions) + client.CheckRetry = retryablehttp.HostSprayRetryPolicy() poolMutex.Lock() clientPool[hash] = client