diff --git a/pkg/templates/cache.go b/pkg/templates/cache.go index 1d60a982..b4c00fb3 100644 --- a/pkg/templates/cache.go +++ b/pkg/templates/cache.go @@ -1,37 +1,39 @@ package templates import ( + "github.com/projectdiscovery/utils/conversion" mapsutil "github.com/projectdiscovery/utils/maps" ) // Templates is a cache for caching and storing templates for reuse. type Cache struct { - items *mapsutil.SyncLockMap[string, parsedTemplateErrHolder] + items *mapsutil.SyncLockMap[string, parsedTemplate] } // New returns a new templates cache func NewCache() *Cache { - return &Cache{items: mapsutil.NewSyncLockMap[string, parsedTemplateErrHolder]()} + return &Cache{items: mapsutil.NewSyncLockMap[string, parsedTemplate]()} } -type parsedTemplateErrHolder struct { +type parsedTemplate struct { template *Template + raw string err error } // Has returns true if the cache has a template. The template // is returned along with any errors if found. -func (t *Cache) Has(template string) (*Template, error) { +func (t *Cache) Has(template string) (*Template, []byte, error) { value, ok := t.items.Get(template) if !ok { - return nil, nil + return nil, nil, nil } - return value.template, value.err + return value.template, conversion.Bytes(value.raw), value.err } // Store stores a template with data and error -func (t *Cache) Store(template string, data *Template, err error) { - _ = t.items.Set(template, parsedTemplateErrHolder{template: data, err: err}) +func (t *Cache) Store(id string, tpl *Template, raw []byte, err error) { + _ = t.items.Set(id, parsedTemplate{template: tpl, raw: conversion.String(raw), err: err}) } // Purge the cache diff --git a/pkg/templates/cache_test.go b/pkg/templates/cache_test.go index ffb17736..8ae529d9 100644 --- a/pkg/templates/cache_test.go +++ b/pkg/templates/cache_test.go @@ -11,14 +11,14 @@ func TestCache(t *testing.T) { templates := NewCache() testErr := errors.New("test error") - data, err := templates.Has("test") + data, _, err := templates.Has("test") require.Nil(t, err, "invalid value for err") require.Nil(t, data, "invalid value for data") item := &Template{} - templates.Store("test", item, testErr) - data, err = templates.Has("test") + templates.Store("test", item, nil, testErr) + data, _, err = templates.Has("test") require.Equal(t, testErr, err, "invalid value for err") require.Equal(t, item, data, "invalid value for data") } diff --git a/pkg/templates/compile.go b/pkg/templates/compile.go index f50fd8b1..f831485d 100644 --- a/pkg/templates/compile.go +++ b/pkg/templates/compile.go @@ -1,6 +1,7 @@ package templates import ( + "bytes" "encoding/json" "fmt" "io" @@ -21,7 +22,6 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/templates/signer" "github.com/projectdiscovery/nuclei/v3/pkg/tmplexec" "github.com/projectdiscovery/nuclei/v3/pkg/utils" - "github.com/projectdiscovery/retryablehttp-go" errorutil "github.com/projectdiscovery/utils/errors" stringsutil "github.com/projectdiscovery/utils/strings" ) @@ -52,27 +52,28 @@ func Parse(filePath string, preprocessor Preprocessor, options protocols.Executo panic("not a parser") } if !options.DoNotCache { - if value, err := parser.compiledTemplatesCache.Has(filePath); value != nil { + if value, _, err := parser.compiledTemplatesCache.Has(filePath); value != nil { return value, err } } var reader io.ReadCloser - if utils.IsURL(filePath) { - // use retryablehttp (tls verification is enabled by default in the standard library) - resp, err := retryablehttp.DefaultClient().Get(filePath) - if err != nil { - return nil, err + if !options.DoNotCache { + _, raw, err := parser.parsedTemplatesCache.Has(filePath) + if err == nil && raw != nil { + reader = io.NopCloser(bytes.NewReader(raw)) } - reader = resp.Body - } else { - var err error - reader, err = options.Catalog.OpenFile(filePath) + } + var err error + if reader == nil { + reader, err = utils.ReaderFromPathOrURL(filePath, options.Catalog) if err != nil { return nil, err } } + defer reader.Close() + options.TemplatePath = filePath template, err := ParseTemplateFromReader(reader, preprocessor, options.Copy()) if err != nil { @@ -88,7 +89,7 @@ func Parse(filePath string, preprocessor Preprocessor, options protocols.Executo } template.Path = filePath if !options.DoNotCache { - parser.compiledTemplatesCache.Store(filePath, template, err) + parser.compiledTemplatesCache.Store(filePath, template, nil, err) } return template, nil } diff --git a/pkg/templates/parser.go b/pkg/templates/parser.go index 1047f686..f63990f3 100644 --- a/pkg/templates/parser.go +++ b/pkg/templates/parser.go @@ -3,18 +3,24 @@ package templates import ( "encoding/json" "fmt" + "io" "github.com/projectdiscovery/nuclei/v3/pkg/catalog" "github.com/projectdiscovery/nuclei/v3/pkg/catalog/config" "github.com/projectdiscovery/nuclei/v3/pkg/utils" "github.com/projectdiscovery/nuclei/v3/pkg/utils/stats" + yamlutil "github.com/projectdiscovery/nuclei/v3/pkg/utils/yaml" + fileutil "github.com/projectdiscovery/utils/file" "gopkg.in/yaml.v2" ) type Parser struct { - ShouldValidate bool - NoStrictSyntax bool - parsedTemplatesCache *Cache + ShouldValidate bool + NoStrictSyntax bool + // this cache can be copied safely between ephemeral instances + parsedTemplatesCache *Cache + // this cache might potentially contain references to heap objects + // it's recommended to always empty it at the end of execution compiledTemplatesCache *Cache } @@ -69,13 +75,29 @@ func (p *Parser) LoadTemplate(templatePath string, t any, extraTags []string, ca // ParseTemplate parses a template and returns a *templates.Template structure func (p *Parser) ParseTemplate(templatePath string, catalog catalog.Catalog) (any, error) { - if value, err := p.parsedTemplatesCache.Has(templatePath); value != nil { + value, _, err := p.parsedTemplatesCache.Has(templatePath) + if value != nil { return value, err } - data, err := utils.ReadFromPathOrURL(templatePath, catalog) + + reader, err := utils.ReaderFromPathOrURL(templatePath, catalog) if err != nil { return nil, err } + defer reader.Close() + + data, err := io.ReadAll(reader) + if err != nil { + return nil, err + } + + // pre-process directives only for local files + if fileutil.FileExists(templatePath) && config.GetTemplateFormatFromExt(templatePath) == config.YAML { + data, err = yamlutil.PreProcess(data) + if err != nil { + return nil, err + } + } template := &Template{} @@ -95,7 +117,7 @@ func (p *Parser) ParseTemplate(templatePath string, catalog catalog.Catalog) (an return nil, err } - p.parsedTemplatesCache.Store(templatePath, template, nil) + p.parsedTemplatesCache.Store(templatePath, template, data, nil) return template, nil } diff --git a/pkg/templates/parser_stats.go b/pkg/templates/parser_stats.go index eeee0890..1b555c16 100644 --- a/pkg/templates/parser_stats.go +++ b/pkg/templates/parser_stats.go @@ -8,6 +8,5 @@ const ( HeadlessFlagWarningStats = "headless-flag-missing-warnings" TemplatesExecutedStats = "templates-executed" CodeFlagWarningStats = "code-flag-missing-warnings" - // Note: this is redefined in workflows.go to avoid circular dependency, so make sure to keep it in sync - SkippedUnsignedStatsTODO = "skipped-unsigned-stats" // tracks loading of unsigned templates + SkippedUnsignedStats = "skipped-unsigned-stats" // tracks loading of unsigned templates ) diff --git a/pkg/templates/parser_test.go b/pkg/templates/parser_test.go index c783024b..324db094 100644 --- a/pkg/templates/parser_test.go +++ b/pkg/templates/parser_test.go @@ -99,7 +99,7 @@ func TestLoadTemplate(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - p.parsedTemplatesCache.Store(tc.name, tc.template, tc.templateErr) + p.parsedTemplatesCache.Store(tc.name, tc.template, nil, tc.templateErr) tagFilter, err := NewTagFilter(&tc.filter) require.Nil(t, err) @@ -141,7 +141,7 @@ func TestLoadTemplate(t *testing.T) { SeverityHolder: severity.Holder{Severity: severity.Medium}, }, } - p.parsedTemplatesCache.Store(name, template, nil) + p.parsedTemplatesCache.Store(name, template, nil, nil) tagFilter, err := NewTagFilter(&Config{}) require.Nil(t, err) diff --git a/pkg/templates/workflows.go b/pkg/templates/workflows.go index 03c84e2e..c402ce76 100644 --- a/pkg/templates/workflows.go +++ b/pkg/templates/workflows.go @@ -10,11 +10,6 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/workflows" ) -const ( - // Note: we redefine to avoid cyclic dependency but it should be same as parsers.SkippedUnsignedStats - SkippedUnsignedStats = "skipped-unsigned-stats" // tracks loading of unsigned templates -) - // compileWorkflow compiles the workflow for execution func compileWorkflow(path string, preprocessor Preprocessor, options *protocols.ExecutorOptions, workflow *workflows.Workflow, loader model.WorkflowLoader) { for _, workflow := range workflow.Workflows { diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 87303c9e..0d95b13f 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -7,10 +7,7 @@ import ( "strings" "github.com/projectdiscovery/nuclei/v3/pkg/catalog" - "github.com/projectdiscovery/nuclei/v3/pkg/catalog/config" - "github.com/projectdiscovery/nuclei/v3/pkg/utils/yaml" "github.com/projectdiscovery/retryablehttp-go" - fileutil "github.com/projectdiscovery/utils/file" ) func IsBlank(value string) bool { @@ -35,38 +32,20 @@ func IsURL(input string) bool { } // ReadFromPathOrURL reads and returns the contents of a file or url. -func ReadFromPathOrURL(templatePath string, catalog catalog.Catalog) (data []byte, err error) { - var reader io.Reader +func ReaderFromPathOrURL(templatePath string, catalog catalog.Catalog) (io.ReadCloser, error) { if IsURL(templatePath) { resp, err := retryablehttp.DefaultClient().Get(templatePath) if err != nil { return nil, err } - defer resp.Body.Close() - reader = resp.Body + return resp.Body, nil } else { f, err := catalog.OpenFile(templatePath) if err != nil { return nil, err } - defer f.Close() - reader = f + return f, nil } - - data, err = io.ReadAll(reader) - if err != nil { - return nil, err - } - - // pre-process directives only for local files - if fileutil.FileExists(templatePath) && config.GetTemplateFormatFromExt(templatePath) == config.YAML { - data, err = yaml.PreProcess(data) - if err != nil { - return nil, err - } - } - - return } // StringSliceContains checks if a string slice contains a string.