mirror of https://github.com/daffainfo/nuclei.git
Add template sign/verify functionality (#3029)
* add template sign/verify functionality * fixing syntaxdev
parent
aeb5dbd293
commit
62af038617
|
@ -0,0 +1,148 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/projectdiscovery/goflags"
|
||||
"github.com/projectdiscovery/nuclei/v2/pkg/templates/signer"
|
||||
stringsutil "github.com/projectdiscovery/utils/strings"
|
||||
)
|
||||
|
||||
type options struct {
|
||||
Templates goflags.StringSlice
|
||||
Algorithm string
|
||||
PrivateKeyName string
|
||||
PrivateKeyPassPhrase string
|
||||
PublicKeyName string
|
||||
}
|
||||
|
||||
func ParseOptions() (*options, error) {
|
||||
opts := &options{}
|
||||
flagSet := goflags.NewFlagSet()
|
||||
flagSet.SetDescription(`sign-templates is a utility to perform template signature`)
|
||||
|
||||
flagSet.CreateGroup("sign", "sign",
|
||||
flagSet.StringSliceVarP(&opts.Templates, "templates", "t", nil, "templates files/folders to sign", goflags.CommaSeparatedStringSliceOptions),
|
||||
flagSet.StringVarP(&opts.Algorithm, "algorithm", "a", "ecdsa", "signature algorithm (ecdsa, rsa)"),
|
||||
flagSet.StringVarP(&opts.PrivateKeyName, "private-key", "prk", "", "private key env var name or file location"),
|
||||
flagSet.StringVarP(&opts.PrivateKeyPassPhrase, "private-key-pass", "prkp", "", "private key passphrase env var name or file location"),
|
||||
flagSet.StringVarP(&opts.PublicKeyName, "public-key", "puk", "", "public key env var name or file location"),
|
||||
)
|
||||
|
||||
if err := flagSet.Parse(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
opts, err := ParseOptions()
|
||||
if err != nil {
|
||||
log.Fatalf("couldn't parse options: %s\n", err)
|
||||
}
|
||||
|
||||
var algo signer.AlgorithmType
|
||||
switch opts.Algorithm {
|
||||
case "rsa":
|
||||
algo = signer.RSA
|
||||
case "ecdsa":
|
||||
algo = signer.ECDSA
|
||||
default:
|
||||
log.Fatal("unknown algorithm type")
|
||||
}
|
||||
|
||||
signerOptions := &signer.Options{
|
||||
PrivateKeyName: opts.PrivateKeyName,
|
||||
PassphraseName: opts.PrivateKeyPassPhrase,
|
||||
PublicKeyName: opts.PublicKeyName,
|
||||
Algorithm: algo,
|
||||
}
|
||||
sign, err := signer.New(signerOptions)
|
||||
if err != nil {
|
||||
log.Fatalf("couldn't create crypto engine: %s\n", err)
|
||||
}
|
||||
|
||||
for _, templateItem := range opts.Templates {
|
||||
if err := processItem(sign, templateItem); err != nil {
|
||||
log.Fatalf("Could not walk directory: %s\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func processItem(sign *signer.Signer, item string) error {
|
||||
return filepath.WalkDir(item, func(iterItem string, d fs.DirEntry, err error) error {
|
||||
if err != nil || d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := processFile(sign, iterItem); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func processFile(sign *signer.Signer, filePath string) error {
|
||||
ext := filepath.Ext(filePath)
|
||||
if !stringsutil.EqualFoldAny(ext, ".yaml") {
|
||||
return nil
|
||||
}
|
||||
err := signTemplate(sign, filePath)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "could not sign template: %s", filePath)
|
||||
}
|
||||
|
||||
ok, err := verifyTemplateSignature(sign, filePath)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "could not verify template: %s", filePath)
|
||||
}
|
||||
if !ok {
|
||||
return errors.Wrapf(err, "template signature doesn't match: %s", filePath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func appendToFile(path string, data []byte, digest string) error {
|
||||
file, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if _, err := file.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := file.WriteString("\n" + digest); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func signTemplate(sign *signer.Signer, templatePath string) error {
|
||||
templateData, err := os.ReadFile(templatePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
signatureData, err := signer.Sign(sign, templateData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dataWithoutSignature := signer.RemoveSignatureFromData(templateData)
|
||||
return appendToFile(templatePath, dataWithoutSignature, signatureData)
|
||||
}
|
||||
|
||||
func verifyTemplateSignature(sign *signer.Signer, templatePath string) (bool, error) {
|
||||
templateData, err := os.ReadFile(templatePath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return signer.Verify(sign, templateData)
|
||||
}
|
|
@ -209,7 +209,7 @@ require (
|
|||
go.etcd.io/bbolt v1.3.7 // indirect
|
||||
go.uber.org/zap v1.23.0 // indirect
|
||||
goftp.io/server/v2 v2.0.0 // indirect
|
||||
golang.org/x/crypto v0.5.0 // indirect
|
||||
golang.org/x/crypto v0.5.0
|
||||
golang.org/x/exp v0.0.0-20230206171751-46f607a40771
|
||||
golang.org/x/mod v0.8.0 // indirect
|
||||
golang.org/x/sys v0.5.0 // indirect
|
||||
|
|
|
@ -49,20 +49,18 @@ const (
|
|||
letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
)
|
||||
|
||||
var invalidDslFunctionError = errors.New("invalid DSL function signature")
|
||||
var invalidDslFunctionMessageTemplate = "%w. correct method signature %q"
|
||||
|
||||
var dslFunctions map[string]dslFunction
|
||||
|
||||
var (
|
||||
ErrinvalidDslFunction = errors.New("invalid DSL function signature")
|
||||
dslFunctions map[string]dslFunction
|
||||
|
||||
// FunctionNames is a list of function names for expression evaluation usages
|
||||
FunctionNames []string
|
||||
// HelperFunctions is a pre-compiled list of govaluate DSL functions
|
||||
HelperFunctions map[string]govaluate.ExpressionFunction
|
||||
)
|
||||
|
||||
var functionSignaturePattern = regexp.MustCompile(`(\w+)\s*\((?:([\w\d,\s]+)\s+([.\w\d{}&*]+))?\)([\s.\w\d{}&*]+)?`)
|
||||
var dateFormatRegex = regexp.MustCompile("%([A-Za-z])")
|
||||
functionSignaturePattern = regexp.MustCompile(`(\w+)\s*\((?:([\w\d,\s]+)\s+([.\w\d{}&*]+))?\)([\s.\w\d{}&*]+)?`)
|
||||
dateFormatRegex = regexp.MustCompile("%([A-Za-z])")
|
||||
)
|
||||
|
||||
type dslFunction struct {
|
||||
signatures []string
|
||||
|
@ -98,7 +96,7 @@ func init() {
|
|||
func(args ...interface{}) (interface{}, error) {
|
||||
argCount := len(args)
|
||||
if argCount == 0 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
} else if argCount == 1 {
|
||||
runes := []rune(types.ToString(args[0]))
|
||||
sort.Slice(runes, func(i int, j int) bool {
|
||||
|
@ -122,7 +120,7 @@ func init() {
|
|||
func(args ...interface{}) (interface{}, error) {
|
||||
argCount := len(args)
|
||||
if argCount == 0 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
} else if argCount == 1 {
|
||||
builder := &strings.Builder{}
|
||||
visited := make(map[rune]struct{})
|
||||
|
@ -149,7 +147,7 @@ func init() {
|
|||
"repeat": makeDslFunction(2, func(args ...interface{}) (interface{}, error) {
|
||||
count, err := strconv.Atoi(types.ToString(args[1]))
|
||||
if err != nil {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
return strings.Repeat(types.ToString(args[0]), count), nil
|
||||
}),
|
||||
|
@ -243,7 +241,7 @@ func init() {
|
|||
|
||||
argumentsSize := len(arguments)
|
||||
if argumentsSize < 1 && argumentsSize > 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
|
||||
currentTime, err := getCurrentTimeFromUserInput(arguments)
|
||||
|
@ -353,7 +351,7 @@ func init() {
|
|||
"(str string, prefix ...string) bool",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) < 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
for _, prefix := range args[1:] {
|
||||
if strings.HasPrefix(types.ToString(args[0]), types.ToString(prefix)) {
|
||||
|
@ -366,7 +364,7 @@ func init() {
|
|||
"line_starts_with": makeDslWithOptionalArgsFunction(
|
||||
"(str string, prefix ...string) bool", func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) < 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
for _, line := range strings.Split(types.ToString(args[0]), "\n") {
|
||||
for _, prefix := range args[1:] {
|
||||
|
@ -382,7 +380,7 @@ func init() {
|
|||
"(str string, suffix ...string) bool",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) < 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
for _, suffix := range args[1:] {
|
||||
if strings.HasSuffix(types.ToString(args[0]), types.ToString(suffix)) {
|
||||
|
@ -395,7 +393,7 @@ func init() {
|
|||
"line_ends_with": makeDslWithOptionalArgsFunction(
|
||||
"(str string, suffix ...string) bool", func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) < 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
for _, line := range strings.Split(types.ToString(args[0]), "\n") {
|
||||
for _, suffix := range args[1:] {
|
||||
|
@ -436,11 +434,11 @@ func init() {
|
|||
separator := types.ToString(arguments[1])
|
||||
count, err := strconv.Atoi(types.ToString(arguments[2]))
|
||||
if err != nil {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
return strings.SplitN(input, separator, count), nil
|
||||
} else {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
},
|
||||
),
|
||||
|
@ -450,7 +448,7 @@ func init() {
|
|||
func(arguments ...interface{}) (interface{}, error) {
|
||||
argumentsSize := len(arguments)
|
||||
if argumentsSize < 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
} else if argumentsSize == 2 {
|
||||
separator := types.ToString(arguments[0])
|
||||
elements, ok := arguments[1].([]string)
|
||||
|
@ -495,7 +493,7 @@ func init() {
|
|||
|
||||
argSize := len(args)
|
||||
if argSize != 0 && argSize != 1 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
|
||||
if argSize >= 1 {
|
||||
|
@ -516,7 +514,7 @@ func init() {
|
|||
|
||||
argSize := len(args)
|
||||
if argSize < 1 || argSize > 3 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
|
||||
length = int(args[0].(float64))
|
||||
|
@ -538,7 +536,7 @@ func init() {
|
|||
|
||||
argSize := len(args)
|
||||
if argSize != 1 && argSize != 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
|
||||
length = int(args[0].(float64))
|
||||
|
@ -558,7 +556,7 @@ func init() {
|
|||
|
||||
argSize := len(args)
|
||||
if argSize != 1 && argSize != 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
|
||||
length = int(args[0].(float64))
|
||||
|
@ -575,7 +573,7 @@ func init() {
|
|||
func(args ...interface{}) (interface{}, error) {
|
||||
argSize := len(args)
|
||||
if argSize != 1 && argSize != 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
|
||||
length := int(args[0].(float64))
|
||||
|
@ -594,7 +592,7 @@ func init() {
|
|||
func(args ...interface{}) (interface{}, error) {
|
||||
argSize := len(args)
|
||||
if argSize > 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
|
||||
min := 0
|
||||
|
@ -613,7 +611,7 @@ func init() {
|
|||
"(cidr ...string) string",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) == 0 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
var cidrs []string
|
||||
for _, arg := range args {
|
||||
|
@ -635,7 +633,7 @@ func init() {
|
|||
|
||||
argSize := len(args)
|
||||
if argSize != 0 && argSize != 1 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
} else if argSize == 1 {
|
||||
seconds = int(args[0].(float64))
|
||||
}
|
||||
|
@ -670,7 +668,7 @@ func init() {
|
|||
}
|
||||
return parsedTime.Unix(), err
|
||||
} else {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
},
|
||||
),
|
||||
|
@ -678,7 +676,7 @@ func init() {
|
|||
"(seconds uint)",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
seconds := args[0].(float64)
|
||||
time.Sleep(time.Duration(seconds) * time.Second)
|
||||
|
@ -689,7 +687,7 @@ func init() {
|
|||
"(firstVersion, constraints ...string) bool",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) < 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
|
||||
firstParsed, parseErr := version.NewVersion(types.ToString(args[0]))
|
||||
|
@ -713,7 +711,7 @@ func init() {
|
|||
"(args ...interface{})",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) < 1 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
gologger.Info().Msgf("print_debug value: %s", fmt.Sprint(args))
|
||||
return true, nil
|
||||
|
@ -753,7 +751,7 @@ func init() {
|
|||
"(str string, start int, optionalEnd int)",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) < 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
argStr := types.ToString(args[0])
|
||||
start, err := strconv.Atoi(types.ToString(args[1]))
|
||||
|
@ -817,7 +815,7 @@ func init() {
|
|||
argSize := len(args)
|
||||
|
||||
if argSize < 1 || argSize > 4 {
|
||||
return nil, invalidDslFunctionError
|
||||
return nil, ErrinvalidDslFunction
|
||||
}
|
||||
jsonString := args[0].(string)
|
||||
|
||||
|
@ -968,7 +966,7 @@ func makeDslFunction(numberOfParameters int, dslFunctionLogic govaluate.Expressi
|
|||
[]string{signature},
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != numberOfParameters {
|
||||
return nil, fmt.Errorf(invalidDslFunctionMessageTemplate, invalidDslFunctionError, signature)
|
||||
return nil, fmt.Errorf("%w. correct method signature %q", ErrinvalidDslFunction, signature)
|
||||
}
|
||||
return dslFunctionLogic(args...)
|
||||
},
|
||||
|
|
|
@ -53,7 +53,7 @@ func TestDSLGzipSerialize(t *testing.T) {
|
|||
|
||||
func TestDslFunctionSignatures(t *testing.T) {
|
||||
createSignatureError := func(signature string) string {
|
||||
return fmt.Errorf(invalidDslFunctionMessageTemplate, invalidDslFunctionError, signature).Error()
|
||||
return fmt.Errorf("%w. correct method signature %q", ErrinvalidDslFunction, signature).Error()
|
||||
}
|
||||
|
||||
toUpperSignatureError := createSignatureError("to_upper(arg1 interface{}) interface{}")
|
||||
|
|
|
@ -9,14 +9,15 @@ import (
|
|||
"github.com/projectdiscovery/nuclei/v2/pkg/catalog/loader/filter"
|
||||
"github.com/projectdiscovery/nuclei/v2/pkg/templates"
|
||||
"github.com/projectdiscovery/nuclei/v2/pkg/templates/cache"
|
||||
"github.com/projectdiscovery/nuclei/v2/pkg/templates/signer"
|
||||
"github.com/projectdiscovery/nuclei/v2/pkg/utils"
|
||||
"github.com/projectdiscovery/nuclei/v2/pkg/utils/stats"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
mandatoryFieldMissingTemplate = "mandatory '%s' field is missing"
|
||||
invalidFieldFormatTemplate = "invalid field format for '%s' (allowed format is %s)"
|
||||
errMandatoryFieldMissingFmt = "mandatory '%s' field is missing"
|
||||
errInvalidFieldFmt = "invalid field format for '%s' (allowed format is %s)"
|
||||
)
|
||||
|
||||
// LoadTemplate returns true if the template is valid and matches the filtering criteria.
|
||||
|
@ -71,17 +72,17 @@ func validateTemplateFields(template *templates.Template) error {
|
|||
var errors []string
|
||||
|
||||
if utils.IsBlank(info.Name) {
|
||||
errors = append(errors, fmt.Sprintf(mandatoryFieldMissingTemplate, "name"))
|
||||
errors = append(errors, fmt.Sprintf(errMandatoryFieldMissingFmt, "name"))
|
||||
}
|
||||
|
||||
if info.Authors.IsEmpty() {
|
||||
errors = append(errors, fmt.Sprintf(mandatoryFieldMissingTemplate, "author"))
|
||||
errors = append(errors, fmt.Sprintf(errMandatoryFieldMissingFmt, "author"))
|
||||
}
|
||||
|
||||
if template.ID == "" {
|
||||
errors = append(errors, fmt.Sprintf(mandatoryFieldMissingTemplate, "id"))
|
||||
errors = append(errors, fmt.Sprintf(errMandatoryFieldMissingFmt, "id"))
|
||||
} else if !templateIDRegexp.MatchString(template.ID) {
|
||||
errors = append(errors, fmt.Sprintf(invalidFieldFormatTemplate, "id", templateIDRegexp.String()))
|
||||
errors = append(errors, fmt.Sprintf(errInvalidFieldFmt, "id", templateIDRegexp.String()))
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
|
@ -105,7 +106,6 @@ const (
|
|||
)
|
||||
|
||||
func init() {
|
||||
|
||||
parsedTemplatesCache = cache.New()
|
||||
|
||||
stats.NewEntry(SyntaxWarningStats, "Found %d templates with syntax warning (use -validate flag for further examination)")
|
||||
|
@ -124,6 +124,12 @@ func ParseTemplate(templatePath string, catalog catalog.Catalog) (*templates.Tem
|
|||
}
|
||||
|
||||
template := &templates.Template{}
|
||||
|
||||
// check if the template is verified
|
||||
if signer.DefaultVerifier != nil {
|
||||
template.Verified, _ = signer.Verify(signer.DefaultVerifier, data)
|
||||
}
|
||||
|
||||
if NoStrictSyntax {
|
||||
err = yaml.Unmarshal(data, template)
|
||||
} else {
|
||||
|
@ -133,6 +139,7 @@ func ParseTemplate(templatePath string, catalog catalog.Catalog) (*templates.Tem
|
|||
stats.Increment(SyntaxErrorStats)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsedTemplatesCache.Store(templatePath, template, nil)
|
||||
return template, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
package signer
|
||||
|
||||
var DefaultVerifier *Signer
|
||||
|
||||
func init() {
|
||||
DefaultVerifier, _ = NewVerifier(&Options{PublicKeyData: ecdsaPublicKey, Algorithm: ECDSA})
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
package signer
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
//go:embed ecdsa_public_key.pem
|
||||
var ecdsaPublicKey []byte
|
|
@ -0,0 +1,4 @@
|
|||
-----BEGIN PUBLIC KEY-----
|
||||
xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
-----END PUBLIC KEY-----
|
|
@ -0,0 +1,34 @@
|
|||
package signer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/big"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
type AlgorithmType uint8
|
||||
|
||||
const (
|
||||
RSA AlgorithmType = iota
|
||||
ECDSA
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
PrivateKeyName string
|
||||
PrivateKeyData []byte
|
||||
PassphraseName string
|
||||
PassphraseData []byte
|
||||
PublicKeyName string
|
||||
PublicKeyData []byte
|
||||
Algorithm AlgorithmType
|
||||
}
|
||||
|
||||
type EcdsaSignature struct {
|
||||
R *big.Int
|
||||
S *big.Int
|
||||
}
|
||||
|
||||
var (
|
||||
ReDigest = regexp.MustCompile(`(?m)^#\sdigest:\s.+$`)
|
||||
ErrUnknownAlgorithm = errors.New("unknown algorithm")
|
||||
)
|
|
@ -0,0 +1,234 @@
|
|||
package signer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/gob"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
fileutil "github.com/projectdiscovery/utils/file"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Signer struct {
|
||||
options *Options
|
||||
sshSigner ssh.Signer
|
||||
sshVerifier ssh.PublicKey
|
||||
ecdsaSigner *ecdsa.PrivateKey
|
||||
ecdsaVerifier *ecdsa.PublicKey
|
||||
}
|
||||
|
||||
func New(options *Options) (*Signer, error) {
|
||||
var (
|
||||
privateKeyData, passphraseData, publicKeyData []byte
|
||||
err error
|
||||
)
|
||||
if options.PrivateKeyName != "" {
|
||||
privateKeyData, err = readKeyFromFileOrEnv(options.PrivateKeyName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
privateKeyData = options.PrivateKeyData
|
||||
}
|
||||
|
||||
if options.PassphraseName != "" {
|
||||
passphraseData = readKeyFromFileOrEnvWithDefault(options.PassphraseName, []byte{})
|
||||
} else {
|
||||
passphraseData = options.PassphraseData
|
||||
}
|
||||
|
||||
if options.PublicKeyName != "" {
|
||||
publicKeyData, err = readKeyFromFileOrEnv(options.PublicKeyName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
publicKeyData = options.PublicKeyData
|
||||
}
|
||||
|
||||
signer := &Signer{options: options}
|
||||
|
||||
switch signer.options.Algorithm {
|
||||
case RSA:
|
||||
signer.sshSigner, signer.sshVerifier, err = parseRsa(privateKeyData, publicKeyData, passphraseData)
|
||||
case ECDSA:
|
||||
signer.ecdsaSigner, signer.ecdsaVerifier, err = parseECDSA(privateKeyData, publicKeyData)
|
||||
default:
|
||||
return nil, ErrUnknownAlgorithm
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
func NewVerifier(options *Options) (*Signer, error) {
|
||||
var (
|
||||
publicKeyData []byte
|
||||
err error
|
||||
)
|
||||
if options.PublicKeyName != "" {
|
||||
publicKeyData, err = readKeyFromFileOrEnv(options.PrivateKeyName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
publicKeyData = options.PublicKeyData
|
||||
}
|
||||
|
||||
signer := &Signer{options: options}
|
||||
|
||||
switch signer.options.Algorithm {
|
||||
case RSA:
|
||||
signer.sshVerifier, err = parseRsaPublicKey(publicKeyData)
|
||||
case ECDSA:
|
||||
signer.ecdsaVerifier, err = parseECDSAPublicKey(publicKeyData)
|
||||
default:
|
||||
return nil, ErrUnknownAlgorithm
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
func (s *Signer) Sign(data []byte) ([]byte, error) {
|
||||
dataHash := sha256.Sum256(data)
|
||||
switch s.options.Algorithm {
|
||||
case RSA:
|
||||
sshSignature, err := s.sshSigner.Sign(rand.Reader, dataHash[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var signatureData bytes.Buffer
|
||||
if err := gob.NewEncoder(&signatureData).Encode(sshSignature); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return signatureData.Bytes(), nil
|
||||
case ECDSA:
|
||||
r, s, err := ecdsa.Sign(rand.Reader, s.ecdsaSigner, dataHash[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ecdsaSignature := &EcdsaSignature{R: r, S: s}
|
||||
var signatureData bytes.Buffer
|
||||
if err := gob.NewEncoder(&signatureData).Encode(ecdsaSignature); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return signatureData.Bytes(), nil
|
||||
default:
|
||||
return nil, ErrUnknownAlgorithm
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Signer) Verify(data, signatureData []byte) (bool, error) {
|
||||
dataHash := sha256.Sum256(data)
|
||||
switch s.options.Algorithm {
|
||||
case RSA:
|
||||
signature := &ssh.Signature{}
|
||||
if err := gob.NewDecoder(bytes.NewReader(signatureData)).Decode(&signature); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := s.sshVerifier.Verify(dataHash[:], signature); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
case ECDSA:
|
||||
signature := &EcdsaSignature{}
|
||||
if err := gob.NewDecoder(bytes.NewReader(signatureData)).Decode(&signature); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return ecdsa.Verify(s.ecdsaVerifier, dataHash[:], signature.R, signature.S), nil
|
||||
default:
|
||||
return false, ErrUnknownAlgorithm
|
||||
}
|
||||
}
|
||||
|
||||
func parseRsa(privateKeyData, passphraseData, publicKeyData []byte) (ssh.Signer, ssh.PublicKey, error) {
|
||||
privateKey, err := parseRsaPrivateKey(privateKeyData, passphraseData)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
publicKey, err := parseRsaPublicKey(publicKeyData)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return privateKey, publicKey, nil
|
||||
}
|
||||
|
||||
func parseRsaPrivateKey(privateKeyData, passphraseData []byte) (ssh.Signer, error) {
|
||||
if len(passphraseData) > 0 {
|
||||
return ssh.ParsePrivateKeyWithPassphrase(privateKeyData, passphraseData)
|
||||
}
|
||||
return ssh.ParsePrivateKey(privateKeyData)
|
||||
}
|
||||
|
||||
func parseRsaPublicKey(publicKeyData []byte) (ssh.PublicKey, error) {
|
||||
publicKey, _, _, _, err := ssh.ParseAuthorizedKey(publicKeyData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return publicKey, nil
|
||||
}
|
||||
|
||||
func parseECDSA(privateKeyData, publicKeyData []byte) (*ecdsa.PrivateKey, *ecdsa.PublicKey, error) {
|
||||
privateKey, err := parseECDSAPrivateKey(privateKeyData)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
publicKey, err := parseECDSAPublicKey(publicKeyData)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return privateKey, publicKey, nil
|
||||
}
|
||||
|
||||
func parseECDSAPrivateKey(privateKeyData []byte) (*ecdsa.PrivateKey, error) {
|
||||
blockPriv, _ := pem.Decode(privateKeyData)
|
||||
return x509.ParseECPrivateKey(blockPriv.Bytes)
|
||||
}
|
||||
|
||||
func parseECDSAPublicKey(publicKeyData []byte) (*ecdsa.PublicKey, error) {
|
||||
blockPub, _ := pem.Decode(publicKeyData)
|
||||
genericPublicKey, err := x509.ParsePKIXPublicKey(blockPub.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if publicKey, ok := genericPublicKey.(*ecdsa.PublicKey); ok {
|
||||
return publicKey, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("couldn't parse ecdsa public key")
|
||||
}
|
||||
|
||||
func readKeyFromFileOrEnvWithDefault(keypath string, defaultValue []byte) []byte {
|
||||
keyValue, err := readKeyFromFileOrEnv(keypath)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return keyValue
|
||||
}
|
||||
|
||||
func readKeyFromFileOrEnv(keypath string) ([]byte, error) {
|
||||
if fileutil.FileExists(keypath) {
|
||||
return os.ReadFile(keypath)
|
||||
}
|
||||
if keydata := os.Getenv(keypath); keydata != "" {
|
||||
return []byte(keydata), nil
|
||||
}
|
||||
return nil, fmt.Errorf("Private key not found in file or environment variable: %s", keypath)
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
package signer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
SignaturePattern = "# digest: "
|
||||
SignatureFmt = SignaturePattern + "%x"
|
||||
)
|
||||
|
||||
func RemoveSignatureFromData(data []byte) []byte {
|
||||
return bytes.Trim(ReDigest.ReplaceAll(data, []byte("")), "\n")
|
||||
}
|
||||
|
||||
func Sign(sign *Signer, data []byte) (string, error) {
|
||||
if sign == nil {
|
||||
return "", errors.New("invalid nil signer")
|
||||
}
|
||||
cleanedData := RemoveSignatureFromData(data)
|
||||
signatureData, err := sign.Sign(cleanedData)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf(SignatureFmt, signatureData), nil
|
||||
}
|
||||
|
||||
func Verify(sign *Signer, data []byte) (bool, error) {
|
||||
if sign == nil {
|
||||
return false, errors.New("invalid nil verifier")
|
||||
}
|
||||
digestData := ReDigest.Find(data)
|
||||
if len(digestData) == 0 {
|
||||
return false, errors.New("digest not found")
|
||||
}
|
||||
|
||||
digestData = bytes.TrimSpace(bytes.TrimPrefix(digestData, []byte(SignaturePattern)))
|
||||
digest, err := hex.DecodeString(string(digestData))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
cleanedData := RemoveSignatureFromData(data)
|
||||
|
||||
return sign.Verify(cleanedData, digest)
|
||||
}
|
|
@ -109,6 +109,9 @@ type Template struct {
|
|||
Executer protocols.Executer `yaml:"-" json:"-"`
|
||||
|
||||
Path string `yaml:"-" json:"-"`
|
||||
|
||||
// Verified defines if the template signature is digitally verified
|
||||
Verified bool `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// TemplateProtocols is a list of accepted template protocols
|
||||
|
|
Loading…
Reference in New Issue