mirror of https://github.com/daffainfo/nuclei.git
Adding custom cancel function
parent
09ceb29ba3
commit
70cecf83fb
|
@ -45,6 +45,7 @@ type generatedRequest struct {
|
||||||
request *retryablehttp.Request
|
request *retryablehttp.Request
|
||||||
dynamicValues map[string]interface{}
|
dynamicValues map[string]interface{}
|
||||||
interactshURLs []string
|
interactshURLs []string
|
||||||
|
customCancelFunction context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *generatedRequest) URL() string {
|
func (g *generatedRequest) URL() string {
|
||||||
|
@ -297,11 +298,20 @@ func (r *requestGenerator) handleRawWithPayloads(ctx context.Context, rawRequest
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if reqWithAnnotations, hasAnnotations := r.request.parseAnnotations(rawRequest, req); hasAnnotations {
|
generatedRequest := &generatedRequest{
|
||||||
request.Request = reqWithAnnotations
|
request: request,
|
||||||
|
meta: generatorValues,
|
||||||
|
original: r.request,
|
||||||
|
dynamicValues: finalValues,
|
||||||
|
interactshURLs: r.interactshURLs,
|
||||||
}
|
}
|
||||||
|
|
||||||
return &generatedRequest{request: request, meta: generatorValues, original: r.request, dynamicValues: finalValues, interactshURLs: r.interactshURLs}, nil
|
if reqWithAnnotations, cancelFunc, hasAnnotations := r.request.parseAnnotations(rawRequest, req); hasAnnotations {
|
||||||
|
generatedRequest.request.Request = reqWithAnnotations
|
||||||
|
generatedRequest.customCancelFunction = cancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
return generatedRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// fillRequest fills various headers in the request with values
|
// fillRequest fills various headers in the request with values
|
||||||
|
|
|
@ -267,6 +267,11 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
|
||||||
request.options.Progress.IncrementFailedRequestsBy(int64(generator.Total()))
|
request.options.Progress.IncrementFailedRequestsBy(int64(generator.Total()))
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if generatedHttpRequest.customCancelFunction != nil {
|
||||||
|
defer generatedHttpRequest.customCancelFunction()
|
||||||
|
}
|
||||||
|
|
||||||
// If the variables contain interactsh urls, use them
|
// If the variables contain interactsh urls, use them
|
||||||
if len(interactURLs) > 0 {
|
if len(interactURLs) > 0 {
|
||||||
generatedHttpRequest.interactshURLs = append(generatedHttpRequest.interactshURLs, interactURLs...)
|
generatedHttpRequest.interactshURLs = append(generatedHttpRequest.interactshURLs, interactURLs...)
|
||||||
|
|
|
@ -49,9 +49,13 @@ func parseFlowAnnotations(rawRequest string) (flowMark, bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseAnnotations and override requests settings
|
// parseAnnotations and override requests settings
|
||||||
func (r *Request) parseAnnotations(rawRequest string, request *http.Request) (*http.Request, bool) {
|
func (r *Request) parseAnnotations(rawRequest string, request *http.Request) (*http.Request, context.CancelFunc, bool) {
|
||||||
|
var (
|
||||||
|
modified bool
|
||||||
|
cancelFunc context.CancelFunc
|
||||||
|
)
|
||||||
// parse request for known ovverride annotations
|
// parse request for known ovverride annotations
|
||||||
var modified bool
|
|
||||||
// @Host:target
|
// @Host:target
|
||||||
if hosts := reHostAnnotation.FindStringSubmatch(rawRequest); len(hosts) > 0 {
|
if hosts := reHostAnnotation.FindStringSubmatch(rawRequest); len(hosts) > 0 {
|
||||||
value := strings.TrimSpace(hosts[1])
|
value := strings.TrimSpace(hosts[1])
|
||||||
|
@ -97,22 +101,23 @@ func (r *Request) parseAnnotations(rawRequest string, request *http.Request) (*h
|
||||||
// @timeout:duration
|
// @timeout:duration
|
||||||
if r.connConfiguration.NoTimeout {
|
if r.connConfiguration.NoTimeout {
|
||||||
modified = true
|
modified = true
|
||||||
|
var ctx context.Context
|
||||||
|
|
||||||
if duration := reTimeoutAnnotation.FindStringSubmatch(rawRequest); len(duration) > 0 {
|
if duration := reTimeoutAnnotation.FindStringSubmatch(rawRequest); len(duration) > 0 {
|
||||||
value := strings.TrimSpace(duration[1])
|
value := strings.TrimSpace(duration[1])
|
||||||
if parsed, err := time.ParseDuration(value); err == nil {
|
if parsed, err := time.ParseDuration(value); err == nil {
|
||||||
//nolint:govet // cancelled automatically by withTimeout
|
//nolint:govet // cancelled automatically by withTimeout
|
||||||
ctx, _ := context.WithTimeout(context.Background(), parsed)
|
ctx, cancelFunc = context.WithTimeout(context.Background(), parsed)
|
||||||
request = request.Clone(ctx)
|
request = request.Clone(ctx)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
//nolint:govet // cancelled automatically by withTimeout
|
//nolint:govet // cancelled automatically by withTimeout
|
||||||
ctx, _ := context.WithTimeout(context.Background(), time.Duration(r.options.Options.Timeout)*time.Second)
|
ctx, cancelFunc = context.WithTimeout(context.Background(), time.Duration(r.options.Options.Timeout)*time.Second)
|
||||||
request = request.Clone(ctx)
|
request = request.Clone(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return request, modified
|
return request, cancelFunc, modified
|
||||||
}
|
}
|
||||||
|
|
||||||
func isHostPort(value string) bool {
|
func isHostPort(value string) bool {
|
||||||
|
|
|
@ -21,7 +21,8 @@ func TestRequestParseAnnotationsTimeout(t *testing.T) {
|
||||||
httpReq, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
httpReq, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
require.Nil(t, err, "could not create http request")
|
require.Nil(t, err, "could not create http request")
|
||||||
|
|
||||||
newRequest, modified := request.parseAnnotations(rawRequest, httpReq)
|
newRequest, cancelFunc, modified := request.parseAnnotations(rawRequest, httpReq)
|
||||||
|
require.NotNil(t, cancelFunc, "could not initialize valid cancel function")
|
||||||
require.True(t, modified, "could not get correct modified value")
|
require.True(t, modified, "could not get correct modified value")
|
||||||
_, deadlined := newRequest.Context().Deadline()
|
_, deadlined := newRequest.Context().Deadline()
|
||||||
require.True(t, deadlined, "could not get set request deadline")
|
require.True(t, deadlined, "could not get set request deadline")
|
||||||
|
@ -37,7 +38,8 @@ func TestRequestParseAnnotationsTimeout(t *testing.T) {
|
||||||
httpReq, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com", nil)
|
httpReq, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com", nil)
|
||||||
require.Nil(t, err, "could not create http request")
|
require.Nil(t, err, "could not create http request")
|
||||||
|
|
||||||
newRequest, modified := request.parseAnnotations(rawRequest, httpReq)
|
newRequest, cancelFunc, modified := request.parseAnnotations(rawRequest, httpReq)
|
||||||
|
require.Nil(t, cancelFunc, "cancel function should be nil")
|
||||||
require.False(t, modified, "could not get correct modified value")
|
require.False(t, modified, "could not get correct modified value")
|
||||||
_, deadlined := newRequest.Context().Deadline()
|
_, deadlined := newRequest.Context().Deadline()
|
||||||
require.False(t, deadlined, "could not get set request deadline")
|
require.False(t, deadlined, "could not get set request deadline")
|
||||||
|
|
Loading…
Reference in New Issue