moving aws signing logic to helper library

dev
mzack 2021-11-13 02:13:48 +01:00
parent 34889d50f8
commit e517797cfa
2 changed files with 102 additions and 22 deletions

View File

@ -0,0 +1,85 @@
package http
import (
"bytes"
"context"
"errors"
"io/ioutil"
"net/http"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
)
type AwsSigner struct {
creds *credentials.Credentials
signer *v4.Signer
}
type SignArguments struct {
Service string
Region string
Time time.Time
}
func NewAwsSigner(awsId, awsSecretToken string) (*AwsSigner, error) {
if awsId == "" {
return nil, errors.New("empty id")
}
if awsSecretToken == "" {
return nil, errors.New("empty token")
}
creds := credentials.NewStaticCredentials(awsId, awsSecretToken, "")
if creds == nil {
return nil, errors.New("couldn't create the credentials structure")
}
signer := v4.NewSigner(creds)
return &AwsSigner{creds: creds, signer: signer}, nil
}
func NewAwsSignerFromEnv() (*AwsSigner, error) {
creds := credentials.NewEnvCredentials()
if creds == nil {
return nil, errors.New("couldn't create the credentials structure")
}
return &AwsSigner{creds: creds}, nil
}
func (awsSigner *AwsSigner) SignHTTP(request *http.Request, args SignArguments) error {
awsSigner.prepareRequest(request)
var body *bytes.Reader
if request.Body != nil {
bodyBytes, err := ioutil.ReadAll(request.Body)
if err != nil {
return err
}
request.Body.Close()
body = bytes.NewReader(bodyBytes)
}
if _, err := awsSigner.signer.Sign(request, body, args.Service, args.Region, args.Time); err != nil {
return err
}
return nil
}
func (awsSigner *AwsSigner) CalculateHTTPHeaders(request *http.Request, args SignArguments) (map[string]string, error) {
reqClone := request.Clone(context.Background())
awsSigner.prepareRequest(reqClone)
err := awsSigner.SignHTTP(reqClone, args)
if err != nil {
return nil, err
}
headers := make(map[string]string)
headers["X-Amz-Date"] = reqClone.Header.Get("X-Amz-Date")
headers["Authorization"] = reqClone.Header.Get("Authorization")
return headers, nil
}
func (awsSigner *AwsSigner) prepareRequest(request *http.Request) {
request.Header.Del("Host")
}

View File

@ -13,8 +13,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/aws/aws-sdk-go/aws/credentials"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/remeh/sizedwaitgroup" "github.com/remeh/sizedwaitgroup"
"go.uber.org/multierr" "go.uber.org/multierr"
@ -31,6 +29,7 @@ import (
"github.com/projectdiscovery/nuclei/v2/pkg/protocols/common/tostring" "github.com/projectdiscovery/nuclei/v2/pkg/protocols/common/tostring"
"github.com/projectdiscovery/nuclei/v2/pkg/protocols/http/httpclientpool" "github.com/projectdiscovery/nuclei/v2/pkg/protocols/http/httpclientpool"
templateTypes "github.com/projectdiscovery/nuclei/v2/pkg/templates/types" templateTypes "github.com/projectdiscovery/nuclei/v2/pkg/templates/types"
"github.com/projectdiscovery/nuclei/v2/pkg/types"
"github.com/projectdiscovery/rawhttp" "github.com/projectdiscovery/rawhttp"
"github.com/projectdiscovery/stringsutil" "github.com/projectdiscovery/stringsutil"
) )
@ -352,33 +351,29 @@ func (request *Request) executeRequest(reqURL string, generatedRequest *generate
if resp == nil { if resp == nil {
// aws sign the request if necessary // aws sign the request if necessary
if request.AwsSign { if request.AwsSign {
generatedRequest.request.Header.Del("Host") var awsSigner *AwsSigner
payloads := request.options.Options.Vars.AsMap() payloads := request.options.Options.Vars.AsMap()
var creds *credentials.Credentials
if request.options.Options.EnvironmentVariables { if request.options.Options.EnvironmentVariables {
// get from env var err error
creds = credentials.NewEnvCredentials() awsSigner, err = NewAwsSignerFromEnv()
} else { // get from variables { if err != nil {
awsAccessKeyId := payloads["aws-id"] return err
awsSecretAccessKey := payloads["aws-secret"] }
creds = credentials.NewStaticCredentials(awsAccessKeyId.(string), awsSecretAccessKey.(string), "") } else { // get from variables {
} awsAccessKeyId := types.ToString(payloads["aws-id"])
awsSecretAccessKey := types.ToString(payloads["aws-secret"])
signer := v4.NewSigner(creds) awsSigner, err = NewAwsSigner(awsAccessKeyId, awsSecretAccessKey)
var body *bytes.Reader
if generatedRequest.request.Request.Body != nil {
bodyBytes, err := ioutil.ReadAll(generatedRequest.request.Request.Body)
if err != nil { if err != nil {
return err return err
} }
generatedRequest.request.Request.Body.Close()
body = bytes.NewReader(bodyBytes)
} }
service := payloads["service"].(string) args := SignArguments{
region := payloads["region"].(string) Service: types.ToString(payloads["service"]),
now := time.Now() Region: types.ToString(payloads["region"]),
_, err = signer.Sign(generatedRequest.request.Request, body, service, region, now) Time: time.Now(),
}
err = awsSigner.SignHTTP(generatedRequest.request.Request, args)
if err != nil { if err != nil {
return err return err
} }