diff --git a/v2/pkg/protocols/others/websocket/websocket.go b/v2/pkg/protocols/others/websocket/websocket.go index db62c663..4cf3ff74 100644 --- a/v2/pkg/protocols/others/websocket/websocket.go +++ b/v2/pkg/protocols/others/websocket/websocket.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "io" + "net/http" "net/url" "strings" "time" @@ -32,6 +33,9 @@ type Request struct { // description: | // Inputs contains inputs for the websocket protocol Inputs []*Input `yaml:"inputs,omitempty" jsonschema:"title=inputs for the websocket request,description=Inputs contains any input/output for the current request"` + // description: | + // Origin is the websocket request origin. + Origin string `yaml:"origin,omitempty" jsonschema:"title=origin is the request origin,description=Origin is the websocket request origin"` // description: | // Attack is the type of payload combinations to perform. @@ -163,7 +167,15 @@ func (r *Request) ExecuteWithResults(input string, dynamicValues, previous outpu // ExecuteWithResults executes the protocol requests and returns results instead of writing them. func (r *Request) executeRequestWithPayloads(input, hostname string, dynamicValues, previous output.InternalEvent, callback protocols.OutputEventCallback) error { + var header ws.HandshakeHeader + + if r.Origin != "" { + header = ws.HandshakeHeaderHTTP(http.Header{ + "Origin": []string{r.Origin}, + }) + } websocketDialer := ws.Dialer{ + Header: header, Timeout: time.Duration(r.options.Options.Timeout) * time.Second, NetDial: r.dialer.Dial, TLSConfig: &tls.Config{InsecureSkipVerify: true, ServerName: hostname}, @@ -179,7 +191,7 @@ func (r *Request) executeRequestWithPayloads(input, hostname string, dynamicValu responseBuilder := &strings.Builder{} if readBuffer != nil { - io.Copy(responseBuilder, readBuffer) // Copy initial response + _, _ = io.Copy(responseBuilder, readBuffer) // Copy initial response } reqBuilder := &strings.Builder{} @@ -187,7 +199,6 @@ func (r *Request) executeRequestWithPayloads(input, hostname string, dynamicValu inputEvents := make(map[string]interface{}) for _, req := range r.Inputs { reqBuilder.Grow(len(req.Data)) - reqBuilder.WriteString(req.Data) finalData, dataErr := expressions.EvaluateByte([]byte(req.Data), dynamicValues) if dataErr != nil { @@ -195,6 +206,8 @@ func (r *Request) executeRequestWithPayloads(input, hostname string, dynamicValu r.options.Progress.IncrementFailedRequestsBy(1) return errors.Wrap(dataErr, "could not evaluate template expressions") } + reqBuilder.WriteString(string(finalData)) + err = wsutil.WriteClientMessage(conn, ws.OpText, finalData) if err != nil { r.options.Output.Request(r.options.TemplateID, input, "websocket", err) @@ -252,6 +265,7 @@ func (r *Request) executeRequestWithPayloads(input, hostname string, dynamicValu for k, v := range inputEvents { data[k] = v } + data["success"] = "true" data["request"] = reqBuilder.String() data["response"] = responseBuilder.String() data["host"] = input @@ -294,7 +308,7 @@ func (r *Request) makeResultEventItem(wrapped *output.InternalWrappedEvent) *out Timestamp: time.Now(), IP: types.ToString(wrapped.InternalEvent["ip"]), Request: types.ToString(wrapped.InternalEvent["request"]), - Response: types.ToString(wrapped.InternalEvent["responses"]), + Response: types.ToString(wrapped.InternalEvent["response"]), } return data }