Added dns client pool + misc changes to http client pool

dev
Ice3man543 2020-12-23 22:09:11 +05:30
parent c4428824b6
commit ff4c61a0eb
2 changed files with 110 additions and 4 deletions

View File

@ -0,0 +1,78 @@
package clientpool
import (
"strconv"
"strings"
"sync"
"github.com/projectdiscovery/nuclei/v2/pkg/types"
"github.com/projectdiscovery/retryabledns"
)
var (
poolMutex *sync.RWMutex
normalClient *retryabledns.Client
clientPool map[string]*retryabledns.Client
)
// defaultResolvers contains the list of resolvers known to be trusted.
var defaultResolvers = []string{
"1.1.1.1:53", // Cloudflare
"1.0.0.1:53", // Cloudflare
"8.8.8.8:53", // Google
"8.8.4.4:53", // Google
}
// Init initializes the clientpool implementation
func Init(options *types.Options) error {
// Don't create clients if already created in past.
if normalClient != nil {
return nil
}
poolMutex = &sync.RWMutex{}
clientPool = make(map[string]*retryabledns.Client)
if client, err := Get(options, &Configuration{}); err != nil {
return err
} else {
normalClient = client
}
return nil
}
// Configuration contains the custom configuration options for a client
type Configuration struct {
// Retries contains the retries for the dns client
Retries int
}
// Hash returns the hash of the configuration to allow client pooling
func (c *Configuration) Hash() string {
builder := &strings.Builder{}
builder.Grow(8)
builder.WriteString("r")
builder.WriteString(strconv.Itoa(c.Retries))
hash := builder.String()
return hash
}
// Get creates or gets a client for the protocol based on custom configuration
func Get(options *types.Options, configuration *Configuration) (*retryabledns.Client, error) {
if !(configuration.Retries > 0) {
return normalClient, nil
}
hash := configuration.Hash()
poolMutex.RLock()
if client, ok := clientPool[hash]; ok {
poolMutex.RUnlock()
return client, nil
}
poolMutex.RUnlock()
client := retryabledns.New(defaultResolvers, configuration.Retries)
poolMutex.Lock()
clientPool[hash] = client
poolMutex.Unlock()
return client, nil
}

View File

@ -15,19 +15,34 @@ import (
"github.com/pkg/errors"
"github.com/projectdiscovery/fastdialer/fastdialer"
"github.com/projectdiscovery/nuclei/v2/pkg/types"
"github.com/projectdiscovery/rawhttp"
"github.com/projectdiscovery/retryablehttp-go"
"golang.org/x/net/proxy"
)
var (
dialer *fastdialer.Dialer
poolMutex *sync.RWMutex
clientPool map[string]*retryablehttp.Client
dialer *fastdialer.Dialer
rawhttpClient *rawhttp.Client
poolMutex *sync.RWMutex
normalClient *retryablehttp.Client
clientPool map[string]*retryablehttp.Client
)
func init() {
// Init initializes the clientpool implementation
func Init(options *types.Options) error {
// Don't create clients if already created in past.
if normalClient != nil {
return nil
}
poolMutex = &sync.RWMutex{}
clientPool = make(map[string]*retryablehttp.Client)
if client, err := Get(options, &Configuration{}); err != nil {
return err
} else {
normalClient = client
}
return nil
}
// Configuration contains the custom configuration options for a client
@ -43,6 +58,7 @@ type Configuration struct {
// Hash returns the hash of the configuration to allow client pooling
func (c *Configuration) Hash() string {
builder := &strings.Builder{}
builder.Grow(16)
builder.WriteString("t")
builder.WriteString(strconv.Itoa(c.Threads))
builder.WriteString("m")
@ -53,8 +69,19 @@ func (c *Configuration) Hash() string {
return hash
}
// GetRawHTTP returns the rawhttp request client
func GetRawHTTP() *rawhttp.Client {
if rawhttpClient == nil {
rawhttpClient = rawhttp.NewClient(rawhttp.DefaultOptions)
}
return rawhttpClient
}
// Get creates or gets a client for the protocol based on custom configuration
func Get(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) {
if !(configuration.Threads > 0 && configuration.MaxRedirects > 0 && configuration.FollowRedirects) {
return normalClient, nil
}
var proxyURL *url.URL
var err error
@ -139,6 +166,7 @@ func Get(options *types.Options, configuration *Configuration) (*retryablehttp.C
Timeout: time.Duration(options.Timeout) * time.Second,
CheckRedirect: makeCheckRedirectFunc(followRedirects, maxRedirects),
}, retryablehttpOptions)
client.CheckRetry = retryablehttp.HostSprayRetryPolicy()
poolMutex.Lock()
clientPool[hash] = client