Adding custom cancel function

dev
mzack 2022-10-10 08:10:07 +02:00
parent 09ceb29ba3
commit 70cecf83fb
4 changed files with 39 additions and 17 deletions

View File

@ -45,6 +45,7 @@ type generatedRequest struct {
request *retryablehttp.Request request *retryablehttp.Request
dynamicValues map[string]interface{} dynamicValues map[string]interface{}
interactshURLs []string interactshURLs []string
customCancelFunction context.CancelFunc
} }
func (g *generatedRequest) URL() string { func (g *generatedRequest) URL() string {
@ -297,11 +298,20 @@ func (r *requestGenerator) handleRawWithPayloads(ctx context.Context, rawRequest
return nil, err return nil, err
} }
if reqWithAnnotations, hasAnnotations := r.request.parseAnnotations(rawRequest, req); hasAnnotations { generatedRequest := &generatedRequest{
request.Request = reqWithAnnotations request: request,
meta: generatorValues,
original: r.request,
dynamicValues: finalValues,
interactshURLs: r.interactshURLs,
} }
return &generatedRequest{request: request, meta: generatorValues, original: r.request, dynamicValues: finalValues, interactshURLs: r.interactshURLs}, nil if reqWithAnnotations, cancelFunc, hasAnnotations := r.request.parseAnnotations(rawRequest, req); hasAnnotations {
generatedRequest.request.Request = reqWithAnnotations
generatedRequest.customCancelFunction = cancelFunc
}
return generatedRequest, nil
} }
// fillRequest fills various headers in the request with values // fillRequest fills various headers in the request with values

View File

@ -267,6 +267,11 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
request.options.Progress.IncrementFailedRequestsBy(int64(generator.Total())) request.options.Progress.IncrementFailedRequestsBy(int64(generator.Total()))
return true, err return true, err
} }
if generatedHttpRequest.customCancelFunction != nil {
defer generatedHttpRequest.customCancelFunction()
}
// If the variables contain interactsh urls, use them // If the variables contain interactsh urls, use them
if len(interactURLs) > 0 { if len(interactURLs) > 0 {
generatedHttpRequest.interactshURLs = append(generatedHttpRequest.interactshURLs, interactURLs...) generatedHttpRequest.interactshURLs = append(generatedHttpRequest.interactshURLs, interactURLs...)

View File

@ -49,9 +49,13 @@ func parseFlowAnnotations(rawRequest string) (flowMark, bool) {
} }
// parseAnnotations and override requests settings // parseAnnotations and override requests settings
func (r *Request) parseAnnotations(rawRequest string, request *http.Request) (*http.Request, bool) { func (r *Request) parseAnnotations(rawRequest string, request *http.Request) (*http.Request, context.CancelFunc, bool) {
var (
modified bool
cancelFunc context.CancelFunc
)
// parse request for known ovverride annotations // parse request for known ovverride annotations
var modified bool
// @Host:target // @Host:target
if hosts := reHostAnnotation.FindStringSubmatch(rawRequest); len(hosts) > 0 { if hosts := reHostAnnotation.FindStringSubmatch(rawRequest); len(hosts) > 0 {
value := strings.TrimSpace(hosts[1]) value := strings.TrimSpace(hosts[1])
@ -97,22 +101,23 @@ func (r *Request) parseAnnotations(rawRequest string, request *http.Request) (*h
// @timeout:duration // @timeout:duration
if r.connConfiguration.NoTimeout { if r.connConfiguration.NoTimeout {
modified = true modified = true
var ctx context.Context
if duration := reTimeoutAnnotation.FindStringSubmatch(rawRequest); len(duration) > 0 { if duration := reTimeoutAnnotation.FindStringSubmatch(rawRequest); len(duration) > 0 {
value := strings.TrimSpace(duration[1]) value := strings.TrimSpace(duration[1])
if parsed, err := time.ParseDuration(value); err == nil { if parsed, err := time.ParseDuration(value); err == nil {
//nolint:govet // cancelled automatically by withTimeout //nolint:govet // cancelled automatically by withTimeout
ctx, _ := context.WithTimeout(context.Background(), parsed) ctx, cancelFunc = context.WithTimeout(context.Background(), parsed)
request = request.Clone(ctx) request = request.Clone(ctx)
} }
} else { } else {
//nolint:govet // cancelled automatically by withTimeout //nolint:govet // cancelled automatically by withTimeout
ctx, _ := context.WithTimeout(context.Background(), time.Duration(r.options.Options.Timeout)*time.Second) ctx, cancelFunc = context.WithTimeout(context.Background(), time.Duration(r.options.Options.Timeout)*time.Second)
request = request.Clone(ctx) request = request.Clone(ctx)
} }
} }
return request, modified return request, cancelFunc, modified
} }
func isHostPort(value string) bool { func isHostPort(value string) bool {

View File

@ -21,7 +21,8 @@ func TestRequestParseAnnotationsTimeout(t *testing.T) {
httpReq, err := http.NewRequest(http.MethodGet, "https://example.com", nil) httpReq, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
require.Nil(t, err, "could not create http request") require.Nil(t, err, "could not create http request")
newRequest, modified := request.parseAnnotations(rawRequest, httpReq) newRequest, cancelFunc, modified := request.parseAnnotations(rawRequest, httpReq)
require.NotNil(t, cancelFunc, "could not initialize valid cancel function")
require.True(t, modified, "could not get correct modified value") require.True(t, modified, "could not get correct modified value")
_, deadlined := newRequest.Context().Deadline() _, deadlined := newRequest.Context().Deadline()
require.True(t, deadlined, "could not get set request deadline") require.True(t, deadlined, "could not get set request deadline")
@ -37,7 +38,8 @@ func TestRequestParseAnnotationsTimeout(t *testing.T) {
httpReq, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com", nil) httpReq, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com", nil)
require.Nil(t, err, "could not create http request") require.Nil(t, err, "could not create http request")
newRequest, modified := request.parseAnnotations(rawRequest, httpReq) newRequest, cancelFunc, modified := request.parseAnnotations(rawRequest, httpReq)
require.Nil(t, cancelFunc, "cancel function should be nil")
require.False(t, modified, "could not get correct modified value") require.False(t, modified, "could not get correct modified value")
_, deadlined := newRequest.Context().Deadline() _, deadlined := newRequest.Context().Deadline()
require.False(t, deadlined, "could not get set request deadline") require.False(t, deadlined, "could not get set request deadline")