mirror of https://github.com/daffainfo/nuclei.git
feat: Improve DSL function UX #1295
parent
dfe284664c
commit
c61ec5f673
|
@ -15,6 +15,7 @@ import (
|
|||
"math/rand"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -26,369 +27,378 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
numbers = "1234567890"
|
||||
letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
withCutSetArgsSize = 2
|
||||
withBaseRandArgsSize = 3
|
||||
withMaxRandArgsSize = withCutSetArgsSize
|
||||
numbers = "1234567890"
|
||||
letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
)
|
||||
|
||||
var ErrDSLArguments = errors.New("invalid arguments provided to dsl")
|
||||
var invalidDslFunctionError = errors.New("invalid DSL function signature")
|
||||
var invalidDslFunctionMessageTemplate = "correct method signature '%s'. %w"
|
||||
|
||||
var functions = map[string]govaluate.ExpressionFunction{
|
||||
"len": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
length := len(types.ToString(args[0]))
|
||||
return float64(length), nil
|
||||
},
|
||||
"toupper": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return strings.ToUpper(types.ToString(args[0])), nil
|
||||
},
|
||||
"tolower": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return strings.ToLower(types.ToString(args[0])), nil
|
||||
},
|
||||
"replace": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 3 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return strings.ReplaceAll(types.ToString(args[0]), types.ToString(args[1]), types.ToString(args[2])), nil
|
||||
},
|
||||
"replace_regex": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 3 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
compiled, err := regexp.Compile(types.ToString(args[1]))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return compiled.ReplaceAllString(types.ToString(args[0]), types.ToString(args[2])), nil
|
||||
},
|
||||
"trim": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return strings.Trim(types.ToString(args[0]), types.ToString(args[1])), nil
|
||||
},
|
||||
"trimleft": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return strings.TrimLeft(types.ToString(args[0]), types.ToString(args[1])), nil
|
||||
},
|
||||
"trimright": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return strings.TrimRight(types.ToString(args[0]), types.ToString(args[1])), nil
|
||||
},
|
||||
"trimspace": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return strings.TrimSpace(types.ToString(args[0])), nil
|
||||
},
|
||||
"trimprefix": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return strings.TrimPrefix(types.ToString(args[0]), types.ToString(args[1])), nil
|
||||
},
|
||||
"trimsuffix": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return strings.TrimSuffix(types.ToString(args[0]), types.ToString(args[1])), nil
|
||||
},
|
||||
"reverse": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return reverseString(types.ToString(args[0])), nil
|
||||
},
|
||||
// encoding
|
||||
"base64": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
sEnc := base64.StdEncoding.EncodeToString([]byte(types.ToString(args[0])))
|
||||
var dslFunctions map[string]dslFunction
|
||||
|
||||
return sEnc, nil
|
||||
},
|
||||
"gzip": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
buffer := &bytes.Buffer{}
|
||||
writer := gzip.NewWriter(buffer)
|
||||
if _, err := writer.Write([]byte(args[0].(string))); err != nil {
|
||||
return "", err
|
||||
}
|
||||
_ = writer.Close()
|
||||
type dslFunction struct {
|
||||
signature string
|
||||
expressFunc govaluate.ExpressionFunction
|
||||
}
|
||||
|
||||
return buffer.String(), nil
|
||||
},
|
||||
// python encodes to base64 with lines of 76 bytes terminated by new line "\n"
|
||||
"base64_py": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
sEnc := base64.StdEncoding.EncodeToString([]byte(types.ToString(args[0])))
|
||||
return deserialization.InsertInto(sEnc, 76, '\n'), nil
|
||||
},
|
||||
"base64_decode": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return base64.StdEncoding.DecodeString(types.ToString(args[0]))
|
||||
},
|
||||
"url_encode": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return url.QueryEscape(types.ToString(args[0])), nil
|
||||
},
|
||||
"url_decode": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return url.QueryUnescape(types.ToString(args[0]))
|
||||
},
|
||||
"hex_encode": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return hex.EncodeToString([]byte(types.ToString(args[0]))), nil
|
||||
},
|
||||
"hex_decode": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
hx, _ := hex.DecodeString(types.ToString(args[0]))
|
||||
return string(hx), nil
|
||||
},
|
||||
"html_escape": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return html.EscapeString(types.ToString(args[0])), nil
|
||||
},
|
||||
"html_unescape": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return html.UnescapeString(types.ToString(args[0])), nil
|
||||
},
|
||||
// hashing
|
||||
"md5": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
hash := md5.Sum([]byte(types.ToString(args[0])))
|
||||
func init() {
|
||||
tempDslFunctions := map[string]func(string) dslFunction{
|
||||
"len": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
length := len(types.ToString(args[0]))
|
||||
return float64(length), nil
|
||||
}),
|
||||
"toupper": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
return strings.ToUpper(types.ToString(args[0])), nil
|
||||
}),
|
||||
"tolower": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
return strings.ToLower(types.ToString(args[0])), nil
|
||||
}),
|
||||
"replace": makeDslFunction(3, func(args ...interface{}) (interface{}, error) {
|
||||
return strings.ReplaceAll(types.ToString(args[0]), types.ToString(args[1]), types.ToString(args[2])), nil
|
||||
}),
|
||||
"replace_regex": makeDslFunction(3, func(args ...interface{}) (interface{}, error) {
|
||||
compiled, err := regexp.Compile(types.ToString(args[1]))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return compiled.ReplaceAllString(types.ToString(args[0]), types.ToString(args[2])), nil
|
||||
}),
|
||||
"trim": makeDslFunction(2, func(args ...interface{}) (interface{}, error) {
|
||||
return strings.Trim(types.ToString(args[0]), types.ToString(args[1])), nil
|
||||
}),
|
||||
"trimleft": makeDslFunction(2, func(args ...interface{}) (interface{}, error) {
|
||||
return strings.TrimLeft(types.ToString(args[0]), types.ToString(args[1])), nil
|
||||
}),
|
||||
"trimright": makeDslFunction(2, func(args ...interface{}) (interface{}, error) {
|
||||
return strings.TrimRight(types.ToString(args[0]), types.ToString(args[1])), nil
|
||||
}),
|
||||
"trimspace": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
return strings.TrimSpace(types.ToString(args[0])), nil
|
||||
}),
|
||||
"trimprefix": makeDslFunction(2, func(args ...interface{}) (interface{}, error) {
|
||||
return strings.TrimPrefix(types.ToString(args[0]), types.ToString(args[1])), nil
|
||||
}),
|
||||
"trimsuffix": makeDslFunction(2, func(args ...interface{}) (interface{}, error) {
|
||||
return strings.TrimSuffix(types.ToString(args[0]), types.ToString(args[1])), nil
|
||||
}),
|
||||
"reverse": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
return reverseString(types.ToString(args[0])), nil
|
||||
}),
|
||||
"base64": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
return base64.StdEncoding.EncodeToString([]byte(types.ToString(args[0]))), nil
|
||||
}),
|
||||
"gzip": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
buffer := &bytes.Buffer{}
|
||||
writer := gzip.NewWriter(buffer)
|
||||
if _, err := writer.Write([]byte(args[0].(string))); err != nil {
|
||||
return "", err
|
||||
}
|
||||
_ = writer.Close()
|
||||
|
||||
return hex.EncodeToString(hash[:]), nil
|
||||
},
|
||||
"sha256": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
h := sha256.New()
|
||||
if _, err := h.Write([]byte(types.ToString(args[0]))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hex.EncodeToString(h.Sum(nil)), nil
|
||||
},
|
||||
"sha1": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
h := sha1.New()
|
||||
if _, err := h.Write([]byte(types.ToString(args[0]))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hex.EncodeToString(h.Sum(nil)), nil
|
||||
},
|
||||
"mmh3": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return fmt.Sprintf("%d", int32(murmur3.Sum32WithSeed([]byte(types.ToString(args[0])), 0))), nil
|
||||
},
|
||||
// search
|
||||
"contains": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
return strings.Contains(types.ToString(args[0]), types.ToString(args[1])), nil
|
||||
},
|
||||
"regex": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
compiled, err := regexp.Compile(types.ToString(args[0]))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return compiled.MatchString(types.ToString(args[1])), nil
|
||||
},
|
||||
// random generators
|
||||
"rand_char": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
chars := letters + numbers
|
||||
bad := ""
|
||||
if len(args) >= 1 {
|
||||
chars = types.ToString(args[0])
|
||||
}
|
||||
if len(args) >= withCutSetArgsSize {
|
||||
bad = types.ToString(args[1])
|
||||
}
|
||||
chars = trimAll(chars, bad)
|
||||
return chars[rand.Intn(len(chars))], nil
|
||||
},
|
||||
"rand_base": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 3 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
l := 0
|
||||
bad := ""
|
||||
base := letters + numbers
|
||||
return buffer.String(), nil
|
||||
}),
|
||||
"base64_py": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
stdBase64 := base64.StdEncoding.EncodeToString([]byte(types.ToString(args[0])))
|
||||
return deserialization.InsertInto(stdBase64, 76, '\n'), nil
|
||||
}),
|
||||
"base64_decode": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
return base64.StdEncoding.DecodeString(types.ToString(args[0]))
|
||||
}),
|
||||
"url_encode": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
return url.QueryEscape(types.ToString(args[0])), nil
|
||||
}),
|
||||
"url_decode": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
return url.QueryUnescape(types.ToString(args[0]))
|
||||
}),
|
||||
"hex_encode": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
return hex.EncodeToString([]byte(types.ToString(args[0]))), nil
|
||||
}),
|
||||
"hex_decode": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
decodeString, err := hex.DecodeString(types.ToString(args[0]))
|
||||
return decodeString, err
|
||||
}),
|
||||
"html_escape": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
return html.EscapeString(types.ToString(args[0])), nil
|
||||
}),
|
||||
"html_unescape": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
return html.UnescapeString(types.ToString(args[0])), nil
|
||||
}),
|
||||
"md5": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
hash := md5.Sum([]byte(types.ToString(args[0])))
|
||||
return hex.EncodeToString(hash[:]), nil
|
||||
}),
|
||||
"sha256": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
hash := sha256.New()
|
||||
if _, err := hash.Write([]byte(types.ToString(args[0]))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hex.EncodeToString(hash.Sum(nil)), nil
|
||||
}),
|
||||
"sha1": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
hash := sha1.New()
|
||||
if _, err := hash.Write([]byte(types.ToString(args[0]))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hex.EncodeToString(hash.Sum(nil)), nil
|
||||
}),
|
||||
"mmh3": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
|
||||
return fmt.Sprintf("%d", int32(murmur3.Sum32WithSeed([]byte(types.ToString(args[0])), 0))), nil
|
||||
}),
|
||||
"contains": makeDslFunction(2, func(args ...interface{}) (interface{}, error) {
|
||||
return strings.Contains(types.ToString(args[0]), types.ToString(args[1])), nil
|
||||
}),
|
||||
"regex": makeDslFunction(2, func(args ...interface{}) (interface{}, error) {
|
||||
compiled, err := regexp.Compile(types.ToString(args[0]))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return compiled.MatchString(types.ToString(args[1])), nil
|
||||
}),
|
||||
"rand_char": makeDslWithOptionalArgsFunction(
|
||||
"(optionalCharSet, optionalBachChars) string",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
charSet := letters + numbers
|
||||
badChars := ""
|
||||
|
||||
if len(args) >= 1 {
|
||||
l = int(args[0].(float64))
|
||||
}
|
||||
if len(args) >= withCutSetArgsSize {
|
||||
bad = types.ToString(args[1])
|
||||
}
|
||||
if len(args) >= withBaseRandArgsSize {
|
||||
base = types.ToString(args[2])
|
||||
}
|
||||
base = trimAll(base, bad)
|
||||
return randSeq(base, l), nil
|
||||
},
|
||||
"rand_text_alphanumeric": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
l := 0
|
||||
bad := ""
|
||||
chars := letters + numbers
|
||||
argSize := len(args)
|
||||
if argSize != 1 && argSize != 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
}
|
||||
|
||||
if len(args) >= 1 {
|
||||
l = int(args[0].(float64))
|
||||
}
|
||||
if len(args) >= withCutSetArgsSize {
|
||||
bad = types.ToString(args[1])
|
||||
}
|
||||
chars = trimAll(chars, bad)
|
||||
return randSeq(chars, l), nil
|
||||
},
|
||||
"rand_text_alpha": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
l := 0
|
||||
bad := ""
|
||||
chars := letters
|
||||
if argSize >= 1 {
|
||||
charSet = types.ToString(args[0])
|
||||
}
|
||||
if argSize == 2 {
|
||||
badChars = types.ToString(args[1])
|
||||
}
|
||||
|
||||
if len(args) >= 1 {
|
||||
l = int(args[0].(float64))
|
||||
}
|
||||
if len(args) >= withCutSetArgsSize {
|
||||
bad = types.ToString(args[1])
|
||||
}
|
||||
chars = trimAll(chars, bad)
|
||||
return randSeq(chars, l), nil
|
||||
},
|
||||
"rand_text_numeric": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
l := 0
|
||||
bad := ""
|
||||
chars := numbers
|
||||
charSet = trimAll(charSet, badChars)
|
||||
return charSet[rand.Intn(len(charSet))], nil
|
||||
},
|
||||
),
|
||||
"rand_base": makeDslWithOptionalArgsFunction(
|
||||
"(length, optionalCharSet, optionalBadChars) string",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
var length int
|
||||
badChars := ""
|
||||
charSet := letters + numbers
|
||||
|
||||
if len(args) >= 1 {
|
||||
l = int(args[0].(float64))
|
||||
}
|
||||
if len(args) >= withCutSetArgsSize {
|
||||
bad = types.ToString(args[1])
|
||||
}
|
||||
chars = trimAll(chars, bad)
|
||||
return randSeq(chars, l), nil
|
||||
},
|
||||
"rand_int": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 2 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
min := 0
|
||||
max := math.MaxInt32
|
||||
argSize := len(args)
|
||||
if argSize < 1 || argSize > 3 {
|
||||
return nil, invalidDslFunctionError
|
||||
}
|
||||
|
||||
if len(args) >= 1 {
|
||||
min = int(args[0].(float64))
|
||||
}
|
||||
if len(args) >= withMaxRandArgsSize {
|
||||
max = int(args[1].(float64))
|
||||
}
|
||||
return rand.Intn(max-min) + min, nil
|
||||
},
|
||||
"unixtime": func(args ...interface{}) (interface{}, error) {
|
||||
seconds := 0
|
||||
if len(args) >= 1 {
|
||||
seconds = int(args[0].(float64))
|
||||
}
|
||||
now := time.Now()
|
||||
offset := now.Add(time.Duration(seconds) * time.Second)
|
||||
return float64(offset.Unix()), nil
|
||||
},
|
||||
// Time Functions
|
||||
"waitfor": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
seconds := args[0].(float64)
|
||||
time.Sleep(time.Duration(seconds) * time.Second)
|
||||
return true, nil
|
||||
},
|
||||
// deserialization Functions
|
||||
"generate_java_gadget": func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 3 {
|
||||
return nil, ErrDSLArguments
|
||||
}
|
||||
gadget := args[0].(string)
|
||||
cmd := args[1].(string)
|
||||
length = int(args[0].(float64))
|
||||
|
||||
var encoding string
|
||||
if len(args) > 2 {
|
||||
encoding = args[2].(string)
|
||||
if argSize >= 2 {
|
||||
badChars = types.ToString(args[1])
|
||||
}
|
||||
if argSize == 3 {
|
||||
charSet = types.ToString(args[2])
|
||||
}
|
||||
charSet = trimAll(charSet, badChars)
|
||||
return randSeq(charSet, length), nil
|
||||
},
|
||||
),
|
||||
"rand_text_alphanumeric": makeDslWithOptionalArgsFunction(
|
||||
"(length, optionalBadChars) string",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
length := 0
|
||||
badChars := ""
|
||||
|
||||
argSize := len(args)
|
||||
if argSize != 1 && argSize != 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
}
|
||||
|
||||
length = int(args[0].(float64))
|
||||
|
||||
if argSize == 2 {
|
||||
badChars = types.ToString(args[1])
|
||||
}
|
||||
chars := trimAll(letters+numbers, badChars)
|
||||
return randSeq(chars, length), nil
|
||||
},
|
||||
),
|
||||
"rand_text_alpha": makeDslWithOptionalArgsFunction(
|
||||
"(length, optionalBadChars) string",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
var length int
|
||||
badChars := ""
|
||||
|
||||
argSize := len(args)
|
||||
if argSize != 1 && argSize != 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
}
|
||||
|
||||
length = int(args[0].(float64))
|
||||
|
||||
if argSize == 2 {
|
||||
badChars = types.ToString(args[1])
|
||||
}
|
||||
chars := trimAll(letters, badChars)
|
||||
return randSeq(chars, length), nil
|
||||
},
|
||||
),
|
||||
"rand_text_numeric": makeDslWithOptionalArgsFunction(
|
||||
"(size int, optionalBadNumbers string) string",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
argSize := len(args)
|
||||
if argSize != 1 && argSize != 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
}
|
||||
|
||||
length := args[0].(int)
|
||||
var badNumbers = ""
|
||||
|
||||
if argSize == 2 {
|
||||
badNumbers = types.ToString(args[1])
|
||||
}
|
||||
|
||||
chars := trimAll(numbers, badNumbers)
|
||||
return randSeq(chars, length), nil
|
||||
},
|
||||
),
|
||||
"rand_int": makeDslWithOptionalArgsFunction(
|
||||
"(optionalMin, optionalMax int) int",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
argSize := len(args)
|
||||
if argSize >= 2 {
|
||||
return nil, invalidDslFunctionError
|
||||
}
|
||||
|
||||
min := 0
|
||||
max := math.MaxInt32
|
||||
|
||||
if argSize >= 1 {
|
||||
min = args[0].(int)
|
||||
}
|
||||
if argSize == 2 {
|
||||
max = args[1].(int)
|
||||
}
|
||||
return rand.Intn(max-min) + min, nil
|
||||
},
|
||||
),
|
||||
"generate_java_gadget": makeDslFunction(3, func(args ...interface{}) (interface{}, error) {
|
||||
gadget := args[0].(string)
|
||||
cmd := args[1].(string)
|
||||
encoding := args[2].(string)
|
||||
data := deserialization.GenerateJavaGadget(gadget, cmd, encoding)
|
||||
return data, nil
|
||||
}),
|
||||
"unixtime": makeDslWithOptionalArgsFunction(
|
||||
"(optionalSeconds uint) float64",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
seconds := 0
|
||||
|
||||
argSize := len(args)
|
||||
if argSize != 0 && argSize != 1 {
|
||||
return nil, invalidDslFunctionError
|
||||
} else if argSize == 1 {
|
||||
seconds = int(args[0].(uint))
|
||||
}
|
||||
|
||||
offset := time.Now().Add(time.Duration(seconds) * time.Second)
|
||||
return float64(offset.Unix()), nil
|
||||
},
|
||||
),
|
||||
"waitfor": makeDslWithOptionalArgsFunction(
|
||||
"(seconds uint)",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != 1 {
|
||||
return nil, invalidDslFunctionError
|
||||
}
|
||||
seconds := args[0].(uint)
|
||||
time.Sleep(time.Duration(seconds) * time.Second)
|
||||
return true, nil
|
||||
},
|
||||
),
|
||||
"print_debug": makeDslWithOptionalArgsFunction(
|
||||
"(args ...interface{})",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) < 1 {
|
||||
return nil, invalidDslFunctionError
|
||||
}
|
||||
gologger.Info().Msgf("print_debug value: %s", fmt.Sprint(args))
|
||||
return true, nil
|
||||
},
|
||||
),
|
||||
"time_now": makeDslWithOptionalArgsFunction(
|
||||
"() float64",
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) == 0 {
|
||||
return nil, invalidDslFunctionError
|
||||
}
|
||||
return float64(time.Now().Unix()), nil
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
dslFunctions = make(map[string]dslFunction, len(tempDslFunctions))
|
||||
for funcName, dslFunc := range tempDslFunctions {
|
||||
dslFunctions[funcName] = dslFunc(funcName)
|
||||
}
|
||||
}
|
||||
|
||||
func createSignaturePart(numberOfParameters int) string {
|
||||
params := make([]string, 0, numberOfParameters)
|
||||
for i := 1; i <= numberOfParameters; i++ {
|
||||
params = append(params, "arg"+strconv.Itoa(i))
|
||||
}
|
||||
return fmt.Sprintf("(%s interface{}) interface{}", strings.Join(params, ", "))
|
||||
}
|
||||
|
||||
func makeDslWithOptionalArgsFunction(signaturePart string, dslFunctionLogic govaluate.ExpressionFunction) func(functionName string) dslFunction {
|
||||
return func(functionName string) dslFunction {
|
||||
return dslFunction{
|
||||
functionName + signaturePart,
|
||||
dslFunctionLogic,
|
||||
}
|
||||
data := deserialization.GenerateJavaGadget(gadget, cmd, encoding)
|
||||
return data, nil
|
||||
},
|
||||
// for debug purposes
|
||||
"print_debug": func(args ...interface{}) (interface{}, error) {
|
||||
gologger.Info().Msgf("print_debug value: %s", fmt.Sprint(args))
|
||||
return true, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func makeDslFunction(numberOfParameters int, dslFunctionLogic govaluate.ExpressionFunction) func(functionName string) dslFunction {
|
||||
return func(functionName string) dslFunction {
|
||||
signature := functionName + createSignaturePart(numberOfParameters)
|
||||
return dslFunction{
|
||||
signature,
|
||||
func(args ...interface{}) (interface{}, error) {
|
||||
if len(args) != numberOfParameters {
|
||||
return nil, fmt.Errorf(invalidDslFunctionMessageTemplate, signature, invalidDslFunctionError)
|
||||
}
|
||||
return dslFunctionLogic(args...)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HelperFunctions returns the dsl helper functions
|
||||
func HelperFunctions() map[string]govaluate.ExpressionFunction {
|
||||
return functions
|
||||
helperFunctions := make(map[string]govaluate.ExpressionFunction, len(dslFunctions))
|
||||
|
||||
for functionName, dslFunction := range dslFunctions {
|
||||
helperFunctions[functionName] = dslFunction.expressFunc
|
||||
}
|
||||
|
||||
return helperFunctions
|
||||
}
|
||||
|
||||
func GetDslFunctionSignatures() []string {
|
||||
result := make([]string, 0, len(dslFunctions))
|
||||
|
||||
for _, dslFunction := range dslFunctions {
|
||||
result = append(result, dslFunction.signature)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// AddHelperFunction allows creation of additional helper functions to be supported with templates
|
||||
func AddHelperFunction(key string, value func(args ...interface{}) (interface{}, error)) error {
|
||||
if _, ok := functions[key]; !ok {
|
||||
functions[key] = value
|
||||
if _, ok := dslFunctions[key]; !ok {
|
||||
dslFunction := dslFunctions[key]
|
||||
dslFunction.signature = "(args ...interface{}) interface{}"
|
||||
dslFunction.expressFunc = value
|
||||
return nil
|
||||
}
|
||||
return errors.New("duplicate helper function key defined")
|
||||
|
|
|
@ -36,7 +36,6 @@ func (m *Matcher) CompileMatchers() error {
|
|||
m.Part = "body"
|
||||
}
|
||||
|
||||
|
||||
// Compile the regexes
|
||||
for _, regex := range m.Regex {
|
||||
compiled, err := regexp.Compile(regex)
|
||||
|
@ -59,7 +58,7 @@ func (m *Matcher) CompileMatchers() error {
|
|||
for _, expr := range m.DSL {
|
||||
compiled, err := govaluate.NewEvaluableExpressionWithFunctions(expr, dsl.HelperFunctions())
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not compile dsl: %s", expr)
|
||||
return fmt.Errorf("could not compile dsl: %s. %w", expr, err)
|
||||
}
|
||||
m.dslCompiled = append(m.dslCompiled, compiled)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue