caching content + merging caches

dev
mzack 2024-03-13 21:02:36 +01:00
parent 3685379960
commit e9f6febe01
8 changed files with 60 additions and 62 deletions

View File

@ -1,37 +1,39 @@
package templates package templates
import ( import (
"github.com/projectdiscovery/utils/conversion"
mapsutil "github.com/projectdiscovery/utils/maps" mapsutil "github.com/projectdiscovery/utils/maps"
) )
// Templates is a cache for caching and storing templates for reuse. // Templates is a cache for caching and storing templates for reuse.
type Cache struct { type Cache struct {
items *mapsutil.SyncLockMap[string, parsedTemplateErrHolder] items *mapsutil.SyncLockMap[string, parsedTemplate]
} }
// New returns a new templates cache // New returns a new templates cache
func NewCache() *Cache { func NewCache() *Cache {
return &Cache{items: mapsutil.NewSyncLockMap[string, parsedTemplateErrHolder]()} return &Cache{items: mapsutil.NewSyncLockMap[string, parsedTemplate]()}
} }
type parsedTemplateErrHolder struct { type parsedTemplate struct {
template *Template template *Template
raw string
err error err error
} }
// Has returns true if the cache has a template. The template // Has returns true if the cache has a template. The template
// is returned along with any errors if found. // is returned along with any errors if found.
func (t *Cache) Has(template string) (*Template, error) { func (t *Cache) Has(template string) (*Template, []byte, error) {
value, ok := t.items.Get(template) value, ok := t.items.Get(template)
if !ok { if !ok {
return nil, nil return nil, nil, nil
} }
return value.template, value.err return value.template, conversion.Bytes(value.raw), value.err
} }
// Store stores a template with data and error // Store stores a template with data and error
func (t *Cache) Store(template string, data *Template, err error) { func (t *Cache) Store(id string, tpl *Template, raw []byte, err error) {
_ = t.items.Set(template, parsedTemplateErrHolder{template: data, err: err}) _ = t.items.Set(id, parsedTemplate{template: tpl, raw: conversion.String(raw), err: err})
} }
// Purge the cache // Purge the cache

View File

@ -11,14 +11,14 @@ func TestCache(t *testing.T) {
templates := NewCache() templates := NewCache()
testErr := errors.New("test error") testErr := errors.New("test error")
data, err := templates.Has("test") data, _, err := templates.Has("test")
require.Nil(t, err, "invalid value for err") require.Nil(t, err, "invalid value for err")
require.Nil(t, data, "invalid value for data") require.Nil(t, data, "invalid value for data")
item := &Template{} item := &Template{}
templates.Store("test", item, testErr) templates.Store("test", item, nil, testErr)
data, err = templates.Has("test") data, _, err = templates.Has("test")
require.Equal(t, testErr, err, "invalid value for err") require.Equal(t, testErr, err, "invalid value for err")
require.Equal(t, item, data, "invalid value for data") require.Equal(t, item, data, "invalid value for data")
} }

View File

@ -1,6 +1,7 @@
package templates package templates
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -21,7 +22,6 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/templates/signer" "github.com/projectdiscovery/nuclei/v3/pkg/templates/signer"
"github.com/projectdiscovery/nuclei/v3/pkg/tmplexec" "github.com/projectdiscovery/nuclei/v3/pkg/tmplexec"
"github.com/projectdiscovery/nuclei/v3/pkg/utils" "github.com/projectdiscovery/nuclei/v3/pkg/utils"
"github.com/projectdiscovery/retryablehttp-go"
errorutil "github.com/projectdiscovery/utils/errors" errorutil "github.com/projectdiscovery/utils/errors"
stringsutil "github.com/projectdiscovery/utils/strings" stringsutil "github.com/projectdiscovery/utils/strings"
) )
@ -52,27 +52,28 @@ func Parse(filePath string, preprocessor Preprocessor, options protocols.Executo
panic("not a parser") panic("not a parser")
} }
if !options.DoNotCache { if !options.DoNotCache {
if value, err := parser.compiledTemplatesCache.Has(filePath); value != nil { if value, _, err := parser.compiledTemplatesCache.Has(filePath); value != nil {
return value, err return value, err
} }
} }
var reader io.ReadCloser var reader io.ReadCloser
if utils.IsURL(filePath) { if !options.DoNotCache {
// use retryablehttp (tls verification is enabled by default in the standard library) _, raw, err := parser.parsedTemplatesCache.Has(filePath)
resp, err := retryablehttp.DefaultClient().Get(filePath) if err == nil && raw != nil {
if err != nil { reader = io.NopCloser(bytes.NewReader(raw))
return nil, err }
} }
reader = resp.Body
} else {
var err error var err error
reader, err = options.Catalog.OpenFile(filePath) if reader == nil {
reader, err = utils.ReaderFromPathOrURL(filePath, options.Catalog)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
defer reader.Close() defer reader.Close()
options.TemplatePath = filePath options.TemplatePath = filePath
template, err := ParseTemplateFromReader(reader, preprocessor, options.Copy()) template, err := ParseTemplateFromReader(reader, preprocessor, options.Copy())
if err != nil { if err != nil {
@ -88,7 +89,7 @@ func Parse(filePath string, preprocessor Preprocessor, options protocols.Executo
} }
template.Path = filePath template.Path = filePath
if !options.DoNotCache { if !options.DoNotCache {
parser.compiledTemplatesCache.Store(filePath, template, err) parser.compiledTemplatesCache.Store(filePath, template, nil, err)
} }
return template, nil return template, nil
} }

View File

@ -3,18 +3,24 @@ package templates
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog" "github.com/projectdiscovery/nuclei/v3/pkg/catalog"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/config" "github.com/projectdiscovery/nuclei/v3/pkg/catalog/config"
"github.com/projectdiscovery/nuclei/v3/pkg/utils" "github.com/projectdiscovery/nuclei/v3/pkg/utils"
"github.com/projectdiscovery/nuclei/v3/pkg/utils/stats" "github.com/projectdiscovery/nuclei/v3/pkg/utils/stats"
yamlutil "github.com/projectdiscovery/nuclei/v3/pkg/utils/yaml"
fileutil "github.com/projectdiscovery/utils/file"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
type Parser struct { type Parser struct {
ShouldValidate bool ShouldValidate bool
NoStrictSyntax bool NoStrictSyntax bool
// this cache can be copied safely between ephemeral instances
parsedTemplatesCache *Cache parsedTemplatesCache *Cache
// this cache might potentially contain references to heap objects
// it's recommended to always empty it at the end of execution
compiledTemplatesCache *Cache compiledTemplatesCache *Cache
} }
@ -69,13 +75,29 @@ func (p *Parser) LoadTemplate(templatePath string, t any, extraTags []string, ca
// ParseTemplate parses a template and returns a *templates.Template structure // ParseTemplate parses a template and returns a *templates.Template structure
func (p *Parser) ParseTemplate(templatePath string, catalog catalog.Catalog) (any, error) { func (p *Parser) ParseTemplate(templatePath string, catalog catalog.Catalog) (any, error) {
if value, err := p.parsedTemplatesCache.Has(templatePath); value != nil { value, _, err := p.parsedTemplatesCache.Has(templatePath)
if value != nil {
return value, err return value, err
} }
data, err := utils.ReadFromPathOrURL(templatePath, catalog)
reader, err := utils.ReaderFromPathOrURL(templatePath, catalog)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer reader.Close()
data, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
// pre-process directives only for local files
if fileutil.FileExists(templatePath) && config.GetTemplateFormatFromExt(templatePath) == config.YAML {
data, err = yamlutil.PreProcess(data)
if err != nil {
return nil, err
}
}
template := &Template{} template := &Template{}
@ -95,7 +117,7 @@ func (p *Parser) ParseTemplate(templatePath string, catalog catalog.Catalog) (an
return nil, err return nil, err
} }
p.parsedTemplatesCache.Store(templatePath, template, nil) p.parsedTemplatesCache.Store(templatePath, template, data, nil)
return template, nil return template, nil
} }

View File

@ -8,6 +8,5 @@ const (
HeadlessFlagWarningStats = "headless-flag-missing-warnings" HeadlessFlagWarningStats = "headless-flag-missing-warnings"
TemplatesExecutedStats = "templates-executed" TemplatesExecutedStats = "templates-executed"
CodeFlagWarningStats = "code-flag-missing-warnings" CodeFlagWarningStats = "code-flag-missing-warnings"
// Note: this is redefined in workflows.go to avoid circular dependency, so make sure to keep it in sync SkippedUnsignedStats = "skipped-unsigned-stats" // tracks loading of unsigned templates
SkippedUnsignedStatsTODO = "skipped-unsigned-stats" // tracks loading of unsigned templates
) )

View File

@ -99,7 +99,7 @@ func TestLoadTemplate(t *testing.T) {
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
p.parsedTemplatesCache.Store(tc.name, tc.template, tc.templateErr) p.parsedTemplatesCache.Store(tc.name, tc.template, nil, tc.templateErr)
tagFilter, err := NewTagFilter(&tc.filter) tagFilter, err := NewTagFilter(&tc.filter)
require.Nil(t, err) require.Nil(t, err)
@ -141,7 +141,7 @@ func TestLoadTemplate(t *testing.T) {
SeverityHolder: severity.Holder{Severity: severity.Medium}, SeverityHolder: severity.Holder{Severity: severity.Medium},
}, },
} }
p.parsedTemplatesCache.Store(name, template, nil) p.parsedTemplatesCache.Store(name, template, nil, nil)
tagFilter, err := NewTagFilter(&Config{}) tagFilter, err := NewTagFilter(&Config{})
require.Nil(t, err) require.Nil(t, err)

View File

@ -10,11 +10,6 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/workflows" "github.com/projectdiscovery/nuclei/v3/pkg/workflows"
) )
const (
// Note: we redefine to avoid cyclic dependency but it should be same as parsers.SkippedUnsignedStats
SkippedUnsignedStats = "skipped-unsigned-stats" // tracks loading of unsigned templates
)
// compileWorkflow compiles the workflow for execution // compileWorkflow compiles the workflow for execution
func compileWorkflow(path string, preprocessor Preprocessor, options *protocols.ExecutorOptions, workflow *workflows.Workflow, loader model.WorkflowLoader) { func compileWorkflow(path string, preprocessor Preprocessor, options *protocols.ExecutorOptions, workflow *workflows.Workflow, loader model.WorkflowLoader) {
for _, workflow := range workflow.Workflows { for _, workflow := range workflow.Workflows {

View File

@ -7,10 +7,7 @@ import (
"strings" "strings"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog" "github.com/projectdiscovery/nuclei/v3/pkg/catalog"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/config"
"github.com/projectdiscovery/nuclei/v3/pkg/utils/yaml"
"github.com/projectdiscovery/retryablehttp-go" "github.com/projectdiscovery/retryablehttp-go"
fileutil "github.com/projectdiscovery/utils/file"
) )
func IsBlank(value string) bool { func IsBlank(value string) bool {
@ -35,38 +32,20 @@ func IsURL(input string) bool {
} }
// ReadFromPathOrURL reads and returns the contents of a file or url. // ReadFromPathOrURL reads and returns the contents of a file or url.
func ReadFromPathOrURL(templatePath string, catalog catalog.Catalog) (data []byte, err error) { func ReaderFromPathOrURL(templatePath string, catalog catalog.Catalog) (io.ReadCloser, error) {
var reader io.Reader
if IsURL(templatePath) { if IsURL(templatePath) {
resp, err := retryablehttp.DefaultClient().Get(templatePath) resp, err := retryablehttp.DefaultClient().Get(templatePath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() return resp.Body, nil
reader = resp.Body
} else { } else {
f, err := catalog.OpenFile(templatePath) f, err := catalog.OpenFile(templatePath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer f.Close() return f, nil
reader = f
} }
data, err = io.ReadAll(reader)
if err != nil {
return nil, err
}
// pre-process directives only for local files
if fileutil.FileExists(templatePath) && config.GetTemplateFormatFromExt(templatePath) == config.YAML {
data, err = yaml.PreProcess(data)
if err != nil {
return nil, err
}
}
return
} }
// StringSliceContains checks if a string slice contains a string. // StringSliceContains checks if a string slice contains a string.