Adding custom cancel function

dev
mzack 2022-10-10 08:10:07 +02:00
parent 09ceb29ba3
commit 70cecf83fb
4 changed files with 39 additions and 17 deletions

View File

@ -45,6 +45,7 @@ type generatedRequest struct {
request *retryablehttp.Request
dynamicValues map[string]interface{}
interactshURLs []string
customCancelFunction context.CancelFunc
}
func (g *generatedRequest) URL() string {
@ -297,11 +298,20 @@ func (r *requestGenerator) handleRawWithPayloads(ctx context.Context, rawRequest
return nil, err
}
if reqWithAnnotations, hasAnnotations := r.request.parseAnnotations(rawRequest, req); hasAnnotations {
request.Request = reqWithAnnotations
generatedRequest := &generatedRequest{
request: request,
meta: generatorValues,
original: r.request,
dynamicValues: finalValues,
interactshURLs: r.interactshURLs,
}
return &generatedRequest{request: request, meta: generatorValues, original: r.request, dynamicValues: finalValues, interactshURLs: r.interactshURLs}, nil
if reqWithAnnotations, cancelFunc, hasAnnotations := r.request.parseAnnotations(rawRequest, req); hasAnnotations {
generatedRequest.request.Request = reqWithAnnotations
generatedRequest.customCancelFunction = cancelFunc
}
return generatedRequest, nil
}
// fillRequest fills various headers in the request with values

View File

@ -267,6 +267,11 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
request.options.Progress.IncrementFailedRequestsBy(int64(generator.Total()))
return true, err
}
if generatedHttpRequest.customCancelFunction != nil {
defer generatedHttpRequest.customCancelFunction()
}
// If the variables contain interactsh urls, use them
if len(interactURLs) > 0 {
generatedHttpRequest.interactshURLs = append(generatedHttpRequest.interactshURLs, interactURLs...)

View File

@ -49,9 +49,13 @@ func parseFlowAnnotations(rawRequest string) (flowMark, bool) {
}
// parseAnnotations and override requests settings
func (r *Request) parseAnnotations(rawRequest string, request *http.Request) (*http.Request, bool) {
func (r *Request) parseAnnotations(rawRequest string, request *http.Request) (*http.Request, context.CancelFunc, bool) {
var (
modified bool
cancelFunc context.CancelFunc
)
// parse request for known ovverride annotations
var modified bool
// @Host:target
if hosts := reHostAnnotation.FindStringSubmatch(rawRequest); len(hosts) > 0 {
value := strings.TrimSpace(hosts[1])
@ -97,22 +101,23 @@ func (r *Request) parseAnnotations(rawRequest string, request *http.Request) (*h
// @timeout:duration
if r.connConfiguration.NoTimeout {
modified = true
var ctx context.Context
if duration := reTimeoutAnnotation.FindStringSubmatch(rawRequest); len(duration) > 0 {
value := strings.TrimSpace(duration[1])
if parsed, err := time.ParseDuration(value); err == nil {
//nolint:govet // cancelled automatically by withTimeout
ctx, _ := context.WithTimeout(context.Background(), parsed)
ctx, cancelFunc = context.WithTimeout(context.Background(), parsed)
request = request.Clone(ctx)
}
} else {
//nolint:govet // cancelled automatically by withTimeout
ctx, _ := context.WithTimeout(context.Background(), time.Duration(r.options.Options.Timeout)*time.Second)
ctx, cancelFunc = context.WithTimeout(context.Background(), time.Duration(r.options.Options.Timeout)*time.Second)
request = request.Clone(ctx)
}
}
return request, modified
return request, cancelFunc, modified
}
func isHostPort(value string) bool {

View File

@ -21,7 +21,8 @@ func TestRequestParseAnnotationsTimeout(t *testing.T) {
httpReq, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
require.Nil(t, err, "could not create http request")
newRequest, modified := request.parseAnnotations(rawRequest, httpReq)
newRequest, cancelFunc, modified := request.parseAnnotations(rawRequest, httpReq)
require.NotNil(t, cancelFunc, "could not initialize valid cancel function")
require.True(t, modified, "could not get correct modified value")
_, deadlined := newRequest.Context().Deadline()
require.True(t, deadlined, "could not get set request deadline")
@ -37,7 +38,8 @@ func TestRequestParseAnnotationsTimeout(t *testing.T) {
httpReq, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://example.com", nil)
require.Nil(t, err, "could not create http request")
newRequest, modified := request.parseAnnotations(rawRequest, httpReq)
newRequest, cancelFunc, modified := request.parseAnnotations(rawRequest, httpReq)
require.Nil(t, cancelFunc, "cancel function should be nil")
require.False(t, modified, "could not get correct modified value")
_, deadlined := newRequest.Context().Deadline()
require.False(t, deadlined, "could not get set request deadline")