extending request/response hijacking with native calls (#3091)

* extending request/response hijacking with native calls

* fixing tests
dev
Mzack9999 2023-01-05 12:56:18 +01:00 committed by GitHub
parent 4aa2002e72
commit a96f764959
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 337 additions and 75 deletions

View File

@ -0,0 +1,29 @@
id: file-upload
info:
name: Basic File Upload
author: pdteam
severity: info
headless:
- steps:
- action: navigate
args:
url: "{{BaseURL}}"
- action: waitload
- action: files
args:
by: xpath
xpath: /html/body/form/input[1]
value: headless/file-upload.yaml
- action: sleep
args:
duration: 2
- action: click
args:
by: x
xpath: /html/body/form/input[2]
matchers:
- type: word
words:
- "Basic File Upload"

View File

@ -1,6 +1,7 @@
package main
import (
"io"
"net/http"
"net/http/httptest"
@ -15,6 +16,7 @@ var headlessTestcases = map[string]testutils.TestCase{
"headless/headless-extract-values.yaml": &headlessExtractValues{},
"headless/headless-payloads.yaml": &headlessPayloads{},
"headless/variables.yaml": &headlessVariables{},
"headless/file-upload.yaml": &headlessFileUpload{},
}
type headlessBasic struct{}
@ -111,3 +113,48 @@ func (h *headlessVariables) Execute(filePath string) error {
return expectResultsCount(results, 1)
}
type headlessFileUpload struct{}
// Execute executes a test case and returns an error if occurred
func (h *headlessFileUpload) Execute(filePath string) error {
router := httprouter.New()
router.GET("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
_, _ = w.Write([]byte(`
<!doctype html>
<body>
<form method=post enctype=multipart/form-data>
<input type=file name=file>
<input type=submit value=Upload>
</form>
</body>
</html>
`))
})
router.POST("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
file, _, err := r.FormFile("file")
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
defer file.Close()
content, err := io.ReadAll(file)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
_, _ = w.Write(content)
})
ts := httptest.NewServer(router)
defer ts.Close()
results, err := testutils.RunNucleiTemplateAndGetResults(filePath, ts.URL, debug, "-headless")
if err != nil {
return err
}
return expectResultsCount(results, 1)
}

View File

@ -0,0 +1,99 @@
package engine
// TODO: redundant from katana - the whole headless package should be replace with katana
import (
"encoding/base64"
"github.com/go-rod/rod"
"github.com/go-rod/rod/lib/proto"
)
// NewHijack create hijack from page.
func NewHijack(page *rod.Page) *Hijack {
return &Hijack{
page: page,
disable: &proto.FetchDisable{},
}
}
// HijackHandler type
type HijackHandler = func(e *proto.FetchRequestPaused) error
// Hijack is a hijack handler
type Hijack struct {
page *rod.Page
enable *proto.FetchEnable
disable *proto.FetchDisable
cancel func()
}
// SetPattern set pattern directly
func (h *Hijack) SetPattern(pattern *proto.FetchRequestPattern) {
h.enable = &proto.FetchEnable{
Patterns: []*proto.FetchRequestPattern{pattern},
}
}
// Start hijack.
func (h *Hijack) Start(handler HijackHandler) func() error {
if h.enable == nil {
panic("hijack pattern not set")
}
p, cancel := h.page.WithCancel()
h.cancel = cancel
err := h.enable.Call(p)
if err != nil {
return func() error { return err }
}
wait := p.EachEvent(func(e *proto.FetchRequestPaused) {
if handler != nil {
err = handler(e)
}
})
return func() error {
wait()
return err
}
}
// Stop
func (h *Hijack) Stop() error {
if h.cancel != nil {
h.cancel()
}
return h.disable.Call(h.page)
}
// FetchGetResponseBody get request body.
func FetchGetResponseBody(page *rod.Page, e *proto.FetchRequestPaused) ([]byte, error) {
m := proto.FetchGetResponseBody{
RequestID: e.RequestID,
}
r, err := m.Call(page)
if err != nil {
return nil, err
}
if !r.Base64Encoded {
return []byte(r.Body), nil
}
bs, err := base64.StdEncoding.DecodeString(r.Body)
if err != nil {
return nil, err
}
return bs, nil
}
// FetchContinueRequest continue request
func FetchContinueRequest(page *rod.Page, e *proto.FetchRequestPaused) error {
m := proto.FetchContinueRequest{
RequestID: e.RequestID,
}
return m.Call(page)
}

View File

@ -13,9 +13,10 @@ import (
// Page is a single page in an isolated browser instance
type Page struct {
page *rod.Page
rules []requestRule
rules []rule
instance *Instance
router *rod.HijackRouter
hijackRouter *rod.HijackRouter
hijackNative *Hijack
mutex *sync.RWMutex
History []HistoryData
InteractshURLs []string
@ -43,11 +44,27 @@ func (i *Instance) Run(baseURL *url.URL, actions []*Action, payloads map[string]
}
createdPage := &Page{page: page, instance: i, mutex: &sync.RWMutex{}, payloads: payloads}
router := page.HijackRequests()
if routerErr := router.Add("*", "", createdPage.routingRuleHandler); routerErr != nil {
return nil, nil, routerErr
// in case the page has request/response modification rules - enable global hijacking
if createdPage.hasModificationRules() || containsModificationActions(actions...) {
hijackRouter := page.HijackRequests()
if err := hijackRouter.Add("*", "", createdPage.routingRuleHandler); err != nil {
return nil, nil, err
}
createdPage.hijackRouter = hijackRouter
go hijackRouter.Run()
} else {
hijackRouter := NewHijack(page)
hijackRouter.SetPattern(&proto.FetchRequestPattern{
URLPattern: "*",
RequestStage: proto.FetchRequestStageResponse,
})
createdPage.hijackNative = hijackRouter
hijackRouterHandler := hijackRouter.Start(createdPage.routingRuleHandlerNative)
go func() {
_ = hijackRouterHandler()
}()
}
createdPage.router = router
if err := page.SetViewport(&proto.EmulationSetDeviceMetricsOverride{Viewport: &proto.PageViewport{
Scale: 1,
@ -61,7 +78,6 @@ func (i *Instance) Run(baseURL *url.URL, actions []*Action, payloads map[string]
return nil, nil, err
}
go router.Run()
data, err := createdPage.ExecuteActions(baseURL, actions)
if err != nil {
return nil, nil, err
@ -71,7 +87,12 @@ func (i *Instance) Run(baseURL *url.URL, actions []*Action, payloads map[string]
// Close closes a browser page
func (p *Page) Close() {
_ = p.router.Stop()
if p.hijackRouter != nil {
_ = p.hijackRouter.Stop()
}
if p.hijackNative != nil {
_ = p.hijackNative.Stop()
}
p.page.Close()
}
@ -121,3 +142,40 @@ func (p *Page) addInteractshURL(URLs ...string) {
p.InteractshURLs = append(p.InteractshURLs, URLs...)
}
func (p *Page) hasModificationRules() bool {
for _, rule := range p.rules {
if containsAnyModificationActionType(rule.Action) {
return true
}
}
return false
}
func containsModificationActions(actions ...*Action) bool {
for _, action := range actions {
if containsAnyModificationActionType(action.ActionType.ActionType) {
return true
}
}
return false
}
func containsAnyModificationActionType(actionTypes ...ActionType) bool {
for _, actionType := range actionTypes {
switch actionType {
case ActionSetMethod:
return true
case ActionAddHeader:
return true
case ActionSetHeader:
return true
case ActionDeleteHeader:
return true
case ActionSetBody:
return true
}
}
return false
}

View File

@ -20,12 +20,14 @@ import (
)
var (
invalidArgumentsError = errors.New("invalid arguments provided")
errinvalidArguments = errors.New("invalid arguments provided")
reUrlWithPort = regexp.MustCompile(`{{BaseURL}}:(\d+)`)
)
const (
couldNotGetElementErrorMessage = "could not get element"
couldNotScrollErrorMessage = "could not scroll into view"
errCouldNotGetElement = "could not get element"
errCouldNotScroll = "could not scroll into view"
errElementDidNotAppear = "Element did not appear in the given amount of time"
)
// ExecuteActions executes a list of actions on a page.
@ -89,14 +91,12 @@ func (p *Page) ExecuteActions(baseURL *url.URL, actions []*Action) (map[string]s
return outData, nil
}
type requestRule struct {
type rule struct {
Action ActionType
Part string
Args map[string]string
}
const elementDidNotAppearMessage = "Element did not appear in the given amount of time"
// WaitVisible waits until an element appears.
func (p *Page) WaitVisible(act *Action, out map[string]string) error {
timeout, err := getTimeout(p, act)
@ -115,10 +115,10 @@ func (p *Page) WaitVisible(act *Action, out map[string]string) error {
if element != nil {
if err := element.WaitVisible(); err != nil {
return errors.Wrap(err, elementDidNotAppearMessage)
return errors.Wrap(err, errElementDidNotAppear)
}
} else {
return errors.New(elementDidNotAppearMessage)
return errors.New(errElementDidNotAppear)
}
return nil
@ -179,12 +179,7 @@ func (p *Page) ActionAddHeader(act *Action, out map[string]string /*TODO review
args := make(map[string]string)
args["key"] = p.getActionArgWithDefaultValues(act, "key")
args["value"] = p.getActionArgWithDefaultValues(act, "value")
rule := requestRule{
Action: ActionAddHeader,
Part: in,
Args: args,
}
p.rules = append(p.rules, rule)
p.rules = append(p.rules, rule{Action: ActionAddHeader, Part: in, Args: args})
return nil
}
@ -195,12 +190,7 @@ func (p *Page) ActionSetHeader(act *Action, out map[string]string /*TODO review
args := make(map[string]string)
args["key"] = p.getActionArgWithDefaultValues(act, "key")
args["value"] = p.getActionArgWithDefaultValues(act, "value")
rule := requestRule{
Action: ActionSetHeader,
Part: in,
Args: args,
}
p.rules = append(p.rules, rule)
p.rules = append(p.rules, rule{Action: ActionSetHeader, Part: in, Args: args})
return nil
}
@ -210,12 +200,7 @@ func (p *Page) ActionDeleteHeader(act *Action, out map[string]string /*TODO revi
args := make(map[string]string)
args["key"] = p.getActionArgWithDefaultValues(act, "key")
rule := requestRule{
Action: ActionDeleteHeader,
Part: in,
Args: args,
}
p.rules = append(p.rules, rule)
p.rules = append(p.rules, rule{Action: ActionDeleteHeader, Part: in, Args: args})
return nil
}
@ -225,12 +210,7 @@ func (p *Page) ActionSetBody(act *Action, out map[string]string /*TODO review un
args := make(map[string]string)
args["body"] = p.getActionArgWithDefaultValues(act, "body")
rule := requestRule{
Action: ActionSetBody,
Part: in,
Args: args,
}
p.rules = append(p.rules, rule)
p.rules = append(p.rules, rule{Action: ActionSetBody, Part: in, Args: args})
return nil
}
@ -240,12 +220,7 @@ func (p *Page) ActionSetMethod(act *Action, out map[string]string /*TODO review
args := make(map[string]string)
args["method"] = p.getActionArgWithDefaultValues(act, "method")
rule := requestRule{
Action: ActionSetMethod,
Part: in,
Args: args,
}
p.rules = append(p.rules, rule)
p.rules = append(p.rules, rule{Action: ActionSetMethod, Part: in, Args: args})
return nil
}
@ -253,7 +228,7 @@ func (p *Page) ActionSetMethod(act *Action, out map[string]string /*TODO review
func (p *Page) NavigateURL(action *Action, out map[string]string, parsed *url.URL /*TODO review unused parameter*/) error {
URL := p.getActionArgWithDefaultValues(action, "url")
if URL == "" {
return invalidArgumentsError
return errinvalidArguments
}
// Handle the dynamic value substitution here.
@ -277,7 +252,7 @@ func (p *Page) NavigateURL(action *Action, out map[string]string, parsed *url.UR
func (p *Page) RunScript(action *Action, out map[string]string) error {
code := p.getActionArgWithDefaultValues(action, "code")
if code == "" {
return invalidArgumentsError
return errinvalidArguments
}
if p.getActionArgWithDefaultValues(action, "hook") == "true" {
if _, err := p.page.EvalOnNewDocument(code); err != nil {
@ -298,10 +273,10 @@ func (p *Page) RunScript(action *Action, out map[string]string) error {
func (p *Page) ClickElement(act *Action, out map[string]string /*TODO review unused parameter*/) error {
element, err := p.pageElementBy(act.Data)
if err != nil {
return errors.Wrap(err, couldNotGetElementErrorMessage)
return errors.Wrap(err, errCouldNotGetElement)
}
if err = element.ScrollIntoView(); err != nil {
return errors.Wrap(err, couldNotScrollErrorMessage)
return errors.Wrap(err, errCouldNotScroll)
}
if err = element.Click(proto.InputMouseButtonLeft, 1); err != nil {
return errors.Wrap(err, "could not click element")
@ -318,10 +293,10 @@ func (p *Page) KeyboardAction(act *Action, out map[string]string /*TODO review u
func (p *Page) RightClickElement(act *Action, out map[string]string /*TODO review unused parameter*/) error {
element, err := p.pageElementBy(act.Data)
if err != nil {
return errors.Wrap(err, couldNotGetElementErrorMessage)
return errors.Wrap(err, errCouldNotGetElement)
}
if err = element.ScrollIntoView(); err != nil {
return errors.Wrap(err, couldNotScrollErrorMessage)
return errors.Wrap(err, errCouldNotScroll)
}
if err = element.Click(proto.InputMouseButtonRight, 1); err != nil {
return errors.Wrap(err, "could not right click element")
@ -359,14 +334,14 @@ func (p *Page) Screenshot(act *Action, out map[string]string) error {
func (p *Page) InputElement(act *Action, out map[string]string /*TODO review unused parameter*/) error {
value := p.getActionArgWithDefaultValues(act, "value")
if value == "" {
return invalidArgumentsError
return errinvalidArguments
}
element, err := p.pageElementBy(act.Data)
if err != nil {
return errors.Wrap(err, couldNotGetElementErrorMessage)
return errors.Wrap(err, errCouldNotGetElement)
}
if err = element.ScrollIntoView(); err != nil {
return errors.Wrap(err, couldNotScrollErrorMessage)
return errors.Wrap(err, errCouldNotScroll)
}
if err = element.Input(value); err != nil {
return errors.Wrap(err, "could not input element")
@ -378,14 +353,14 @@ func (p *Page) InputElement(act *Action, out map[string]string /*TODO review unu
func (p *Page) TimeInputElement(act *Action, out map[string]string /*TODO review unused parameter*/) error {
value := p.getActionArgWithDefaultValues(act, "value")
if value == "" {
return invalidArgumentsError
return errinvalidArguments
}
element, err := p.pageElementBy(act.Data)
if err != nil {
return errors.Wrap(err, couldNotGetElementErrorMessage)
return errors.Wrap(err, errCouldNotGetElement)
}
if err = element.ScrollIntoView(); err != nil {
return errors.Wrap(err, couldNotScrollErrorMessage)
return errors.Wrap(err, errCouldNotScroll)
}
t, err := time.Parse(time.RFC3339, value)
if err != nil {
@ -401,14 +376,14 @@ func (p *Page) TimeInputElement(act *Action, out map[string]string /*TODO review
func (p *Page) SelectInputElement(act *Action, out map[string]string /*TODO review unused parameter*/) error {
value := p.getActionArgWithDefaultValues(act, "value")
if value == "" {
return invalidArgumentsError
return errinvalidArguments
}
element, err := p.pageElementBy(act.Data)
if err != nil {
return errors.Wrap(err, couldNotGetElementErrorMessage)
return errors.Wrap(err, errCouldNotGetElement)
}
if err = element.ScrollIntoView(); err != nil {
return errors.Wrap(err, couldNotScrollErrorMessage)
return errors.Wrap(err, errCouldNotScroll)
}
selectedBool := false
@ -440,7 +415,7 @@ func (p *Page) WaitLoad(act *Action, out map[string]string /*TODO review unused
func (p *Page) GetResource(act *Action, out map[string]string) error {
element, err := p.pageElementBy(act.Data)
if err != nil {
return errors.Wrap(err, couldNotGetElementErrorMessage)
return errors.Wrap(err, errCouldNotGetElement)
}
resource, err := element.Resource()
if err != nil {
@ -456,10 +431,10 @@ func (p *Page) GetResource(act *Action, out map[string]string) error {
func (p *Page) FilesInput(act *Action, out map[string]string /*TODO review unused parameter*/) error {
element, err := p.pageElementBy(act.Data)
if err != nil {
return errors.Wrap(err, couldNotGetElementErrorMessage)
return errors.Wrap(err, errCouldNotGetElement)
}
if err = element.ScrollIntoView(); err != nil {
return errors.Wrap(err, couldNotScrollErrorMessage)
return errors.Wrap(err, errCouldNotScroll)
}
value := p.getActionArgWithDefaultValues(act, "value")
filesPaths := strings.Split(value, ",")
@ -473,10 +448,10 @@ func (p *Page) FilesInput(act *Action, out map[string]string /*TODO review unuse
func (p *Page) ExtractElement(act *Action, out map[string]string) error {
element, err := p.pageElementBy(act.Data)
if err != nil {
return errors.Wrap(err, couldNotGetElementErrorMessage)
return errors.Wrap(err, errCouldNotGetElement)
}
if err = element.ScrollIntoView(); err != nil {
return errors.Wrap(err, couldNotScrollErrorMessage)
return errors.Wrap(err, errCouldNotScroll)
}
switch p.getActionArgWithDefaultValues(act, "target") {
case "attribute":
@ -605,15 +580,11 @@ func selectorBy(selector string) rod.SelectorType {
}
}
var (
urlWithPortRegex = regexp.MustCompile(`{{BaseURL}}:(\d+)`)
)
// baseURLWithTemplatePrefs returns the url for BaseURL keeping
// the template port and path preference over the user provided one.
func baseURLWithTemplatePrefs(data string, parsed *url.URL) (string, *url.URL) {
// template port preference over input URL port if template has a port
matches := urlWithPortRegex.FindAllStringSubmatch(data, -1)
matches := reUrlWithPort.FindAllStringSubmatch(data, -1)
if len(matches) == 0 {
return data, parsed
}

View File

@ -537,3 +537,18 @@ func testHeadless(t *testing.T, actions []*Action, timeout time.Duration, handle
page.Close()
}
}
func TestContainsAnyModificationActionType(t *testing.T) {
if containsAnyModificationActionType() {
t.Error("Expected false, got true")
}
if containsAnyModificationActionType(ActionClick) {
t.Error("Expected false, got true")
}
if !containsAnyModificationActionType(ActionSetMethod, ActionAddHeader, ActionExtract) {
t.Error("Expected true, got false")
}
if !containsAnyModificationActionType(ActionSetMethod, ActionAddHeader, ActionSetHeader, ActionDeleteHeader, ActionSetBody) {
t.Error("Expected true, got false")
}
}

View File

@ -6,6 +6,7 @@ import (
"strings"
"github.com/go-rod/rod"
"github.com/go-rod/rod/lib/proto"
)
// routingRuleHandler handles proxy rule for actions related to request/response modification
@ -79,3 +80,45 @@ func (p *Page) routingRuleHandler(ctx *rod.Hijack) {
}
p.addToHistory(historyData)
}
// routingRuleHandlerNative handles native proxy rule
func (p *Page) routingRuleHandlerNative(e *proto.FetchRequestPaused) error {
body, _ := FetchGetResponseBody(p.page, e)
headers := make(map[string][]string)
for _, h := range e.ResponseHeaders {
headers[h.Name] = []string{h.Value}
}
var statusCode int
if e.ResponseStatusCode != nil {
statusCode = *e.ResponseStatusCode
}
// attempts to rebuild request
var rawReq strings.Builder
rawReq.WriteString(fmt.Sprintf("%s %s %s\n", e.Request.Method, e.Request.URL, "HTTP/1.1"))
for _, header := range e.Request.Headers {
rawReq.WriteString(fmt.Sprintf("%s\n", header.String()))
}
if e.Request.HasPostData {
rawReq.WriteString(fmt.Sprintf("\n%s\n", e.Request.PostData))
}
// attempts to rebuild the response
var rawResp strings.Builder
rawResp.WriteString(fmt.Sprintf("HTTP/1.1 %d %s\n", statusCode, e.ResponseStatusText))
for _, header := range e.ResponseHeaders {
rawResp.WriteString(header.Name + ": " + header.Value + "\n")
}
rawResp.WriteString("\n")
rawResp.Write(body)
// dump request
historyData := HistoryData{
RawRequest: rawReq.String(),
RawResponse: rawResp.String(),
}
p.addToHistory(historyData)
return FetchContinueRequest(p.page, e)
}

View File

@ -23,7 +23,7 @@ import (
var _ protocols.Request = &Request{}
const couldGetHtmlElementErrorMessage = "could get html element"
const errCouldGetHtmlElement = "could get html element"
// Type returns the type of the protocol request
func (request *Request) Type() templateTypes.ProtocolType {
@ -81,7 +81,7 @@ func (request *Request) executeRequestWithPayloads(inputURL string, payloads map
if err != nil {
request.options.Output.Request(request.options.TemplatePath, inputURL, request.Type().String(), err)
request.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, couldGetHtmlElementErrorMessage)
return errors.Wrap(err, errCouldGetHtmlElement)
}
defer instance.Close()
@ -95,14 +95,14 @@ func (request *Request) executeRequestWithPayloads(inputURL string, payloads map
if err != nil {
request.options.Output.Request(request.options.TemplatePath, inputURL, request.Type().String(), err)
request.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, couldGetHtmlElementErrorMessage)
return errors.Wrap(err, errCouldGetHtmlElement)
}
timeout := time.Duration(request.options.Options.PageTimeout) * time.Second
out, page, err := instance.Run(parsedURL, request.Steps, payloads, timeout)
if err != nil {
request.options.Output.Request(request.options.TemplatePath, inputURL, request.Type().String(), err)
request.options.Progress.IncrementFailedRequestsBy(1)
return errors.Wrap(err, couldGetHtmlElementErrorMessage)
return errors.Wrap(err, errCouldGetHtmlElement)
}
defer page.Close()