solver: implement content based cache support

Signed-off-by: Tonis Tiigi <tonistiigi@gmail.com>
docker-18.09
Tonis Tiigi 2017-07-31 15:06:59 -07:00
parent fcf2aab63e
commit 8738929b8c
21 changed files with 1119 additions and 319 deletions

View File

@ -47,12 +47,12 @@ func Checksum(ctx context.Context, ref cache.ImmutableRef, path string) (digest.
return getDefaultManager().Checksum(ctx, ref, path)
}
func GetCacheContext(ctx context.Context, ref cache.ImmutableRef) (CacheContext, error) {
return getDefaultManager().GetCacheContext(ctx, ref)
func GetCacheContext(ctx context.Context, md *metadata.StorageItem) (CacheContext, error) {
return getDefaultManager().GetCacheContext(ctx, md)
}
func SetCacheContext(ctx context.Context, ref cache.ImmutableRef, cc CacheContext) error {
return getDefaultManager().SetCacheContext(ctx, ref, cc)
func SetCacheContext(ctx context.Context, md *metadata.StorageItem, cc CacheContext) error {
return getDefaultManager().SetCacheContext(ctx, md, cc)
}
type CacheContext interface {
@ -67,45 +67,57 @@ type Hashed interface {
type cacheManager struct {
locker *locker.Locker
lru *simplelru.LRU
lruMu sync.Mutex
}
func (cm *cacheManager) Checksum(ctx context.Context, ref cache.ImmutableRef, p string) (digest.Digest, error) {
cc, err := cm.GetCacheContext(ctx, ref)
cc, err := cm.GetCacheContext(ctx, ensureOriginMetadata(ref.Metadata()))
if err != nil {
return "", nil
}
return cc.Checksum(ctx, ref, p)
}
func (cm *cacheManager) GetCacheContext(ctx context.Context, ref cache.ImmutableRef) (CacheContext, error) {
cm.locker.Lock(ref.ID())
v, ok := cm.lru.Get(ref.ID())
func (cm *cacheManager) GetCacheContext(ctx context.Context, md *metadata.StorageItem) (CacheContext, error) {
cm.locker.Lock(md.ID())
cm.lruMu.Lock()
v, ok := cm.lru.Get(md.ID())
cm.lruMu.Unlock()
if ok {
cm.locker.Unlock(ref.ID())
cm.locker.Unlock(md.ID())
return v.(*cacheContext), nil
}
cc, err := newCacheContext(ref.Metadata())
cc, err := newCacheContext(md)
if err != nil {
cm.locker.Unlock(ref.ID())
cm.locker.Unlock(md.ID())
return nil, err
}
cm.lru.Add(ref.ID(), cc)
cm.locker.Unlock(ref.ID())
cm.lruMu.Lock()
cm.lru.Add(md.ID(), cc)
cm.lruMu.Unlock()
cm.locker.Unlock(md.ID())
return cc, nil
}
func (cm *cacheManager) SetCacheContext(ctx context.Context, ref cache.ImmutableRef, cci CacheContext) error {
func (cm *cacheManager) SetCacheContext(ctx context.Context, md *metadata.StorageItem, cci CacheContext) error {
cc, ok := cci.(*cacheContext)
if !ok {
return errors.Errorf("invalid cachecontext: %T", cc)
}
if ref.ID() != cc.md.ID() {
return errors.New("saving cachecontext under different ID not supported")
if md.ID() != cc.md.ID() {
cc = &cacheContext{
md: md,
tree: cci.(*cacheContext).tree,
dirtyMap: map[string]struct{}{},
}
} else {
if err := cc.save(); err != nil {
return err
}
}
if err := cc.save(); err != nil {
return err
}
cm.lru.Add(ref.ID(), cc)
cm.lruMu.Lock()
cm.lru.Add(md.ID(), cc)
cm.lruMu.Unlock()
return nil
}
@ -193,7 +205,9 @@ func (cc *cacheContext) save() error {
cc.mu.Lock()
defer cc.mu.Unlock()
cc.dirty = true
if cc.txn != nil {
cc.commitActiveTransaction()
}
var l CacheRecords
node := cc.tree.Root()
@ -231,10 +245,23 @@ func (cc *cacheContext) HandleChange(kind fsutil.ChangeKind, p string, fi os.Fil
}
cc.mu.Lock()
defer cc.mu.Unlock()
if cc.txn == nil {
cc.txn = cc.tree.Txn()
cc.node = cc.tree.Root()
// root is not called by HandleChange. need to fake it
if _, ok := cc.node.Get([]byte("/")); !ok {
cc.txn.Insert([]byte("/"), &CacheRecord{
Type: CacheRecordTypeDirHeader,
Digest: digest.FromBytes(nil),
})
cc.txn.Insert([]byte(""), &CacheRecord{
Type: CacheRecordTypeDir,
})
}
}
if kind == fsutil.ChangeKindDelete {
v, ok := cc.txn.Delete(k)
if ok {
@ -245,7 +272,6 @@ func (cc *cacheContext) HandleChange(kind fsutil.ChangeKind, p string, fi os.Fil
d = ""
}
cc.dirtyMap[d] = struct{}{}
cc.mu.Unlock()
return
}
@ -256,7 +282,6 @@ func (cc *cacheContext) HandleChange(kind fsutil.ChangeKind, p string, fi os.Fil
h, ok := fi.(Hashed)
if !ok {
cc.mu.Unlock()
return errors.Errorf("invalid fileinfo: %s", p)
}
@ -287,7 +312,6 @@ func (cc *cacheContext) HandleChange(kind fsutil.ChangeKind, p string, fi os.Fil
d = ""
}
cc.dirtyMap[d] = struct{}{}
cc.mu.Unlock()
return nil
}
@ -405,12 +429,12 @@ func (cc *cacheContext) checksum(ctx context.Context, root *iradix.Node, txn *ir
switch cr.Type {
case CacheRecordTypeDir:
h := sha256.New()
iter := root.Iterator()
next := append(k, []byte("/")...)
iter.SeekPrefix(next)
iter := root.Seek(next)
subk := next
ok := true
for {
subk, _, ok := iter.Next()
if !ok || bytes.Compare(next, subk) > 0 {
if !ok || !bytes.HasPrefix(subk, next) {
break
}
h.Write(bytes.TrimPrefix(subk, k))
@ -422,9 +446,10 @@ func (cc *cacheContext) checksum(ctx context.Context, root *iradix.Node, txn *ir
h.Write([]byte(subcr.Digest))
if subcr.Type == CacheRecordTypeDir { // skip subfiles
next = append(k, []byte("/\xff")...)
iter.SeekPrefix(next)
next := append(subk, []byte("/\xff")...)
iter = root.Seek(next)
}
subk, _, ok = iter.Next()
}
dgst = digest.NewDigest(digest.SHA256, h)
default:
@ -565,3 +590,19 @@ func addParentToMap(d string, m map[string]struct{}) {
m[d] = struct{}{}
addParentToMap(d, m)
}
func ensureOriginMetadata(md *metadata.StorageItem) *metadata.StorageItem {
v := md.Get("cache.equalMutable") // TODO: const
if v == nil {
return md
}
var mutable string
if err := v.Unmarshal(&mutable); err != nil {
return md
}
si, ok := md.Storage().Get(mutable)
if ok {
return &si
}
return md
}

View File

@ -92,7 +92,7 @@ func TestChecksumBasicFile(t *testing.T) {
dgst, err = cc.Checksum(context.TODO(), ref, "/")
assert.NoError(t, err)
assert.Equal(t, digest.Digest("sha256:0d87c8c2a606f961483cd4c5dc0350a4136a299b4066eea4a969d6ed756614cd"), dgst)
assert.Equal(t, digest.Digest("sha256:7378af5287e8b417b6cbc63154d300e130983bfc645e35e86fdadf6f5060468a"), dgst)
dgst, err = cc.Checksum(context.TODO(), ref, "d0")
assert.NoError(t, err)

View File

@ -1,22 +1,26 @@
package instructioncache
import (
"strings"
"github.com/boltdb/bolt"
"github.com/moby/buildkit/cache"
"github.com/moby/buildkit/cache/metadata"
digest "github.com/opencontainers/go-digest"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"golang.org/x/net/context"
)
const cacheKey = "buildkit.instructioncache"
const contentCacheKey = "buildkit.instructioncache.content"
type LocalStore struct {
MetadataStore *metadata.Store
Cache cache.Accessor
}
func (ls *LocalStore) Set(key string, value interface{}) error {
func (ls *LocalStore) Set(key digest.Digest, value interface{}) error {
ref, ok := value.(cache.ImmutableRef)
if !ok {
return errors.Errorf("invalid ref")
@ -25,21 +29,21 @@ func (ls *LocalStore) Set(key string, value interface{}) error {
if err != nil {
return err
}
v.Index = index(key)
v.Index = index(key.String())
si, _ := ls.MetadataStore.Get(ref.ID())
return si.Update(func(b *bolt.Bucket) error {
return si.SetValue(b, index(key), v)
return si.SetValue(b, v.Index, v)
})
}
func (ls *LocalStore) Lookup(ctx context.Context, key string) (interface{}, error) {
snaps, err := ls.MetadataStore.Search(index(key))
func (ls *LocalStore) Lookup(ctx context.Context, key digest.Digest) (interface{}, error) {
snaps, err := ls.MetadataStore.Search(index(key.String()))
if err != nil {
return nil, err
}
for _, s := range snaps {
v := s.Get(index(key))
v := s.Get(index(key.String()))
if v != nil {
var id string
if err = v.Unmarshal(&id); err != nil {
@ -56,6 +60,42 @@ func (ls *LocalStore) Lookup(ctx context.Context, key string) (interface{}, erro
return nil, nil
}
func (ls *LocalStore) SetContentMapping(key digest.Digest, value interface{}) error {
ref, ok := value.(cache.ImmutableRef)
if !ok {
return errors.Errorf("invalid ref")
}
v, err := metadata.NewValue(ref.ID())
if err != nil {
return err
}
v.Index = contentIndex(key.String())
si, _ := ls.MetadataStore.Get(ref.ID())
return si.Update(func(b *bolt.Bucket) error {
return si.SetValue(b, v.Index, v)
})
}
func (ls *LocalStore) GetContentMapping(key digest.Digest) ([]digest.Digest, error) {
snaps, err := ls.MetadataStore.Search(contentIndex(key.String()))
if err != nil {
return nil, err
}
var out []digest.Digest
for _, s := range snaps {
for _, k := range s.Keys() {
if strings.HasPrefix(k, index("")) {
out = append(out, digest.Digest(strings.TrimPrefix(k, index("")))) // TODO: type
}
}
}
return out, nil
}
func index(k string) string {
return cacheKey + "::" + k
}
func contentIndex(k string) string {
return contentCacheKey + "::" + k
}

1
cache/metadata.go vendored
View File

@ -18,7 +18,6 @@ import (
const sizeUnknown int64 = -1
const keySize = "snapshot.size"
const keyEqualMutable = "cache.equalMutable"
const keyEqualImmutable = "cache.equalImmutable"
const keyCachePolicy = "cache.cachePolicy"
const keyDescription = "cache.description"
const keyCreatedAt = "cache.createdAt"

View File

@ -205,6 +205,10 @@ func newStorageItem(id string, b *bolt.Bucket, s *Store) (StorageItem, error) {
return si, nil
}
func (s *StorageItem) Storage() *Store { // TODO: used in local source. how to remove this?
return s.storage
}
func (s *StorageItem) ID() string {
return s.id
}

1
cache/refs.go vendored
View File

@ -28,6 +28,7 @@ type MutableRef interface {
Commit(context.Context) (ImmutableRef, error)
Release(context.Context) error
Size(ctx context.Context) (int64, error)
Metadata() *metadata.StorageItem
}
type Mountable interface {

View File

@ -21,10 +21,11 @@ func recvDiffCopy(ds grpc.Stream, dest string, cu CacheUpdater, progress progres
logrus.Debugf("diffcopy took: %v", time.Since(st))
}()
var cf fsutil.ChangeFunc
var ch fsutil.ContentHasher
if cu != nil {
cu.MarkSupported(true)
cf = cu.HandleChange
ch = cu.ContentHasher()
}
_ = cf
return fsutil.Receive(ds.Context(), ds, dest, nil, nil, progress)
return fsutil.Receive(ds.Context(), ds, dest, cf, ch, progress)
}

View File

@ -147,6 +147,7 @@ type FSSendRequestOpt struct {
type CacheUpdater interface {
MarkSupported(bool)
HandleChange(fsutil.ChangeKind, string, os.FileInfo, error) error
ContentHasher() fsutil.ContentHasher
}
// FSSync initializes a transfer of files

View File

@ -15,6 +15,8 @@ import (
"golang.org/x/net/context"
)
const buildCacheType = "buildkit.build.v0"
type buildOp struct {
op *pb.BuildOp
s *Solver
@ -29,18 +31,18 @@ func newBuildOp(v Vertex, op *pb.Op_Build, s *Solver) (Op, error) {
}, nil
}
func (b *buildOp) CacheKey(ctx context.Context, inputs []string) (string, int, error) {
func (b *buildOp) CacheKey(ctx context.Context) (digest.Digest, error) {
dt, err := json.Marshal(struct {
Inputs []string
Exec *pb.BuildOp
Type string
Exec *pb.BuildOp
}{
Inputs: inputs,
Exec: b.op,
Type: buildCacheType,
Exec: b.op,
})
if err != nil {
return "", 0, err
return "", err
}
return digest.FromBytes(dt).String(), 1, nil // TODO: other builders should support many outputs
return digest.FromBytes(dt), nil
}
func (b *buildOp) Run(ctx context.Context, inputs []Reference) (outputs []Reference, retErr error) {
@ -125,17 +127,14 @@ func (b *buildOp) Run(ctx context.Context, inputs []Reference) (outputs []Refere
pw, _, ctx := progress.FromContext(ctx, progress.WithMetadata("parentVertex", b.v.Digest()))
defer pw.Close()
refs, err := b.s.getRefs(ctx, vv)
// filer out only the required ref
var out Reference
for i, r := range refs {
if i == index {
out = r
} else {
go r.Release(context.TODO())
}
newref, err := b.s.getRef(ctx, vv, index)
if err != nil {
return nil, err
}
return []Reference{out}, err
return []Reference{newref}, err
}
func (b *buildOp) ContentKeys(context.Context, [][]digest.Digest, []Reference) ([]digest.Digest, error) {
return nil, nil
}

View File

@ -7,14 +7,18 @@ import (
"strings"
"github.com/moby/buildkit/cache"
"github.com/moby/buildkit/cache/contenthash"
"github.com/moby/buildkit/solver/pb"
"github.com/moby/buildkit/util/progress/logs"
"github.com/moby/buildkit/worker"
digest "github.com/opencontainers/go-digest"
"github.com/pkg/errors"
"golang.org/x/net/context"
"golang.org/x/sync/errgroup"
)
const execCacheType = "buildkit.exec.v0"
type execOp struct {
op *pb.ExecOp
cm cache.Manager
@ -29,25 +33,19 @@ func newExecOp(_ Vertex, op *pb.Op_Exec, cm cache.Manager, w worker.Worker) (Op,
}, nil
}
func (e *execOp) CacheKey(ctx context.Context, inputs []string) (string, int, error) {
func (e *execOp) CacheKey(ctx context.Context) (digest.Digest, error) {
dt, err := json.Marshal(struct {
Inputs []string
Exec *pb.ExecOp
Type string
Exec *pb.ExecOp
}{
Inputs: inputs,
Exec: e.op,
Type: execCacheType,
Exec: e.op,
})
if err != nil {
return "", 0, err
}
numRefs := 0
for _, m := range e.op.Mounts {
if m.Output != pb.SkipOutput {
numRefs++
}
return "", err
}
return digest.FromBytes(dt).String(), numRefs, nil
return digest.FromBytes(dt), nil
}
func (e *execOp) Run(ctx context.Context, inputs []Reference) ([]Reference, error) {
@ -130,3 +128,74 @@ func (e *execOp) Run(ctx context.Context, inputs []Reference) ([]Reference, erro
}
return refs, nil
}
func (e *execOp) ContentKeys(ctx context.Context, inputs [][]digest.Digest, refs []Reference) ([]digest.Digest, error) {
if len(refs) == 0 {
return nil, nil
}
// contentKey for exec uses content based checksum for mounts and definition
// based checksum for root
rootIndex := -1
skip := true
srcs := make([]string, len(refs))
for _, m := range e.op.Mounts {
if m.Input != pb.Empty {
if m.Dest != pb.RootMount {
srcs[int(m.Input)] = "/" // TODO: selector
skip = false
} else {
rootIndex = int(m.Input)
}
}
}
if skip {
return nil, nil
}
dgsts := make([]digest.Digest, len(refs))
eg, ctx := errgroup.WithContext(ctx)
for i, ref := range refs {
if srcs[i] == "" {
continue
}
func(i int, ref Reference) {
eg.Go(func() error {
ref, ok := toImmutableRef(ref)
if !ok {
return errors.Errorf("invalid reference")
}
dgst, err := contenthash.Checksum(ctx, ref, srcs[i])
if err != nil {
return err
}
dgsts[i] = dgst
return nil
})
}(i, ref)
}
if err := eg.Wait(); err != nil {
return nil, err
}
var out []digest.Digest
for _, cacheKeys := range inputs {
dt, err := json.Marshal(struct {
Type string
Sources []digest.Digest
Root digest.Digest
Exec *pb.ExecOp
}{
Type: execCacheType,
Sources: dgsts,
Root: cacheKeys[rootIndex],
Exec: e.op,
})
if err != nil {
return nil, err
}
out = append(out, digest.FromBytes(dt))
}
return out, nil
}

View File

@ -70,7 +70,7 @@ func loadLLBVertexRecursive(dgst digest.Digest, op *pb.Op, all map[digest.Digest
if err != nil {
return nil, err
}
vtx.inputs = append(vtx.inputs, &input{index: int(in.Index), vertex: sub})
vtx.inputs = append(vtx.inputs, &input{index: Index(in.Index), vertex: sub})
}
vtx.initClientVertex()
cache[dgst] = vtx

View File

@ -1,7 +1,9 @@
package solver
import (
"encoding/json"
"fmt"
"sync"
"github.com/moby/buildkit/cache"
"github.com/moby/buildkit/client"
@ -10,6 +12,7 @@ import (
"github.com/moby/buildkit/source"
"github.com/moby/buildkit/util/progress"
"github.com/moby/buildkit/worker"
digest "github.com/opencontainers/go-digest"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"golang.org/x/net/context"
@ -50,13 +53,16 @@ type Reference interface {
// Op is an implementation for running a vertex
type Op interface {
CacheKey(context.Context, []string) (string, int, error)
CacheKey(context.Context) (digest.Digest, error)
ContentKeys(context.Context, [][]digest.Digest, []Reference) ([]digest.Digest, error)
Run(ctx context.Context, inputs []Reference) (outputs []Reference, err error)
}
type InstructionCache interface {
Lookup(ctx context.Context, key string) (interface{}, error) // TODO: regular ref
Set(key string, ref interface{}) error
Lookup(ctx context.Context, key digest.Digest) (interface{}, error) // TODO: regular ref
Set(key digest.Digest, ref interface{}) error
SetContentMapping(key digest.Digest, value interface{}) error
GetContentMapping(dgst digest.Digest) ([]digest.Digest, error)
}
type Solver struct {
@ -101,34 +107,25 @@ func (s *Solver) Solve(ctx context.Context, id string, v Vertex, exp exporter.Ex
return err
}
refs, err := s.getRefs(ctx, solveVertex)
ref, err := s.getRef(ctx, solveVertex, index)
s.activeState.cancel(j)
if err != nil {
return err
}
defer func() {
for _, r := range refs {
r.Release(context.TODO())
}
ref.Release(context.TODO())
}()
for _, ref := range refs {
immutable, ok := toImmutableRef(ref)
if !ok {
return errors.Errorf("invalid reference for exporting: %T", ref)
}
if err := immutable.Finalize(ctx); err != nil {
return err
}
immutable, ok := toImmutableRef(ref)
if !ok {
return errors.Errorf("invalid reference for exporting: %T", ref)
}
if err := immutable.Finalize(ctx); err != nil {
return err
}
if exp != nil {
r := refs[int(index)]
immutable, ok := toImmutableRef(r)
if !ok {
return errors.Errorf("invalid reference for exporting: %T", r)
}
vv.notifyStarted(ctx)
pw, _, ctx := progress.FromContext(ctx, progress.WithMetadata("vertex", vv.Digest()))
defer pw.Close()
@ -150,32 +147,13 @@ func (s *Solver) Status(ctx context.Context, id string, statusChan chan *client.
return j.pipe(ctx, statusChan)
}
func (s *Solver) getCacheKey(ctx context.Context, g *vertex) (cacheKey string, numRefs int, retErr error) {
// getCacheKey return a cache key for a single output of a vertex
func (s *Solver) getCacheKey(ctx context.Context, g *vertex, inputs []digest.Digest, index Index) (dgst digest.Digest, retErr error) {
state, err := s.activeState.vertexState(ctx, g.digest, func() (Op, error) {
return s.resolve(g)
})
if err != nil {
return "", 0, err
}
inputs := make([]string, len(g.inputs))
if len(g.inputs) > 0 {
eg, ctx := errgroup.WithContext(ctx)
for i, in := range g.inputs {
func(i int, in *vertex, index int) {
eg.Go(func() error {
k, _, err := s.getCacheKey(ctx, in)
if err != nil {
return err
}
inputs[i] = fmt.Sprintf("%s.%d", k, index)
return nil
})
}(i, in.vertex, in.index)
}
if err := eg.Wait(); err != nil {
return "", 0, err
}
return "", err
}
pw, _, ctx := progress.FromContext(ctx, progress.WithMetadata("vertex", g.Digest()))
@ -188,124 +166,200 @@ func (s *Solver) getCacheKey(ctx context.Context, g *vertex) (cacheKey string, n
}()
}
return state.GetCacheKey(ctx, func(ctx context.Context, op Op) (string, int, error) {
return op.CacheKey(ctx, inputs)
dgst, err = state.GetCacheKey(ctx, func(ctx context.Context, op Op) (digest.Digest, error) {
return op.CacheKey(ctx)
})
if err != nil {
return "", err
}
dt, err := json.Marshal(struct {
Index Index
Inputs []digest.Digest
Digest digest.Digest
}{
Index: index,
Inputs: inputs,
Digest: dgst,
})
if err != nil {
return "", err
}
return digest.FromBytes(dt), nil
}
func (s *Solver) getRefs(ctx context.Context, g *vertex) (retRef []Reference, retErr error) {
// walkVertex walks all possible cache keys and a evaluated reference for a
// single output of a vertex.
func (s *Solver) walkVertex(ctx context.Context, g *vertex, index Index, fn func(digest.Digest, Reference) (bool, error)) (retErr error) {
state, err := s.activeState.vertexState(ctx, g.digest, func() (Op, error) {
return s.resolve(g)
})
if err != nil {
return nil, err
return err
}
var cacheKey string
if s.cache != nil {
var err error
var numRefs int
cacheKey, numRefs, err = s.getCacheKey(ctx, g)
if err != nil {
return nil, err
}
cacheRefs := make([]Reference, 0, numRefs)
// check if all current refs are already cached
for i := 0; i < numRefs; i++ {
ref, err := s.cache.Lookup(ctx, fmt.Sprintf("%s.%d", cacheKey, i))
if err != nil {
return nil, err
}
if ref == nil { // didn't find ref, release all
for _, ref := range cacheRefs {
ref.Release(context.TODO())
}
break
}
cacheRefs = append(cacheRefs, ref.(Reference))
if len(cacheRefs) == numRefs { // last item
g.recursiveMarkCached(ctx)
return cacheRefs, nil
}
}
}
inputCacheKeysMu := sync.Mutex{}
inputCacheKeys := make([][]digest.Digest, len(g.inputs))
walkerStopped := false
// refs contains all outputs for all input vertexes
refs := make([][]*sharedRef, len(g.inputs))
if len(g.inputs) > 0 {
eg, ctx := errgroup.WithContext(ctx)
for i, in := range g.inputs {
func(i int, in *vertex, index int) {
eg.Go(func() error {
if s.cache != nil {
k, numRefs, err := s.getCacheKey(ctx, in)
if err != nil {
return err
}
ref, err := s.cache.Lookup(ctx, fmt.Sprintf("%s.%d", k, index))
if err != nil {
return err
}
if ref != nil {
if ref, ok := toImmutableRef(ref.(Reference)); ok {
refs[i] = make([]*sharedRef, numRefs)
refs[i][index] = newSharedRef(ref)
in.recursiveMarkCached(ctx)
return nil
}
}
}
inputCtx, cancelInputCtx := context.WithCancel(ctx)
defer cancelInputCtx()
// execute input vertex
r, err := s.getRefs(ctx, in)
if err != nil {
return err
}
for _, r := range r {
refs[i] = append(refs[i], newSharedRef(r))
}
if ref, ok := toImmutableRef(r[index].(Reference)); ok {
// make sure input that is required by next step does not get released in case build is cancelled
if err := cache.CachePolicyRetain(ref); err != nil {
return err
}
}
return nil
})
}(i, in.vertex, in.index)
}
err := eg.Wait()
if err != nil {
for _, r := range refs {
for _, r := range r {
if r != nil {
go r.Release(context.TODO())
}
}
}
return nil, err
}
}
// determine the inputs that were needed
inputRefs := make([]Reference, 0, len(g.inputs))
for i, inp := range g.inputs {
inputRefs = append(inputRefs, refs[i][inp.index].Clone())
}
inputRefs := make([]Reference, len(g.inputs))
defer func() {
for _, r := range inputRefs {
go r.Release(context.TODO())
}
}()
// release anything else
for _, r := range refs {
for _, r := range r {
if r != nil {
go r.Release(context.TODO())
}
}
}()
if len(g.inputs) > 0 {
eg, ctx := errgroup.WithContext(ctx)
for i, in := range g.inputs {
func(i int, in *input) {
eg.Go(func() error {
var inputRef Reference
defer func() {
if inputRef != nil {
go inputRef.Release(context.TODO())
}
}()
err := s.walkVertex(inputCtx, in.vertex, in.index, func(k digest.Digest, ref Reference) (bool, error) {
if k == "" && ref == nil {
// indicator between cache key and reference
if inputRef != nil {
return true, nil
}
// TODO: might be good to block here if other inputs may
// cause cache hits to avoid duplicate work.
return false, nil
}
if ref != nil {
inputRef = ref
return true, nil
}
inputCacheKeysMu.Lock()
defer inputCacheKeysMu.Unlock()
if walkerStopped {
return walkerStopped, nil
}
// try all known combinations together with new key
inputCacheKeysCopy := append([][]digest.Digest{}, inputCacheKeys...)
inputCacheKeysCopy[i] = []digest.Digest{k}
inputCacheKeys[i] = append(inputCacheKeys[i], k)
for _, inputKeys := range combinations(inputCacheKeysCopy) {
cacheKey, err := s.getCacheKey(ctx, g, inputKeys, index)
if err != nil {
return false, err
}
stop, err := fn(cacheKey, nil)
if err != nil {
return false, err
}
if stop {
walkerStopped = true
cancelInputCtx() // parent matched, stop processing current node and its inputs
return true, nil
}
}
// if no parent matched, try looking up current node from cache
if s.cache != nil && inputRef == nil {
lookupRef, err := s.cache.Lookup(ctx, k)
if err != nil {
return false, err
}
if lookupRef != nil {
inputRef = lookupRef.(Reference)
in.vertex.recursiveMarkCached(ctx)
return true, nil
}
}
return false, nil
})
if inputRef != nil {
// make sure that the inputs for other steps don't get released on cancellation
if ref, ok := toImmutableRef(inputRef); ok {
if err := cache.CachePolicyRetain(ref); err != nil {
return err
}
if err := ref.Metadata().Commit(); err != nil {
return err
}
}
}
inputCacheKeysMu.Lock()
defer inputCacheKeysMu.Unlock()
if walkerStopped {
return nil
}
if err != nil {
return err
}
inputRefs[i] = inputRef
inputRef = nil
return nil
})
}(i, in)
}
if err := eg.Wait(); err != nil && !walkerStopped {
return err
}
} else {
cacheKey, err := s.getCacheKey(ctx, g, nil, index)
if err != nil {
return err
}
stop, err := fn(cacheKey, nil)
if err != nil {
return err
}
walkerStopped = stop
}
if walkerStopped {
return nil
}
var contentKeys []digest.Digest
if s.cache != nil {
// try to determine content based key
contentKeys, err = state.op.ContentKeys(ctx, combinations(inputCacheKeys), inputRefs)
if err != nil {
return err
}
for _, k := range contentKeys {
cks, err := s.cache.GetContentMapping(contentKeyWithIndex(k, index))
if err != nil {
return err
}
for _, k := range cks {
stop, err := fn(k, nil)
if err != nil {
return err
}
if stop {
return nil
}
}
}
}
// signal that no more cache keys are coming
stop, err := fn("", nil)
if err != nil {
return err
}
if stop {
return nil
}
pw, _, ctx := progress.FromContext(ctx, progress.WithMetadata("vertex", g.Digest()))
@ -316,18 +370,105 @@ func (s *Solver) getRefs(ctx context.Context, g *vertex) (retRef []Reference, re
g.notifyCompleted(ctx, false, retErr)
}()
return state.GetRefs(ctx, func(ctx context.Context, op Op) ([]Reference, error) {
ref, err := state.GetRefs(ctx, index, func(ctx context.Context, op Op) ([]Reference, error) {
refs, err := op.Run(ctx, inputRefs)
if err != nil {
return nil, err
}
if s.cache != nil {
mainInputKeys := firstKeys(inputCacheKeys)
for i, ref := range refs {
if err := s.cache.Set(fmt.Sprintf("%s.%d", cacheKey, i), originRef(ref)); err != nil {
logrus.Errorf("failed to save cache for %s: %v", cacheKey, err)
if ref != nil {
cacheKey, err := s.getCacheKey(ctx, g, mainInputKeys, Index(i))
if err != nil {
return nil, err
}
r := originRef(ref)
if err := s.cache.Set(cacheKey, r); err != nil {
logrus.Errorf("failed to save cache for %s: %v", cacheKey, err)
}
for _, k := range contentKeys {
if err := s.cache.SetContentMapping(contentKeyWithIndex(k, Index(i)), r); err != nil {
logrus.Errorf("failed to save content mapping: %v", err)
}
}
}
}
}
return refs, nil
})
if err != nil {
return err
}
// return reference
_, err = fn("", ref)
if err != nil {
return err
}
return nil
}
func (s *Solver) getRef(ctx context.Context, g *vertex, index Index) (ref Reference, retErr error) {
logrus.Debugf("> getRef %s %v %s", g.digest, index, g.name)
defer logrus.Debugf("< getRef %s %v", g.digest, index)
var returnRef Reference
err := s.walkVertex(ctx, g, index, func(ck digest.Digest, ref Reference) (bool, error) {
if ref != nil {
returnRef = ref
return true, nil
}
if ck == "" {
return false, nil
}
lookupRef, err := s.cache.Lookup(ctx, ck)
if err != nil {
return false, err
}
if lookupRef != nil {
g.recursiveMarkCached(ctx)
returnRef = lookupRef.(Reference)
return true, nil
}
return false, nil
})
if err != nil {
return nil, err
}
return returnRef, nil
}
func firstKeys(inp [][]digest.Digest) []digest.Digest {
var out []digest.Digest
for _, v := range inp {
out = append(out, v[0])
}
return out
}
func combinations(inp [][]digest.Digest) [][]digest.Digest {
var out [][]digest.Digest
if len(inp) == 0 {
return inp
}
if len(inp) == 1 {
for _, v := range inp[0] {
out = append(out, []digest.Digest{v})
}
return out
}
for _, v := range inp[0] {
for _, v2 := range combinations(inp[1:]) {
out = append(out, append([]digest.Digest{v}, v2...))
}
}
return out
}
func contentKeyWithIndex(dgst digest.Digest, index Index) digest.Digest {
return digest.FromBytes([]byte(fmt.Sprintf("%s.%d", dgst, index)))
}

View File

@ -5,9 +5,12 @@ import (
"github.com/moby/buildkit/solver/pb"
"github.com/moby/buildkit/source"
digest "github.com/opencontainers/go-digest"
"golang.org/x/net/context"
)
const sourceCacheType = "buildkit.source.v0"
type sourceOp struct {
mu sync.Mutex
op *pb.Op_Source
@ -58,13 +61,16 @@ func (s *sourceOp) instance(ctx context.Context) (source.SourceInstance, error)
return s.src, nil
}
func (s *sourceOp) CacheKey(ctx context.Context, _ []string) (string, int, error) {
func (s *sourceOp) CacheKey(ctx context.Context) (digest.Digest, error) {
src, err := s.instance(ctx)
if err != nil {
return "", 0, err
return "", err
}
k, err := src.CacheKey(ctx)
return k, 1, err
if err != nil {
return "", err
}
return digest.FromBytes([]byte(sourceCacheType + ":" + k)), nil
}
func (s *sourceOp) Run(ctx context.Context, _ []Reference) ([]Reference, error) {
@ -78,3 +84,7 @@ func (s *sourceOp) Run(ctx context.Context, _ []Reference) ([]Reference, error)
}
return []Reference{ref}, nil
}
func (s *sourceOp) ContentKeys(context.Context, [][]digest.Digest, []Reference) ([]digest.Digest, error) {
return nil, nil
}

View File

@ -25,8 +25,7 @@ type state struct {
key digest.Digest
jobs map[*job]struct{}
refs []*sharedRef
cacheKey string
numRefs int
cacheKey digest.Digest
op Op
progressCtx context.Context
cacheCtx context.Context
@ -75,7 +74,7 @@ func (s *activeState) cancel(j *job) {
}
}
func (s *state) GetRefs(ctx context.Context, cb func(context.Context, Op) ([]Reference, error)) ([]Reference, error) {
func (s *state) GetRefs(ctx context.Context, index Index, cb func(context.Context, Op) ([]Reference, error)) (Reference, error) {
_, err := s.Do(ctx, s.key.String(), func(doctx context.Context) (interface{}, error) {
if s.refs != nil {
if err := writeProgressSnapshot(s.progressCtx, ctx); err != nil {
@ -98,14 +97,10 @@ func (s *state) GetRefs(ctx context.Context, cb func(context.Context, Op) ([]Ref
if err != nil {
return nil, err
}
refs := make([]Reference, 0, len(s.refs))
for _, r := range s.refs {
refs = append(refs, r.Clone())
}
return refs, nil
return s.refs[int(index)].Clone(), nil
}
func (s *state) GetCacheKey(ctx context.Context, cb func(context.Context, Op) (string, int, error)) (string, int, error) {
func (s *state) GetCacheKey(ctx context.Context, cb func(context.Context, Op) (digest.Digest, error)) (digest.Digest, error) {
_, err := s.Do(ctx, "cache:"+s.key.String(), func(doctx context.Context) (interface{}, error) {
if s.cacheKey != "" {
if err := writeProgressSnapshot(s.cacheCtx, ctx); err != nil {
@ -113,19 +108,18 @@ func (s *state) GetCacheKey(ctx context.Context, cb func(context.Context, Op) (s
}
return nil, nil
}
cacheKey, numRefs, err := cb(doctx, s.op)
cacheKey, err := cb(doctx, s.op)
if err != nil {
return nil, err
}
s.cacheKey = cacheKey
s.numRefs = numRefs
s.cacheCtx = doctx
return nil, nil
})
if err != nil {
return "", 0, err
return "", err
}
return s.cacheKey, s.numRefs, nil
return s.cacheKey, nil
}
func writeProgressSnapshot(srcCtx, destCtx context.Context) error {

View File

@ -22,14 +22,16 @@ type Vertex interface {
Name() string // change this to general metadata
}
type Index int
// Input is an pointer to a single reference from a vertex by an index.
type Input struct {
Index int
Index Index
Vertex Vertex
}
type input struct {
index int
index Index
vertex *vertex
}

View File

@ -6,6 +6,7 @@ import (
"github.com/boltdb/bolt"
"github.com/moby/buildkit/cache"
"github.com/moby/buildkit/cache/contenthash"
"github.com/moby/buildkit/cache/metadata"
"github.com/moby/buildkit/session"
"github.com/moby/buildkit/session/filesync"
@ -14,6 +15,7 @@ import (
"github.com/moby/buildkit/util/progress"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/tonistiigi/fsutil"
"golang.org/x/net/context"
"golang.org/x/time/rate"
)
@ -139,12 +141,17 @@ func (ls *localSourceHandler) Snapshot(ctx context.Context) (out cache.Immutable
}
}()
cc, err := contenthash.GetCacheContext(ctx, mutable.Metadata())
if err != nil {
return nil, err
}
opt := filesync.FSSendRequestOpt{
Name: ls.src.Name,
IncludePatterns: nil,
OverrideExcludes: false,
DestDir: dest,
CacheUpdater: nil,
CacheUpdater: &cacheUpdater{cc},
ProgressCb: newProgressHandler(ctx, "transferring "+ls.src.Name+":"),
}
@ -157,6 +164,10 @@ func (ls *localSourceHandler) Snapshot(ctx context.Context) (out cache.Immutable
}
lm = nil
if err := contenthash.SetCacheContext(ctx, mutable.Metadata(), cc); err != nil {
return nil, err
}
// skip storing snapshot by the shared key if it already exists
skipStoreSharedKey := false
si, _ := ls.md.Get(mutable.ID())
@ -185,6 +196,7 @@ func (ls *localSourceHandler) Snapshot(ctx context.Context) (out cache.Immutable
if err != nil {
return nil, err
}
mutable = nil // avoid deferred cleanup
return snap, nil
@ -213,3 +225,14 @@ func newProgressHandler(ctx context.Context, id string) func(int, bool) {
}
}
}
type cacheUpdater struct {
contenthash.CacheContext
}
func (cu *cacheUpdater) MarkSupported(bool) {
}
func (cu *cacheUpdater) ContentHasher() fsutil.ContentHasher {
return contenthash.NewFromStat
}

View File

@ -7,7 +7,7 @@ github.com/pmezard/go-difflib v1.0.0
golang.org/x/sys 739734461d1c916b6c72a63d7efda2b27edb369f
github.com/containerd/containerd 036232856fb8f088a844b22f3330bcddb5d44c0a
golang.org/x/sync 450f422ab23cf9881c94e2db30cac0eb1b7cf80c
golang.org/x/sync f52d1811a62927559de87708c8913c1650ce4f26
github.com/sirupsen/logrus v1.0.0
google.golang.org/grpc v1.3.0
github.com/opencontainers/go-digest 21dfd564fd89c944783d00d069f33e3e7123c448
@ -39,5 +39,5 @@ github.com/pkg/profile 5b67d428864e92711fcbd2f8629456121a56d91f
github.com/tonistiigi/fsutil 195d62bee906e45aa700b8ebeb3417f7b126bb23
github.com/stevvooe/continuity 86cec1535a968310e7532819f699ff2830ed7463
github.com/dmcgowan/go-tar 2e2c51242e8993c50445dab7c03c8e7febddd0cf
github.com/hashicorp/go-immutable-radix 8e8ed81f8f0bf1bdd829593fdd5c29922c1ea990
github.com/hashicorp/golang-lru a0d98a5f288019575c6d1f4bb1573fef2d1fcdc4
github.com/hashicorp/go-immutable-radix 826af9ccf0feeee615d546d69b11f8e98da8c8f1 git://github.com/tonistiigi/go-immutable-radix.git
github.com/hashicorp/golang-lru a0d98a5f288019575c6d1f4bb1573fef2d1fcdc4

View File

@ -2,6 +2,7 @@ package iradix
import (
"bytes"
"strings"
"github.com/hashicorp/golang-lru/simplelru"
)
@ -11,7 +12,9 @@ const (
// cache used per transaction. This is used to cache the updates
// to the nodes near the root, while the leaves do not need to be
// cached. This is important for very large transactions to prevent
// the modified cache from growing to be enormous.
// the modified cache from growing to be enormous. This is also used
// to set the max size of the mutation notify maps since those should
// also be bounded in a similar way.
defaultModifiedCache = 8192
)
@ -27,7 +30,11 @@ type Tree struct {
// New returns an empty Tree
func New() *Tree {
t := &Tree{root: &Node{}}
t := &Tree{
root: &Node{
mutateCh: make(chan struct{}),
},
}
return t
}
@ -40,75 +47,208 @@ func (t *Tree) Len() int {
// atomically and returns a new tree when committed. A transaction
// is not thread safe, and should only be used by a single goroutine.
type Txn struct {
root *Node
size int
modified *simplelru.LRU
// root is the modified root for the transaction.
root *Node
// snap is a snapshot of the root node for use if we have to run the
// slow notify algorithm.
snap *Node
// size tracks the size of the tree as it is modified during the
// transaction.
size int
// writable is a cache of writable nodes that have been created during
// the course of the transaction. This allows us to re-use the same
// nodes for further writes and avoid unnecessary copies of nodes that
// have never been exposed outside the transaction. This will only hold
// up to defaultModifiedCache number of entries.
writable *simplelru.LRU
// trackChannels is used to hold channels that need to be notified to
// signal mutation of the tree. This will only hold up to
// defaultModifiedCache number of entries, after which we will set the
// trackOverflow flag, which will cause us to use a more expensive
// algorithm to perform the notifications. Mutation tracking is only
// performed if trackMutate is true.
trackChannels map[chan struct{}]struct{}
trackOverflow bool
trackMutate bool
}
// Txn starts a new transaction that can be used to mutate the tree
func (t *Tree) Txn() *Txn {
txn := &Txn{
root: t.root,
snap: t.root,
size: t.size,
}
return txn
}
// writeNode returns a node to be modified, if the current
// node as already been modified during the course of
// the transaction, it is used in-place.
func (t *Txn) writeNode(n *Node) *Node {
// Ensure the modified set exists
if t.modified == nil {
// TrackMutate can be used to toggle if mutations are tracked. If this is enabled
// then notifications will be issued for affected internal nodes and leaves when
// the transaction is committed.
func (t *Txn) TrackMutate(track bool) {
t.trackMutate = track
}
// trackChannel safely attempts to track the given mutation channel, setting the
// overflow flag if we can no longer track any more. This limits the amount of
// state that will accumulate during a transaction and we have a slower algorithm
// to switch to if we overflow.
func (t *Txn) trackChannel(ch chan struct{}) {
// In overflow, make sure we don't store any more objects.
if t.trackOverflow {
return
}
// If this would overflow the state we reject it and set the flag (since
// we aren't tracking everything that's required any longer).
if len(t.trackChannels) >= defaultModifiedCache {
// Mark that we are in the overflow state
t.trackOverflow = true
// Clear the map so that the channels can be garbage collected. It is
// safe to do this since we have already overflowed and will be using
// the slow notify algorithm.
t.trackChannels = nil
return
}
// Create the map on the fly when we need it.
if t.trackChannels == nil {
t.trackChannels = make(map[chan struct{}]struct{})
}
// Otherwise we are good to track it.
t.trackChannels[ch] = struct{}{}
}
// writeNode returns a node to be modified, if the current node has already been
// modified during the course of the transaction, it is used in-place. Set
// forLeafUpdate to true if you are getting a write node to update the leaf,
// which will set leaf mutation tracking appropriately as well.
func (t *Txn) writeNode(n *Node, forLeafUpdate bool) *Node {
// Ensure the writable set exists.
if t.writable == nil {
lru, err := simplelru.NewLRU(defaultModifiedCache, nil)
if err != nil {
panic(err)
}
t.modified = lru
t.writable = lru
}
// If this node has already been modified, we can
// continue to use it during this transaction.
if _, ok := t.modified.Get(n); ok {
// If this node has already been modified, we can continue to use it
// during this transaction. We know that we don't need to track it for
// a node update since the node is writable, but if this is for a leaf
// update we track it, in case the initial write to this node didn't
// update the leaf.
if _, ok := t.writable.Get(n); ok {
if t.trackMutate && forLeafUpdate && n.leaf != nil {
t.trackChannel(n.leaf.mutateCh)
}
return n
}
// Copy the existing node
nc := new(Node)
// Mark this node as being mutated.
if t.trackMutate {
t.trackChannel(n.mutateCh)
}
// Mark its leaf as being mutated, if appropriate.
if t.trackMutate && forLeafUpdate && n.leaf != nil {
t.trackChannel(n.leaf.mutateCh)
}
// Copy the existing node. If you have set forLeafUpdate it will be
// safe to replace this leaf with another after you get your node for
// writing. You MUST replace it, because the channel associated with
// this leaf will be closed when this transaction is committed.
nc := &Node{
mutateCh: make(chan struct{}),
leaf: n.leaf,
}
if n.prefix != nil {
nc.prefix = make([]byte, len(n.prefix))
copy(nc.prefix, n.prefix)
}
if n.leaf != nil {
nc.leaf = new(leafNode)
*nc.leaf = *n.leaf
}
if len(n.edges) != 0 {
nc.edges = make([]edge, len(n.edges))
copy(nc.edges, n.edges)
}
// Mark this node as modified
t.modified.Add(n, nil)
// Mark this node as writable.
t.writable.Add(nc, nil)
return nc
}
// Visit all the nodes in the tree under n, and add their mutateChannels to the transaction
// Returns the size of the subtree visited
func (t *Txn) trackChannelsAndCount(n *Node) int {
// Count only leaf nodes
leaves := 0
if n.leaf != nil {
leaves = 1
}
// Mark this node as being mutated.
if t.trackMutate {
t.trackChannel(n.mutateCh)
}
// Mark its leaf as being mutated, if appropriate.
if t.trackMutate && n.leaf != nil {
t.trackChannel(n.leaf.mutateCh)
}
// Recurse on the children
for _, e := range n.edges {
leaves += t.trackChannelsAndCount(e.node)
}
return leaves
}
// mergeChild is called to collapse the given node with its child. This is only
// called when the given node is not a leaf and has a single edge.
func (t *Txn) mergeChild(n *Node) {
// Mark the child node as being mutated since we are about to abandon
// it. We don't need to mark the leaf since we are retaining it if it
// is there.
e := n.edges[0]
child := e.node
if t.trackMutate {
t.trackChannel(child.mutateCh)
}
// Merge the nodes.
n.prefix = concat(n.prefix, child.prefix)
n.leaf = child.leaf
if len(child.edges) != 0 {
n.edges = make([]edge, len(child.edges))
copy(n.edges, child.edges)
} else {
n.edges = nil
}
}
// insert does a recursive insertion
func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface{}, bool) {
// Handle key exhaution
// Handle key exhaustion
if len(search) == 0 {
nc := t.writeNode(n)
var oldVal interface{}
didUpdate := false
if n.isLeaf() {
old := nc.leaf.val
nc.leaf.val = v
return nc, old, true
} else {
nc.leaf = &leafNode{
key: k,
val: v,
}
return nc, nil, false
oldVal = n.leaf.val
didUpdate = true
}
nc := t.writeNode(n, true)
nc.leaf = &leafNode{
mutateCh: make(chan struct{}),
key: k,
val: v,
}
return nc, oldVal, didUpdate
}
// Look for the edge
@ -119,14 +259,16 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface
e := edge{
label: search[0],
node: &Node{
mutateCh: make(chan struct{}),
leaf: &leafNode{
key: k,
val: v,
mutateCh: make(chan struct{}),
key: k,
val: v,
},
prefix: search,
},
}
nc := t.writeNode(n)
nc := t.writeNode(n, false)
nc.addEdge(e)
return nc, nil, false
}
@ -137,7 +279,7 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface
search = search[commonPrefix:]
newChild, oldVal, didUpdate := t.insert(child, k, search, v)
if newChild != nil {
nc := t.writeNode(n)
nc := t.writeNode(n, false)
nc.edges[idx].node = newChild
return nc, oldVal, didUpdate
}
@ -145,9 +287,10 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface
}
// Split the node
nc := t.writeNode(n)
nc := t.writeNode(n, false)
splitNode := &Node{
prefix: search[:commonPrefix],
mutateCh: make(chan struct{}),
prefix: search[:commonPrefix],
}
nc.replaceEdge(edge{
label: search[0],
@ -155,7 +298,7 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface
})
// Restore the existing child node
modChild := t.writeNode(child)
modChild := t.writeNode(child, false)
splitNode.addEdge(edge{
label: modChild.prefix[commonPrefix],
node: modChild,
@ -164,8 +307,9 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface
// Create a new leaf node
leaf := &leafNode{
key: k,
val: v,
mutateCh: make(chan struct{}),
key: k,
val: v,
}
// If the new key is a subset, add to to this node
@ -179,8 +323,9 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface
splitNode.addEdge(edge{
label: search[0],
node: &Node{
leaf: leaf,
prefix: search,
mutateCh: make(chan struct{}),
leaf: leaf,
prefix: search,
},
})
return nc, nil, false
@ -188,19 +333,19 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface
// delete does a recursive deletion
func (t *Txn) delete(parent, n *Node, search []byte) (*Node, *leafNode) {
// Check for key exhaution
// Check for key exhaustion
if len(search) == 0 {
if !n.isLeaf() {
return nil, nil
}
// Remove the leaf node
nc := t.writeNode(n)
nc := t.writeNode(n, true)
nc.leaf = nil
// Check if this node should be merged
if n != t.root && len(nc.edges) == 1 {
nc.mergeChild()
t.mergeChild(nc)
}
return nc, n.leaf
}
@ -219,14 +364,17 @@ func (t *Txn) delete(parent, n *Node, search []byte) (*Node, *leafNode) {
return nil, nil
}
// Copy this node
nc := t.writeNode(n)
// Copy this node. WATCH OUT - it's safe to pass "false" here because we
// will only ADD a leaf via nc.mergeChild() if there isn't one due to
// the !nc.isLeaf() check in the logic just below. This is pretty subtle,
// so be careful if you change any of the logic here.
nc := t.writeNode(n, false)
// Delete the edge if the node has no edges
if newChild.leaf == nil && len(newChild.edges) == 0 {
nc.delEdge(label)
if n != t.root && len(nc.edges) == 1 && !nc.isLeaf() {
nc.mergeChild()
t.mergeChild(nc)
}
} else {
nc.edges[idx].node = newChild
@ -234,6 +382,56 @@ func (t *Txn) delete(parent, n *Node, search []byte) (*Node, *leafNode) {
return nc, leaf
}
// delete does a recursive deletion
func (t *Txn) deletePrefix(parent, n *Node, search []byte) (*Node, int) {
// Check for key exhaustion
if len(search) == 0 {
nc := t.writeNode(n, true)
if n.isLeaf() {
nc.leaf = nil
}
nc.edges = nil
return nc, t.trackChannelsAndCount(n)
}
// Look for an edge
label := search[0]
idx, child := n.getEdge(label)
// We make sure that either the child node's prefix starts with the search term, or the search term starts with the child node's prefix
// Need to do both so that we can delete prefixes that don't correspond to any node in the tree
if child == nil || (!bytes.HasPrefix(child.prefix, search) && !bytes.HasPrefix(search, child.prefix)) {
return nil, 0
}
// Consume the search prefix
if len(child.prefix) > len(search) {
search = []byte("")
} else {
search = search[len(child.prefix):]
}
newChild, numDeletions := t.deletePrefix(n, child, search)
if newChild == nil {
return nil, 0
}
// Copy this node. WATCH OUT - it's safe to pass "false" here because we
// will only ADD a leaf via nc.mergeChild() if there isn't one due to
// the !nc.isLeaf() check in the logic just below. This is pretty subtle,
// so be careful if you change any of the logic here.
nc := t.writeNode(n, false)
// Delete the edge if the node has no edges
if newChild.leaf == nil && len(newChild.edges) == 0 {
nc.delEdge(label)
if n != t.root && len(nc.edges) == 1 && !nc.isLeaf() {
t.mergeChild(nc)
}
} else {
nc.edges[idx].node = newChild
}
return nc, numDeletions
}
// Insert is used to add or update a given key. The return provides
// the previous value and a bool indicating if any was set.
func (t *Txn) Insert(k []byte, v interface{}) (interface{}, bool) {
@ -261,6 +459,19 @@ func (t *Txn) Delete(k []byte) (interface{}, bool) {
return nil, false
}
// DeletePrefix is used to delete an entire subtree that matches the prefix
// This will delete all nodes under that prefix
func (t *Txn) DeletePrefix(prefix []byte) bool {
newRoot, numDeletions := t.deletePrefix(nil, t.root, prefix)
if newRoot != nil {
t.root = newRoot
t.size = t.size - numDeletions
return true
}
return false
}
// Root returns the current root of the radix tree within this
// transaction. The root is not safe across insert and delete operations,
// but can be used to read the current state during a transaction.
@ -274,10 +485,115 @@ func (t *Txn) Get(k []byte) (interface{}, bool) {
return t.root.Get(k)
}
// Commit is used to finalize the transaction and return a new tree
// GetWatch is used to lookup a specific key, returning
// the watch channel, value and if it was found
func (t *Txn) GetWatch(k []byte) (<-chan struct{}, interface{}, bool) {
return t.root.GetWatch(k)
}
// Commit is used to finalize the transaction and return a new tree. If mutation
// tracking is turned on then notifications will also be issued.
func (t *Txn) Commit() *Tree {
t.modified = nil
return &Tree{t.root, t.size}
nt := t.CommitOnly()
if t.trackMutate {
t.Notify()
}
return nt
}
// CommitOnly is used to finalize the transaction and return a new tree, but
// does not issue any notifications until Notify is called.
func (t *Txn) CommitOnly() *Tree {
nt := &Tree{t.root, t.size}
t.writable = nil
return nt
}
// slowNotify does a complete comparison of the before and after trees in order
// to trigger notifications. This doesn't require any additional state but it
// is very expensive to compute.
func (t *Txn) slowNotify() {
snapIter := t.snap.rawIterator()
rootIter := t.root.rawIterator()
for snapIter.Front() != nil || rootIter.Front() != nil {
// If we've exhausted the nodes in the old snapshot, we know
// there's nothing remaining to notify.
if snapIter.Front() == nil {
return
}
snapElem := snapIter.Front()
// If we've exhausted the nodes in the new root, we know we need
// to invalidate everything that remains in the old snapshot. We
// know from the loop condition there's something in the old
// snapshot.
if rootIter.Front() == nil {
close(snapElem.mutateCh)
if snapElem.isLeaf() {
close(snapElem.leaf.mutateCh)
}
snapIter.Next()
continue
}
// Do one string compare so we can check the various conditions
// below without repeating the compare.
cmp := strings.Compare(snapIter.Path(), rootIter.Path())
// If the snapshot is behind the root, then we must have deleted
// this node during the transaction.
if cmp < 0 {
close(snapElem.mutateCh)
if snapElem.isLeaf() {
close(snapElem.leaf.mutateCh)
}
snapIter.Next()
continue
}
// If the snapshot is ahead of the root, then we must have added
// this node during the transaction.
if cmp > 0 {
rootIter.Next()
continue
}
// If we have the same path, then we need to see if we mutated a
// node and possibly the leaf.
rootElem := rootIter.Front()
if snapElem != rootElem {
close(snapElem.mutateCh)
if snapElem.leaf != nil && (snapElem.leaf != rootElem.leaf) {
close(snapElem.leaf.mutateCh)
}
}
snapIter.Next()
rootIter.Next()
}
}
// Notify is used along with TrackMutate to trigger notifications. This must
// only be done once a transaction is committed via CommitOnly, and it is called
// automatically by Commit.
func (t *Txn) Notify() {
if !t.trackMutate {
return
}
// If we've overflowed the tracking state we can't use it in any way and
// need to do a full tree compare.
if t.trackOverflow {
t.slowNotify()
} else {
for ch := range t.trackChannels {
close(ch)
}
}
// Clean up the tracking state so that a re-notify is safe (will trigger
// the else clause above which will be a no-op).
t.trackChannels = nil
t.trackOverflow = false
}
// Insert is used to add or update a given key. The return provides
@ -296,6 +612,14 @@ func (t *Tree) Delete(k []byte) (*Tree, interface{}, bool) {
return txn.Commit(), old, ok
}
// DeletePrefix is used to delete all nodes starting with a given prefix. Returns the new tree,
// and a bool indicating if the prefix matched any nodes
func (t *Tree) DeletePrefix(k []byte) (*Tree, bool) {
txn := t.Txn()
ok := txn.DeletePrefix(k)
return txn.Commit(), ok
}
// Root returns the root node of the tree which can be used for richer
// query operations.
func (t *Tree) Root() *Node {

View File

@ -9,11 +9,13 @@ type Iterator struct {
stack []edges
}
// SeekPrefix is used to seek the iterator to a given prefix
func (i *Iterator) SeekPrefix(prefix []byte) {
// SeekPrefixWatch is used to seek the iterator to a given prefix
// and returns the watch channel of the finest granularity
func (i *Iterator) SeekPrefixWatch(prefix []byte) (watch <-chan struct{}) {
// Wipe the stack
i.stack = nil
n := i.node
watch = n.mutateCh
search := prefix
for {
// Check for key exhaution
@ -29,6 +31,9 @@ func (i *Iterator) SeekPrefix(prefix []byte) {
return
}
// Update to the finest granularity as the search makes progress
watch = n.mutateCh
// Consume the search prefix
if bytes.HasPrefix(search, n.prefix) {
search = search[len(n.prefix):]
@ -43,6 +48,11 @@ func (i *Iterator) SeekPrefix(prefix []byte) {
}
}
// SeekPrefix is used to seek the iterator to a given prefix
func (i *Iterator) SeekPrefix(prefix []byte) {
i.SeekPrefixWatch(prefix)
}
// Next returns the next node in order
func (i *Iterator) Next() ([]byte, interface{}, bool) {
// Initialize our stack if needed

View File

@ -12,8 +12,9 @@ type WalkFn func(k []byte, v interface{}) bool
// leafNode is used to represent a value
type leafNode struct {
key []byte
val interface{}
mutateCh chan struct{}
key []byte
val interface{}
}
// edge is used to represent an edge node
@ -24,6 +25,9 @@ type edge struct {
// Node is an immutable node in the radix tree
type Node struct {
// mutateCh is closed if this node is modified
mutateCh chan struct{}
// leaf is used to store possible leaf
leaf *leafNode
@ -87,31 +91,14 @@ func (n *Node) delEdge(label byte) {
}
}
func (n *Node) mergeChild() {
e := n.edges[0]
child := e.node
n.prefix = concat(n.prefix, child.prefix)
if child.leaf != nil {
n.leaf = new(leafNode)
*n.leaf = *child.leaf
} else {
n.leaf = nil
}
if len(child.edges) != 0 {
n.edges = make([]edge, len(child.edges))
copy(n.edges, child.edges)
} else {
n.edges = nil
}
}
func (n *Node) Get(k []byte) (interface{}, bool) {
func (n *Node) GetWatch(k []byte) (<-chan struct{}, interface{}, bool) {
search := k
watch := n.mutateCh
for {
// Check for key exhaution
// Check for key exhaustion
if len(search) == 0 {
if n.isLeaf() {
return n.leaf.val, true
return n.leaf.mutateCh, n.leaf.val, true
}
break
}
@ -122,6 +109,9 @@ func (n *Node) Get(k []byte) (interface{}, bool) {
break
}
// Update to the finest granularity as the search makes progress
watch = n.mutateCh
// Consume the search prefix
if bytes.HasPrefix(search, n.prefix) {
search = search[len(n.prefix):]
@ -129,7 +119,12 @@ func (n *Node) Get(k []byte) (interface{}, bool) {
break
}
}
return nil, false
return watch, nil, false
}
func (n *Node) Get(k []byte) (interface{}, bool) {
_, val, ok := n.GetWatch(k)
return val, ok
}
// LongestPrefix is like Get, but instead of an
@ -204,6 +199,14 @@ func (n *Node) Iterator() *Iterator {
return &Iterator{node: n}
}
// rawIterator is used to return a raw iterator at the given node to walk the
// tree.
func (n *Node) rawIterator() *rawIterator {
iter := &rawIterator{node: n}
iter.Next()
return iter
}
// Walk is used to walk the tree
func (n *Node) Walk(fn WalkFn) {
recursiveWalk(n, fn)
@ -271,6 +274,66 @@ func (n *Node) WalkPath(path []byte, fn WalkFn) {
}
}
func (n *Node) Seek(prefix []byte) *Seeker {
search := prefix
p := &pos{n: n}
for {
// Check for key exhaution
if len(search) == 0 {
return &Seeker{p}
}
num := len(n.edges)
idx := sort.Search(num, func(i int) bool {
return n.edges[i].label >= search[0]
})
p.current = idx
if idx < len(n.edges) {
n = n.edges[idx].node
if bytes.HasPrefix(search, n.prefix) && len(n.edges) > 0 {
search = search[len(n.prefix):]
p.current++
p = &pos{n: n, prev: p}
continue
}
}
p.current++
return &Seeker{p}
}
}
type Seeker struct {
*pos
}
type pos struct {
n *Node
current int
prev *pos
isLeaf bool
}
func (s *Seeker) Next() (k []byte, v interface{}, ok bool) {
if s.current >= len(s.n.edges) {
if s.prev == nil {
return nil, nil, false
}
s.pos = s.prev
return s.Next()
}
edge := s.n.edges[s.current]
s.current++
if edge.node.leaf != nil && !s.isLeaf {
s.isLeaf = true
s.current--
return edge.node.leaf.key, edge.node.leaf.val, true
}
s.isLeaf = false
s.pos = &pos{n: edge.node, prev: s.pos}
return s.Next()
}
// recursiveWalk is used to do a pre-order walk of a node
// recursively. Returns true if the walk should be aborted
func recursiveWalk(n *Node, fn WalkFn) bool {

View File

@ -0,0 +1,78 @@
package iradix
// rawIterator visits each of the nodes in the tree, even the ones that are not
// leaves. It keeps track of the effective path (what a leaf at a given node
// would be called), which is useful for comparing trees.
type rawIterator struct {
// node is the starting node in the tree for the iterator.
node *Node
// stack keeps track of edges in the frontier.
stack []rawStackEntry
// pos is the current position of the iterator.
pos *Node
// path is the effective path of the current iterator position,
// regardless of whether the current node is a leaf.
path string
}
// rawStackEntry is used to keep track of the cumulative common path as well as
// its associated edges in the frontier.
type rawStackEntry struct {
path string
edges edges
}
// Front returns the current node that has been iterated to.
func (i *rawIterator) Front() *Node {
return i.pos
}
// Path returns the effective path of the current node, even if it's not actually
// a leaf.
func (i *rawIterator) Path() string {
return i.path
}
// Next advances the iterator to the next node.
func (i *rawIterator) Next() {
// Initialize our stack if needed.
if i.stack == nil && i.node != nil {
i.stack = []rawStackEntry{
rawStackEntry{
edges: edges{
edge{node: i.node},
},
},
}
}
for len(i.stack) > 0 {
// Inspect the last element of the stack.
n := len(i.stack)
last := i.stack[n-1]
elem := last.edges[0].node
// Update the stack.
if len(last.edges) > 1 {
i.stack[n-1].edges = last.edges[1:]
} else {
i.stack = i.stack[:n-1]
}
// Push the edges onto the frontier.
if len(elem.edges) > 0 {
path := last.path + string(elem.prefix)
i.stack = append(i.stack, rawStackEntry{path, elem.edges})
}
i.pos = elem
i.path = last.path + string(elem.prefix)
return
}
i.pos = nil
i.path = ""
}