From 8738929b8c39facdb354595a6bb7c6825f6688f4 Mon Sep 17 00:00:00 2001 From: Tonis Tiigi Date: Mon, 31 Jul 2017 15:06:59 -0700 Subject: [PATCH] solver: implement content based cache support Signed-off-by: Tonis Tiigi --- cache/contenthash/checksum.go | 101 ++-- cache/contenthash/checksum_test.go | 2 +- cache/instructioncache/cache.go | 52 ++- cache/metadata.go | 1 - cache/metadata/metadata.go | 4 + cache/refs.go | 1 + session/filesync/diffcopy.go | 5 +- session/filesync/filesync.go | 1 + solver/build.go | 35 +- solver/exec.go | 95 +++- solver/load.go | 2 +- solver/solver.go | 431 +++++++++++------ solver/source.go | 16 +- solver/state.go | 20 +- solver/vertex.go | 6 +- source/local/local.go | 25 +- vendor.conf | 6 +- .../hashicorp/go-immutable-radix/iradix.go | 432 +++++++++++++++--- .../hashicorp/go-immutable-radix/iter.go | 14 +- .../hashicorp/go-immutable-radix/node.go | 111 ++++- .../hashicorp/go-immutable-radix/raw_iter.go | 78 ++++ 21 files changed, 1119 insertions(+), 319 deletions(-) create mode 100644 vendor/github.com/hashicorp/go-immutable-radix/raw_iter.go diff --git a/cache/contenthash/checksum.go b/cache/contenthash/checksum.go index ce21b7fa..cde86077 100644 --- a/cache/contenthash/checksum.go +++ b/cache/contenthash/checksum.go @@ -47,12 +47,12 @@ func Checksum(ctx context.Context, ref cache.ImmutableRef, path string) (digest. return getDefaultManager().Checksum(ctx, ref, path) } -func GetCacheContext(ctx context.Context, ref cache.ImmutableRef) (CacheContext, error) { - return getDefaultManager().GetCacheContext(ctx, ref) +func GetCacheContext(ctx context.Context, md *metadata.StorageItem) (CacheContext, error) { + return getDefaultManager().GetCacheContext(ctx, md) } -func SetCacheContext(ctx context.Context, ref cache.ImmutableRef, cc CacheContext) error { - return getDefaultManager().SetCacheContext(ctx, ref, cc) +func SetCacheContext(ctx context.Context, md *metadata.StorageItem, cc CacheContext) error { + return getDefaultManager().SetCacheContext(ctx, md, cc) } type CacheContext interface { @@ -67,45 +67,57 @@ type Hashed interface { type cacheManager struct { locker *locker.Locker lru *simplelru.LRU + lruMu sync.Mutex } func (cm *cacheManager) Checksum(ctx context.Context, ref cache.ImmutableRef, p string) (digest.Digest, error) { - cc, err := cm.GetCacheContext(ctx, ref) + cc, err := cm.GetCacheContext(ctx, ensureOriginMetadata(ref.Metadata())) if err != nil { return "", nil } return cc.Checksum(ctx, ref, p) } -func (cm *cacheManager) GetCacheContext(ctx context.Context, ref cache.ImmutableRef) (CacheContext, error) { - cm.locker.Lock(ref.ID()) - v, ok := cm.lru.Get(ref.ID()) +func (cm *cacheManager) GetCacheContext(ctx context.Context, md *metadata.StorageItem) (CacheContext, error) { + cm.locker.Lock(md.ID()) + cm.lruMu.Lock() + v, ok := cm.lru.Get(md.ID()) + cm.lruMu.Unlock() if ok { - cm.locker.Unlock(ref.ID()) + cm.locker.Unlock(md.ID()) return v.(*cacheContext), nil } - cc, err := newCacheContext(ref.Metadata()) + cc, err := newCacheContext(md) if err != nil { - cm.locker.Unlock(ref.ID()) + cm.locker.Unlock(md.ID()) return nil, err } - cm.lru.Add(ref.ID(), cc) - cm.locker.Unlock(ref.ID()) + cm.lruMu.Lock() + cm.lru.Add(md.ID(), cc) + cm.lruMu.Unlock() + cm.locker.Unlock(md.ID()) return cc, nil } -func (cm *cacheManager) SetCacheContext(ctx context.Context, ref cache.ImmutableRef, cci CacheContext) error { +func (cm *cacheManager) SetCacheContext(ctx context.Context, md *metadata.StorageItem, cci CacheContext) error { cc, ok := cci.(*cacheContext) if !ok { return errors.Errorf("invalid cachecontext: %T", cc) } - if ref.ID() != cc.md.ID() { - return errors.New("saving cachecontext under different ID not supported") + if md.ID() != cc.md.ID() { + cc = &cacheContext{ + md: md, + tree: cci.(*cacheContext).tree, + dirtyMap: map[string]struct{}{}, + } + } else { + if err := cc.save(); err != nil { + return err + } } - if err := cc.save(); err != nil { - return err - } - cm.lru.Add(ref.ID(), cc) + cm.lruMu.Lock() + cm.lru.Add(md.ID(), cc) + cm.lruMu.Unlock() return nil } @@ -193,7 +205,9 @@ func (cc *cacheContext) save() error { cc.mu.Lock() defer cc.mu.Unlock() - cc.dirty = true + if cc.txn != nil { + cc.commitActiveTransaction() + } var l CacheRecords node := cc.tree.Root() @@ -231,10 +245,23 @@ func (cc *cacheContext) HandleChange(kind fsutil.ChangeKind, p string, fi os.Fil } cc.mu.Lock() + defer cc.mu.Unlock() if cc.txn == nil { cc.txn = cc.tree.Txn() cc.node = cc.tree.Root() + + // root is not called by HandleChange. need to fake it + if _, ok := cc.node.Get([]byte("/")); !ok { + cc.txn.Insert([]byte("/"), &CacheRecord{ + Type: CacheRecordTypeDirHeader, + Digest: digest.FromBytes(nil), + }) + cc.txn.Insert([]byte(""), &CacheRecord{ + Type: CacheRecordTypeDir, + }) + } } + if kind == fsutil.ChangeKindDelete { v, ok := cc.txn.Delete(k) if ok { @@ -245,7 +272,6 @@ func (cc *cacheContext) HandleChange(kind fsutil.ChangeKind, p string, fi os.Fil d = "" } cc.dirtyMap[d] = struct{}{} - cc.mu.Unlock() return } @@ -256,7 +282,6 @@ func (cc *cacheContext) HandleChange(kind fsutil.ChangeKind, p string, fi os.Fil h, ok := fi.(Hashed) if !ok { - cc.mu.Unlock() return errors.Errorf("invalid fileinfo: %s", p) } @@ -287,7 +312,6 @@ func (cc *cacheContext) HandleChange(kind fsutil.ChangeKind, p string, fi os.Fil d = "" } cc.dirtyMap[d] = struct{}{} - cc.mu.Unlock() return nil } @@ -405,12 +429,12 @@ func (cc *cacheContext) checksum(ctx context.Context, root *iradix.Node, txn *ir switch cr.Type { case CacheRecordTypeDir: h := sha256.New() - iter := root.Iterator() next := append(k, []byte("/")...) - iter.SeekPrefix(next) + iter := root.Seek(next) + subk := next + ok := true for { - subk, _, ok := iter.Next() - if !ok || bytes.Compare(next, subk) > 0 { + if !ok || !bytes.HasPrefix(subk, next) { break } h.Write(bytes.TrimPrefix(subk, k)) @@ -422,9 +446,10 @@ func (cc *cacheContext) checksum(ctx context.Context, root *iradix.Node, txn *ir h.Write([]byte(subcr.Digest)) if subcr.Type == CacheRecordTypeDir { // skip subfiles - next = append(k, []byte("/\xff")...) - iter.SeekPrefix(next) + next := append(subk, []byte("/\xff")...) + iter = root.Seek(next) } + subk, _, ok = iter.Next() } dgst = digest.NewDigest(digest.SHA256, h) default: @@ -565,3 +590,19 @@ func addParentToMap(d string, m map[string]struct{}) { m[d] = struct{}{} addParentToMap(d, m) } + +func ensureOriginMetadata(md *metadata.StorageItem) *metadata.StorageItem { + v := md.Get("cache.equalMutable") // TODO: const + if v == nil { + return md + } + var mutable string + if err := v.Unmarshal(&mutable); err != nil { + return md + } + si, ok := md.Storage().Get(mutable) + if ok { + return &si + } + return md +} diff --git a/cache/contenthash/checksum_test.go b/cache/contenthash/checksum_test.go index 4ef840a0..47abf522 100644 --- a/cache/contenthash/checksum_test.go +++ b/cache/contenthash/checksum_test.go @@ -92,7 +92,7 @@ func TestChecksumBasicFile(t *testing.T) { dgst, err = cc.Checksum(context.TODO(), ref, "/") assert.NoError(t, err) - assert.Equal(t, digest.Digest("sha256:0d87c8c2a606f961483cd4c5dc0350a4136a299b4066eea4a969d6ed756614cd"), dgst) + assert.Equal(t, digest.Digest("sha256:7378af5287e8b417b6cbc63154d300e130983bfc645e35e86fdadf6f5060468a"), dgst) dgst, err = cc.Checksum(context.TODO(), ref, "d0") assert.NoError(t, err) diff --git a/cache/instructioncache/cache.go b/cache/instructioncache/cache.go index 22eb05ed..33912b7d 100644 --- a/cache/instructioncache/cache.go +++ b/cache/instructioncache/cache.go @@ -1,22 +1,26 @@ package instructioncache import ( + "strings" + "github.com/boltdb/bolt" "github.com/moby/buildkit/cache" "github.com/moby/buildkit/cache/metadata" + digest "github.com/opencontainers/go-digest" "github.com/pkg/errors" "github.com/sirupsen/logrus" "golang.org/x/net/context" ) const cacheKey = "buildkit.instructioncache" +const contentCacheKey = "buildkit.instructioncache.content" type LocalStore struct { MetadataStore *metadata.Store Cache cache.Accessor } -func (ls *LocalStore) Set(key string, value interface{}) error { +func (ls *LocalStore) Set(key digest.Digest, value interface{}) error { ref, ok := value.(cache.ImmutableRef) if !ok { return errors.Errorf("invalid ref") @@ -25,21 +29,21 @@ func (ls *LocalStore) Set(key string, value interface{}) error { if err != nil { return err } - v.Index = index(key) + v.Index = index(key.String()) si, _ := ls.MetadataStore.Get(ref.ID()) return si.Update(func(b *bolt.Bucket) error { - return si.SetValue(b, index(key), v) + return si.SetValue(b, v.Index, v) }) } -func (ls *LocalStore) Lookup(ctx context.Context, key string) (interface{}, error) { - snaps, err := ls.MetadataStore.Search(index(key)) +func (ls *LocalStore) Lookup(ctx context.Context, key digest.Digest) (interface{}, error) { + snaps, err := ls.MetadataStore.Search(index(key.String())) if err != nil { return nil, err } for _, s := range snaps { - v := s.Get(index(key)) + v := s.Get(index(key.String())) if v != nil { var id string if err = v.Unmarshal(&id); err != nil { @@ -56,6 +60,42 @@ func (ls *LocalStore) Lookup(ctx context.Context, key string) (interface{}, erro return nil, nil } +func (ls *LocalStore) SetContentMapping(key digest.Digest, value interface{}) error { + ref, ok := value.(cache.ImmutableRef) + if !ok { + return errors.Errorf("invalid ref") + } + v, err := metadata.NewValue(ref.ID()) + if err != nil { + return err + } + v.Index = contentIndex(key.String()) + si, _ := ls.MetadataStore.Get(ref.ID()) + return si.Update(func(b *bolt.Bucket) error { + return si.SetValue(b, v.Index, v) + }) +} + +func (ls *LocalStore) GetContentMapping(key digest.Digest) ([]digest.Digest, error) { + snaps, err := ls.MetadataStore.Search(contentIndex(key.String())) + if err != nil { + return nil, err + } + var out []digest.Digest + for _, s := range snaps { + for _, k := range s.Keys() { + if strings.HasPrefix(k, index("")) { + out = append(out, digest.Digest(strings.TrimPrefix(k, index("")))) // TODO: type + } + } + } + return out, nil +} + func index(k string) string { return cacheKey + "::" + k } + +func contentIndex(k string) string { + return contentCacheKey + "::" + k +} diff --git a/cache/metadata.go b/cache/metadata.go index 0a4083a1..f837fede 100644 --- a/cache/metadata.go +++ b/cache/metadata.go @@ -18,7 +18,6 @@ import ( const sizeUnknown int64 = -1 const keySize = "snapshot.size" const keyEqualMutable = "cache.equalMutable" -const keyEqualImmutable = "cache.equalImmutable" const keyCachePolicy = "cache.cachePolicy" const keyDescription = "cache.description" const keyCreatedAt = "cache.createdAt" diff --git a/cache/metadata/metadata.go b/cache/metadata/metadata.go index 5160aa8c..5af6da13 100644 --- a/cache/metadata/metadata.go +++ b/cache/metadata/metadata.go @@ -205,6 +205,10 @@ func newStorageItem(id string, b *bolt.Bucket, s *Store) (StorageItem, error) { return si, nil } +func (s *StorageItem) Storage() *Store { // TODO: used in local source. how to remove this? + return s.storage +} + func (s *StorageItem) ID() string { return s.id } diff --git a/cache/refs.go b/cache/refs.go index 5e358d70..58e865fa 100644 --- a/cache/refs.go +++ b/cache/refs.go @@ -28,6 +28,7 @@ type MutableRef interface { Commit(context.Context) (ImmutableRef, error) Release(context.Context) error Size(ctx context.Context) (int64, error) + Metadata() *metadata.StorageItem } type Mountable interface { diff --git a/session/filesync/diffcopy.go b/session/filesync/diffcopy.go index 88edfb08..f161154e 100644 --- a/session/filesync/diffcopy.go +++ b/session/filesync/diffcopy.go @@ -21,10 +21,11 @@ func recvDiffCopy(ds grpc.Stream, dest string, cu CacheUpdater, progress progres logrus.Debugf("diffcopy took: %v", time.Since(st)) }() var cf fsutil.ChangeFunc + var ch fsutil.ContentHasher if cu != nil { cu.MarkSupported(true) cf = cu.HandleChange + ch = cu.ContentHasher() } - _ = cf - return fsutil.Receive(ds.Context(), ds, dest, nil, nil, progress) + return fsutil.Receive(ds.Context(), ds, dest, cf, ch, progress) } diff --git a/session/filesync/filesync.go b/session/filesync/filesync.go index 7f02dd15..f7787868 100644 --- a/session/filesync/filesync.go +++ b/session/filesync/filesync.go @@ -147,6 +147,7 @@ type FSSendRequestOpt struct { type CacheUpdater interface { MarkSupported(bool) HandleChange(fsutil.ChangeKind, string, os.FileInfo, error) error + ContentHasher() fsutil.ContentHasher } // FSSync initializes a transfer of files diff --git a/solver/build.go b/solver/build.go index 9459345b..becefe42 100644 --- a/solver/build.go +++ b/solver/build.go @@ -15,6 +15,8 @@ import ( "golang.org/x/net/context" ) +const buildCacheType = "buildkit.build.v0" + type buildOp struct { op *pb.BuildOp s *Solver @@ -29,18 +31,18 @@ func newBuildOp(v Vertex, op *pb.Op_Build, s *Solver) (Op, error) { }, nil } -func (b *buildOp) CacheKey(ctx context.Context, inputs []string) (string, int, error) { +func (b *buildOp) CacheKey(ctx context.Context) (digest.Digest, error) { dt, err := json.Marshal(struct { - Inputs []string - Exec *pb.BuildOp + Type string + Exec *pb.BuildOp }{ - Inputs: inputs, - Exec: b.op, + Type: buildCacheType, + Exec: b.op, }) if err != nil { - return "", 0, err + return "", err } - return digest.FromBytes(dt).String(), 1, nil // TODO: other builders should support many outputs + return digest.FromBytes(dt), nil } func (b *buildOp) Run(ctx context.Context, inputs []Reference) (outputs []Reference, retErr error) { @@ -125,17 +127,14 @@ func (b *buildOp) Run(ctx context.Context, inputs []Reference) (outputs []Refere pw, _, ctx := progress.FromContext(ctx, progress.WithMetadata("parentVertex", b.v.Digest())) defer pw.Close() - refs, err := b.s.getRefs(ctx, vv) - - // filer out only the required ref - var out Reference - for i, r := range refs { - if i == index { - out = r - } else { - go r.Release(context.TODO()) - } + newref, err := b.s.getRef(ctx, vv, index) + if err != nil { + return nil, err } - return []Reference{out}, err + return []Reference{newref}, err +} + +func (b *buildOp) ContentKeys(context.Context, [][]digest.Digest, []Reference) ([]digest.Digest, error) { + return nil, nil } diff --git a/solver/exec.go b/solver/exec.go index da899aae..8302ec60 100644 --- a/solver/exec.go +++ b/solver/exec.go @@ -7,14 +7,18 @@ import ( "strings" "github.com/moby/buildkit/cache" + "github.com/moby/buildkit/cache/contenthash" "github.com/moby/buildkit/solver/pb" "github.com/moby/buildkit/util/progress/logs" "github.com/moby/buildkit/worker" digest "github.com/opencontainers/go-digest" "github.com/pkg/errors" "golang.org/x/net/context" + "golang.org/x/sync/errgroup" ) +const execCacheType = "buildkit.exec.v0" + type execOp struct { op *pb.ExecOp cm cache.Manager @@ -29,25 +33,19 @@ func newExecOp(_ Vertex, op *pb.Op_Exec, cm cache.Manager, w worker.Worker) (Op, }, nil } -func (e *execOp) CacheKey(ctx context.Context, inputs []string) (string, int, error) { +func (e *execOp) CacheKey(ctx context.Context) (digest.Digest, error) { dt, err := json.Marshal(struct { - Inputs []string - Exec *pb.ExecOp + Type string + Exec *pb.ExecOp }{ - Inputs: inputs, - Exec: e.op, + Type: execCacheType, + Exec: e.op, }) if err != nil { - return "", 0, err - } - numRefs := 0 - for _, m := range e.op.Mounts { - if m.Output != pb.SkipOutput { - numRefs++ - } + return "", err } - return digest.FromBytes(dt).String(), numRefs, nil + return digest.FromBytes(dt), nil } func (e *execOp) Run(ctx context.Context, inputs []Reference) ([]Reference, error) { @@ -130,3 +128,74 @@ func (e *execOp) Run(ctx context.Context, inputs []Reference) ([]Reference, erro } return refs, nil } + +func (e *execOp) ContentKeys(ctx context.Context, inputs [][]digest.Digest, refs []Reference) ([]digest.Digest, error) { + if len(refs) == 0 { + return nil, nil + } + // contentKey for exec uses content based checksum for mounts and definition + // based checksum for root + + rootIndex := -1 + skip := true + srcs := make([]string, len(refs)) + for _, m := range e.op.Mounts { + if m.Input != pb.Empty { + if m.Dest != pb.RootMount { + srcs[int(m.Input)] = "/" // TODO: selector + skip = false + } else { + rootIndex = int(m.Input) + } + } + } + if skip { + return nil, nil + } + + dgsts := make([]digest.Digest, len(refs)) + eg, ctx := errgroup.WithContext(ctx) + for i, ref := range refs { + if srcs[i] == "" { + continue + } + func(i int, ref Reference) { + eg.Go(func() error { + ref, ok := toImmutableRef(ref) + if !ok { + return errors.Errorf("invalid reference") + } + dgst, err := contenthash.Checksum(ctx, ref, srcs[i]) + if err != nil { + return err + } + dgsts[i] = dgst + return nil + }) + }(i, ref) + } + if err := eg.Wait(); err != nil { + return nil, err + } + + var out []digest.Digest + for _, cacheKeys := range inputs { + dt, err := json.Marshal(struct { + Type string + Sources []digest.Digest + Root digest.Digest + Exec *pb.ExecOp + }{ + Type: execCacheType, + Sources: dgsts, + Root: cacheKeys[rootIndex], + Exec: e.op, + }) + if err != nil { + return nil, err + } + out = append(out, digest.FromBytes(dt)) + } + + return out, nil +} diff --git a/solver/load.go b/solver/load.go index 77e26a21..b6929bb5 100644 --- a/solver/load.go +++ b/solver/load.go @@ -70,7 +70,7 @@ func loadLLBVertexRecursive(dgst digest.Digest, op *pb.Op, all map[digest.Digest if err != nil { return nil, err } - vtx.inputs = append(vtx.inputs, &input{index: int(in.Index), vertex: sub}) + vtx.inputs = append(vtx.inputs, &input{index: Index(in.Index), vertex: sub}) } vtx.initClientVertex() cache[dgst] = vtx diff --git a/solver/solver.go b/solver/solver.go index 8566bcc5..27dfae58 100644 --- a/solver/solver.go +++ b/solver/solver.go @@ -1,7 +1,9 @@ package solver import ( + "encoding/json" "fmt" + "sync" "github.com/moby/buildkit/cache" "github.com/moby/buildkit/client" @@ -10,6 +12,7 @@ import ( "github.com/moby/buildkit/source" "github.com/moby/buildkit/util/progress" "github.com/moby/buildkit/worker" + digest "github.com/opencontainers/go-digest" "github.com/pkg/errors" "github.com/sirupsen/logrus" "golang.org/x/net/context" @@ -50,13 +53,16 @@ type Reference interface { // Op is an implementation for running a vertex type Op interface { - CacheKey(context.Context, []string) (string, int, error) + CacheKey(context.Context) (digest.Digest, error) + ContentKeys(context.Context, [][]digest.Digest, []Reference) ([]digest.Digest, error) Run(ctx context.Context, inputs []Reference) (outputs []Reference, err error) } type InstructionCache interface { - Lookup(ctx context.Context, key string) (interface{}, error) // TODO: regular ref - Set(key string, ref interface{}) error + Lookup(ctx context.Context, key digest.Digest) (interface{}, error) // TODO: regular ref + Set(key digest.Digest, ref interface{}) error + SetContentMapping(key digest.Digest, value interface{}) error + GetContentMapping(dgst digest.Digest) ([]digest.Digest, error) } type Solver struct { @@ -101,34 +107,25 @@ func (s *Solver) Solve(ctx context.Context, id string, v Vertex, exp exporter.Ex return err } - refs, err := s.getRefs(ctx, solveVertex) + ref, err := s.getRef(ctx, solveVertex, index) s.activeState.cancel(j) if err != nil { return err } defer func() { - for _, r := range refs { - r.Release(context.TODO()) - } + ref.Release(context.TODO()) }() - for _, ref := range refs { - immutable, ok := toImmutableRef(ref) - if !ok { - return errors.Errorf("invalid reference for exporting: %T", ref) - } - if err := immutable.Finalize(ctx); err != nil { - return err - } + immutable, ok := toImmutableRef(ref) + if !ok { + return errors.Errorf("invalid reference for exporting: %T", ref) + } + if err := immutable.Finalize(ctx); err != nil { + return err } if exp != nil { - r := refs[int(index)] - immutable, ok := toImmutableRef(r) - if !ok { - return errors.Errorf("invalid reference for exporting: %T", r) - } vv.notifyStarted(ctx) pw, _, ctx := progress.FromContext(ctx, progress.WithMetadata("vertex", vv.Digest())) defer pw.Close() @@ -150,32 +147,13 @@ func (s *Solver) Status(ctx context.Context, id string, statusChan chan *client. return j.pipe(ctx, statusChan) } -func (s *Solver) getCacheKey(ctx context.Context, g *vertex) (cacheKey string, numRefs int, retErr error) { +// getCacheKey return a cache key for a single output of a vertex +func (s *Solver) getCacheKey(ctx context.Context, g *vertex, inputs []digest.Digest, index Index) (dgst digest.Digest, retErr error) { state, err := s.activeState.vertexState(ctx, g.digest, func() (Op, error) { return s.resolve(g) }) if err != nil { - return "", 0, err - } - - inputs := make([]string, len(g.inputs)) - if len(g.inputs) > 0 { - eg, ctx := errgroup.WithContext(ctx) - for i, in := range g.inputs { - func(i int, in *vertex, index int) { - eg.Go(func() error { - k, _, err := s.getCacheKey(ctx, in) - if err != nil { - return err - } - inputs[i] = fmt.Sprintf("%s.%d", k, index) - return nil - }) - }(i, in.vertex, in.index) - } - if err := eg.Wait(); err != nil { - return "", 0, err - } + return "", err } pw, _, ctx := progress.FromContext(ctx, progress.WithMetadata("vertex", g.Digest())) @@ -188,124 +166,200 @@ func (s *Solver) getCacheKey(ctx context.Context, g *vertex) (cacheKey string, n }() } - return state.GetCacheKey(ctx, func(ctx context.Context, op Op) (string, int, error) { - return op.CacheKey(ctx, inputs) + dgst, err = state.GetCacheKey(ctx, func(ctx context.Context, op Op) (digest.Digest, error) { + return op.CacheKey(ctx) }) + + if err != nil { + return "", err + } + + dt, err := json.Marshal(struct { + Index Index + Inputs []digest.Digest + Digest digest.Digest + }{ + Index: index, + Inputs: inputs, + Digest: dgst, + }) + if err != nil { + return "", err + } + return digest.FromBytes(dt), nil } -func (s *Solver) getRefs(ctx context.Context, g *vertex) (retRef []Reference, retErr error) { +// walkVertex walks all possible cache keys and a evaluated reference for a +// single output of a vertex. +func (s *Solver) walkVertex(ctx context.Context, g *vertex, index Index, fn func(digest.Digest, Reference) (bool, error)) (retErr error) { state, err := s.activeState.vertexState(ctx, g.digest, func() (Op, error) { return s.resolve(g) }) if err != nil { - return nil, err + return err } - var cacheKey string - if s.cache != nil { - var err error - var numRefs int - cacheKey, numRefs, err = s.getCacheKey(ctx, g) - if err != nil { - return nil, err - } - cacheRefs := make([]Reference, 0, numRefs) - // check if all current refs are already cached - for i := 0; i < numRefs; i++ { - ref, err := s.cache.Lookup(ctx, fmt.Sprintf("%s.%d", cacheKey, i)) - if err != nil { - return nil, err - } - if ref == nil { // didn't find ref, release all - for _, ref := range cacheRefs { - ref.Release(context.TODO()) - } - break - } - cacheRefs = append(cacheRefs, ref.(Reference)) - if len(cacheRefs) == numRefs { // last item - g.recursiveMarkCached(ctx) - return cacheRefs, nil - } - } - } + inputCacheKeysMu := sync.Mutex{} + inputCacheKeys := make([][]digest.Digest, len(g.inputs)) + walkerStopped := false - // refs contains all outputs for all input vertexes - refs := make([][]*sharedRef, len(g.inputs)) - if len(g.inputs) > 0 { - eg, ctx := errgroup.WithContext(ctx) - for i, in := range g.inputs { - func(i int, in *vertex, index int) { - eg.Go(func() error { - if s.cache != nil { - k, numRefs, err := s.getCacheKey(ctx, in) - if err != nil { - return err - } - ref, err := s.cache.Lookup(ctx, fmt.Sprintf("%s.%d", k, index)) - if err != nil { - return err - } - if ref != nil { - if ref, ok := toImmutableRef(ref.(Reference)); ok { - refs[i] = make([]*sharedRef, numRefs) - refs[i][index] = newSharedRef(ref) - in.recursiveMarkCached(ctx) - return nil - } - } - } + inputCtx, cancelInputCtx := context.WithCancel(ctx) + defer cancelInputCtx() - // execute input vertex - r, err := s.getRefs(ctx, in) - if err != nil { - return err - } - for _, r := range r { - refs[i] = append(refs[i], newSharedRef(r)) - } - if ref, ok := toImmutableRef(r[index].(Reference)); ok { - // make sure input that is required by next step does not get released in case build is cancelled - if err := cache.CachePolicyRetain(ref); err != nil { - return err - } - } - return nil - }) - }(i, in.vertex, in.index) - } - err := eg.Wait() - if err != nil { - for _, r := range refs { - for _, r := range r { - if r != nil { - go r.Release(context.TODO()) - } - } - } - return nil, err - } - } - - // determine the inputs that were needed - inputRefs := make([]Reference, 0, len(g.inputs)) - for i, inp := range g.inputs { - inputRefs = append(inputRefs, refs[i][inp.index].Clone()) - } + inputRefs := make([]Reference, len(g.inputs)) defer func() { for _, r := range inputRefs { - go r.Release(context.TODO()) - } - }() - - // release anything else - for _, r := range refs { - for _, r := range r { if r != nil { go r.Release(context.TODO()) } } + }() + + if len(g.inputs) > 0 { + eg, ctx := errgroup.WithContext(ctx) + for i, in := range g.inputs { + func(i int, in *input) { + eg.Go(func() error { + var inputRef Reference + defer func() { + if inputRef != nil { + go inputRef.Release(context.TODO()) + } + }() + err := s.walkVertex(inputCtx, in.vertex, in.index, func(k digest.Digest, ref Reference) (bool, error) { + if k == "" && ref == nil { + // indicator between cache key and reference + if inputRef != nil { + return true, nil + } + // TODO: might be good to block here if other inputs may + // cause cache hits to avoid duplicate work. + return false, nil + } + if ref != nil { + inputRef = ref + return true, nil + } + + inputCacheKeysMu.Lock() + defer inputCacheKeysMu.Unlock() + + if walkerStopped { + return walkerStopped, nil + } + + // try all known combinations together with new key + inputCacheKeysCopy := append([][]digest.Digest{}, inputCacheKeys...) + inputCacheKeysCopy[i] = []digest.Digest{k} + inputCacheKeys[i] = append(inputCacheKeys[i], k) + + for _, inputKeys := range combinations(inputCacheKeysCopy) { + cacheKey, err := s.getCacheKey(ctx, g, inputKeys, index) + if err != nil { + return false, err + } + stop, err := fn(cacheKey, nil) + if err != nil { + return false, err + } + if stop { + walkerStopped = true + cancelInputCtx() // parent matched, stop processing current node and its inputs + return true, nil + } + } + + // if no parent matched, try looking up current node from cache + if s.cache != nil && inputRef == nil { + lookupRef, err := s.cache.Lookup(ctx, k) + if err != nil { + return false, err + } + if lookupRef != nil { + inputRef = lookupRef.(Reference) + in.vertex.recursiveMarkCached(ctx) + return true, nil + } + } + return false, nil + }) + if inputRef != nil { + // make sure that the inputs for other steps don't get released on cancellation + if ref, ok := toImmutableRef(inputRef); ok { + if err := cache.CachePolicyRetain(ref); err != nil { + return err + } + if err := ref.Metadata().Commit(); err != nil { + return err + } + } + } + inputCacheKeysMu.Lock() + defer inputCacheKeysMu.Unlock() + if walkerStopped { + return nil + } + if err != nil { + return err + } + inputRefs[i] = inputRef + inputRef = nil + return nil + }) + }(i, in) + } + if err := eg.Wait(); err != nil && !walkerStopped { + return err + } + } else { + cacheKey, err := s.getCacheKey(ctx, g, nil, index) + if err != nil { + return err + } + stop, err := fn(cacheKey, nil) + if err != nil { + return err + } + walkerStopped = stop + } + + if walkerStopped { + return nil + } + + var contentKeys []digest.Digest + if s.cache != nil { + // try to determine content based key + contentKeys, err = state.op.ContentKeys(ctx, combinations(inputCacheKeys), inputRefs) + if err != nil { + return err + } + + for _, k := range contentKeys { + cks, err := s.cache.GetContentMapping(contentKeyWithIndex(k, index)) + if err != nil { + return err + } + for _, k := range cks { + stop, err := fn(k, nil) + if err != nil { + return err + } + if stop { + return nil + } + } + } + } + + // signal that no more cache keys are coming + stop, err := fn("", nil) + if err != nil { + return err + } + if stop { + return nil } pw, _, ctx := progress.FromContext(ctx, progress.WithMetadata("vertex", g.Digest())) @@ -316,18 +370,105 @@ func (s *Solver) getRefs(ctx context.Context, g *vertex) (retRef []Reference, re g.notifyCompleted(ctx, false, retErr) }() - return state.GetRefs(ctx, func(ctx context.Context, op Op) ([]Reference, error) { + ref, err := state.GetRefs(ctx, index, func(ctx context.Context, op Op) ([]Reference, error) { refs, err := op.Run(ctx, inputRefs) if err != nil { return nil, err } if s.cache != nil { + mainInputKeys := firstKeys(inputCacheKeys) for i, ref := range refs { - if err := s.cache.Set(fmt.Sprintf("%s.%d", cacheKey, i), originRef(ref)); err != nil { - logrus.Errorf("failed to save cache for %s: %v", cacheKey, err) + if ref != nil { + cacheKey, err := s.getCacheKey(ctx, g, mainInputKeys, Index(i)) + if err != nil { + return nil, err + } + r := originRef(ref) + if err := s.cache.Set(cacheKey, r); err != nil { + logrus.Errorf("failed to save cache for %s: %v", cacheKey, err) + } + + for _, k := range contentKeys { + if err := s.cache.SetContentMapping(contentKeyWithIndex(k, Index(i)), r); err != nil { + logrus.Errorf("failed to save content mapping: %v", err) + } + } } } } return refs, nil }) + if err != nil { + return err + } + + // return reference + _, err = fn("", ref) + if err != nil { + return err + } + + return nil +} + +func (s *Solver) getRef(ctx context.Context, g *vertex, index Index) (ref Reference, retErr error) { + logrus.Debugf("> getRef %s %v %s", g.digest, index, g.name) + defer logrus.Debugf("< getRef %s %v", g.digest, index) + + var returnRef Reference + err := s.walkVertex(ctx, g, index, func(ck digest.Digest, ref Reference) (bool, error) { + if ref != nil { + returnRef = ref + return true, nil + } + if ck == "" { + return false, nil + } + lookupRef, err := s.cache.Lookup(ctx, ck) + if err != nil { + return false, err + } + if lookupRef != nil { + g.recursiveMarkCached(ctx) + returnRef = lookupRef.(Reference) + return true, nil + } + return false, nil + }) + if err != nil { + return nil, err + } + + return returnRef, nil +} + +func firstKeys(inp [][]digest.Digest) []digest.Digest { + var out []digest.Digest + for _, v := range inp { + out = append(out, v[0]) + } + return out +} + +func combinations(inp [][]digest.Digest) [][]digest.Digest { + var out [][]digest.Digest + if len(inp) == 0 { + return inp + } + if len(inp) == 1 { + for _, v := range inp[0] { + out = append(out, []digest.Digest{v}) + } + return out + } + for _, v := range inp[0] { + for _, v2 := range combinations(inp[1:]) { + out = append(out, append([]digest.Digest{v}, v2...)) + } + } + return out +} + +func contentKeyWithIndex(dgst digest.Digest, index Index) digest.Digest { + return digest.FromBytes([]byte(fmt.Sprintf("%s.%d", dgst, index))) } diff --git a/solver/source.go b/solver/source.go index 6180a9c7..7d4c6a83 100644 --- a/solver/source.go +++ b/solver/source.go @@ -5,9 +5,12 @@ import ( "github.com/moby/buildkit/solver/pb" "github.com/moby/buildkit/source" + digest "github.com/opencontainers/go-digest" "golang.org/x/net/context" ) +const sourceCacheType = "buildkit.source.v0" + type sourceOp struct { mu sync.Mutex op *pb.Op_Source @@ -58,13 +61,16 @@ func (s *sourceOp) instance(ctx context.Context) (source.SourceInstance, error) return s.src, nil } -func (s *sourceOp) CacheKey(ctx context.Context, _ []string) (string, int, error) { +func (s *sourceOp) CacheKey(ctx context.Context) (digest.Digest, error) { src, err := s.instance(ctx) if err != nil { - return "", 0, err + return "", err } k, err := src.CacheKey(ctx) - return k, 1, err + if err != nil { + return "", err + } + return digest.FromBytes([]byte(sourceCacheType + ":" + k)), nil } func (s *sourceOp) Run(ctx context.Context, _ []Reference) ([]Reference, error) { @@ -78,3 +84,7 @@ func (s *sourceOp) Run(ctx context.Context, _ []Reference) ([]Reference, error) } return []Reference{ref}, nil } + +func (s *sourceOp) ContentKeys(context.Context, [][]digest.Digest, []Reference) ([]digest.Digest, error) { + return nil, nil +} diff --git a/solver/state.go b/solver/state.go index 5adae684..518e15ab 100644 --- a/solver/state.go +++ b/solver/state.go @@ -25,8 +25,7 @@ type state struct { key digest.Digest jobs map[*job]struct{} refs []*sharedRef - cacheKey string - numRefs int + cacheKey digest.Digest op Op progressCtx context.Context cacheCtx context.Context @@ -75,7 +74,7 @@ func (s *activeState) cancel(j *job) { } } -func (s *state) GetRefs(ctx context.Context, cb func(context.Context, Op) ([]Reference, error)) ([]Reference, error) { +func (s *state) GetRefs(ctx context.Context, index Index, cb func(context.Context, Op) ([]Reference, error)) (Reference, error) { _, err := s.Do(ctx, s.key.String(), func(doctx context.Context) (interface{}, error) { if s.refs != nil { if err := writeProgressSnapshot(s.progressCtx, ctx); err != nil { @@ -98,14 +97,10 @@ func (s *state) GetRefs(ctx context.Context, cb func(context.Context, Op) ([]Ref if err != nil { return nil, err } - refs := make([]Reference, 0, len(s.refs)) - for _, r := range s.refs { - refs = append(refs, r.Clone()) - } - return refs, nil + return s.refs[int(index)].Clone(), nil } -func (s *state) GetCacheKey(ctx context.Context, cb func(context.Context, Op) (string, int, error)) (string, int, error) { +func (s *state) GetCacheKey(ctx context.Context, cb func(context.Context, Op) (digest.Digest, error)) (digest.Digest, error) { _, err := s.Do(ctx, "cache:"+s.key.String(), func(doctx context.Context) (interface{}, error) { if s.cacheKey != "" { if err := writeProgressSnapshot(s.cacheCtx, ctx); err != nil { @@ -113,19 +108,18 @@ func (s *state) GetCacheKey(ctx context.Context, cb func(context.Context, Op) (s } return nil, nil } - cacheKey, numRefs, err := cb(doctx, s.op) + cacheKey, err := cb(doctx, s.op) if err != nil { return nil, err } s.cacheKey = cacheKey - s.numRefs = numRefs s.cacheCtx = doctx return nil, nil }) if err != nil { - return "", 0, err + return "", err } - return s.cacheKey, s.numRefs, nil + return s.cacheKey, nil } func writeProgressSnapshot(srcCtx, destCtx context.Context) error { diff --git a/solver/vertex.go b/solver/vertex.go index a6cb9928..90655a63 100644 --- a/solver/vertex.go +++ b/solver/vertex.go @@ -22,14 +22,16 @@ type Vertex interface { Name() string // change this to general metadata } +type Index int + // Input is an pointer to a single reference from a vertex by an index. type Input struct { - Index int + Index Index Vertex Vertex } type input struct { - index int + index Index vertex *vertex } diff --git a/source/local/local.go b/source/local/local.go index 426c5dd5..bc2363e2 100644 --- a/source/local/local.go +++ b/source/local/local.go @@ -6,6 +6,7 @@ import ( "github.com/boltdb/bolt" "github.com/moby/buildkit/cache" + "github.com/moby/buildkit/cache/contenthash" "github.com/moby/buildkit/cache/metadata" "github.com/moby/buildkit/session" "github.com/moby/buildkit/session/filesync" @@ -14,6 +15,7 @@ import ( "github.com/moby/buildkit/util/progress" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "github.com/tonistiigi/fsutil" "golang.org/x/net/context" "golang.org/x/time/rate" ) @@ -139,12 +141,17 @@ func (ls *localSourceHandler) Snapshot(ctx context.Context) (out cache.Immutable } }() + cc, err := contenthash.GetCacheContext(ctx, mutable.Metadata()) + if err != nil { + return nil, err + } + opt := filesync.FSSendRequestOpt{ Name: ls.src.Name, IncludePatterns: nil, OverrideExcludes: false, DestDir: dest, - CacheUpdater: nil, + CacheUpdater: &cacheUpdater{cc}, ProgressCb: newProgressHandler(ctx, "transferring "+ls.src.Name+":"), } @@ -157,6 +164,10 @@ func (ls *localSourceHandler) Snapshot(ctx context.Context) (out cache.Immutable } lm = nil + if err := contenthash.SetCacheContext(ctx, mutable.Metadata(), cc); err != nil { + return nil, err + } + // skip storing snapshot by the shared key if it already exists skipStoreSharedKey := false si, _ := ls.md.Get(mutable.ID()) @@ -185,6 +196,7 @@ func (ls *localSourceHandler) Snapshot(ctx context.Context) (out cache.Immutable if err != nil { return nil, err } + mutable = nil // avoid deferred cleanup return snap, nil @@ -213,3 +225,14 @@ func newProgressHandler(ctx context.Context, id string) func(int, bool) { } } } + +type cacheUpdater struct { + contenthash.CacheContext +} + +func (cu *cacheUpdater) MarkSupported(bool) { +} + +func (cu *cacheUpdater) ContentHasher() fsutil.ContentHasher { + return contenthash.NewFromStat +} diff --git a/vendor.conf b/vendor.conf index 692e2dd5..a0aaf210 100644 --- a/vendor.conf +++ b/vendor.conf @@ -7,7 +7,7 @@ github.com/pmezard/go-difflib v1.0.0 golang.org/x/sys 739734461d1c916b6c72a63d7efda2b27edb369f github.com/containerd/containerd 036232856fb8f088a844b22f3330bcddb5d44c0a -golang.org/x/sync 450f422ab23cf9881c94e2db30cac0eb1b7cf80c +golang.org/x/sync f52d1811a62927559de87708c8913c1650ce4f26 github.com/sirupsen/logrus v1.0.0 google.golang.org/grpc v1.3.0 github.com/opencontainers/go-digest 21dfd564fd89c944783d00d069f33e3e7123c448 @@ -39,5 +39,5 @@ github.com/pkg/profile 5b67d428864e92711fcbd2f8629456121a56d91f github.com/tonistiigi/fsutil 195d62bee906e45aa700b8ebeb3417f7b126bb23 github.com/stevvooe/continuity 86cec1535a968310e7532819f699ff2830ed7463 github.com/dmcgowan/go-tar 2e2c51242e8993c50445dab7c03c8e7febddd0cf -github.com/hashicorp/go-immutable-radix 8e8ed81f8f0bf1bdd829593fdd5c29922c1ea990 -github.com/hashicorp/golang-lru a0d98a5f288019575c6d1f4bb1573fef2d1fcdc4 \ No newline at end of file +github.com/hashicorp/go-immutable-radix 826af9ccf0feeee615d546d69b11f8e98da8c8f1 git://github.com/tonistiigi/go-immutable-radix.git +github.com/hashicorp/golang-lru a0d98a5f288019575c6d1f4bb1573fef2d1fcdc4 diff --git a/vendor/github.com/hashicorp/go-immutable-radix/iradix.go b/vendor/github.com/hashicorp/go-immutable-radix/iradix.go index b2555838..c7172c40 100644 --- a/vendor/github.com/hashicorp/go-immutable-radix/iradix.go +++ b/vendor/github.com/hashicorp/go-immutable-radix/iradix.go @@ -2,6 +2,7 @@ package iradix import ( "bytes" + "strings" "github.com/hashicorp/golang-lru/simplelru" ) @@ -11,7 +12,9 @@ const ( // cache used per transaction. This is used to cache the updates // to the nodes near the root, while the leaves do not need to be // cached. This is important for very large transactions to prevent - // the modified cache from growing to be enormous. + // the modified cache from growing to be enormous. This is also used + // to set the max size of the mutation notify maps since those should + // also be bounded in a similar way. defaultModifiedCache = 8192 ) @@ -27,7 +30,11 @@ type Tree struct { // New returns an empty Tree func New() *Tree { - t := &Tree{root: &Node{}} + t := &Tree{ + root: &Node{ + mutateCh: make(chan struct{}), + }, + } return t } @@ -40,75 +47,208 @@ func (t *Tree) Len() int { // atomically and returns a new tree when committed. A transaction // is not thread safe, and should only be used by a single goroutine. type Txn struct { - root *Node - size int - modified *simplelru.LRU + // root is the modified root for the transaction. + root *Node + + // snap is a snapshot of the root node for use if we have to run the + // slow notify algorithm. + snap *Node + + // size tracks the size of the tree as it is modified during the + // transaction. + size int + + // writable is a cache of writable nodes that have been created during + // the course of the transaction. This allows us to re-use the same + // nodes for further writes and avoid unnecessary copies of nodes that + // have never been exposed outside the transaction. This will only hold + // up to defaultModifiedCache number of entries. + writable *simplelru.LRU + + // trackChannels is used to hold channels that need to be notified to + // signal mutation of the tree. This will only hold up to + // defaultModifiedCache number of entries, after which we will set the + // trackOverflow flag, which will cause us to use a more expensive + // algorithm to perform the notifications. Mutation tracking is only + // performed if trackMutate is true. + trackChannels map[chan struct{}]struct{} + trackOverflow bool + trackMutate bool } // Txn starts a new transaction that can be used to mutate the tree func (t *Tree) Txn() *Txn { txn := &Txn{ root: t.root, + snap: t.root, size: t.size, } return txn } -// writeNode returns a node to be modified, if the current -// node as already been modified during the course of -// the transaction, it is used in-place. -func (t *Txn) writeNode(n *Node) *Node { - // Ensure the modified set exists - if t.modified == nil { +// TrackMutate can be used to toggle if mutations are tracked. If this is enabled +// then notifications will be issued for affected internal nodes and leaves when +// the transaction is committed. +func (t *Txn) TrackMutate(track bool) { + t.trackMutate = track +} + +// trackChannel safely attempts to track the given mutation channel, setting the +// overflow flag if we can no longer track any more. This limits the amount of +// state that will accumulate during a transaction and we have a slower algorithm +// to switch to if we overflow. +func (t *Txn) trackChannel(ch chan struct{}) { + // In overflow, make sure we don't store any more objects. + if t.trackOverflow { + return + } + + // If this would overflow the state we reject it and set the flag (since + // we aren't tracking everything that's required any longer). + if len(t.trackChannels) >= defaultModifiedCache { + // Mark that we are in the overflow state + t.trackOverflow = true + + // Clear the map so that the channels can be garbage collected. It is + // safe to do this since we have already overflowed and will be using + // the slow notify algorithm. + t.trackChannels = nil + return + } + + // Create the map on the fly when we need it. + if t.trackChannels == nil { + t.trackChannels = make(map[chan struct{}]struct{}) + } + + // Otherwise we are good to track it. + t.trackChannels[ch] = struct{}{} +} + +// writeNode returns a node to be modified, if the current node has already been +// modified during the course of the transaction, it is used in-place. Set +// forLeafUpdate to true if you are getting a write node to update the leaf, +// which will set leaf mutation tracking appropriately as well. +func (t *Txn) writeNode(n *Node, forLeafUpdate bool) *Node { + // Ensure the writable set exists. + if t.writable == nil { lru, err := simplelru.NewLRU(defaultModifiedCache, nil) if err != nil { panic(err) } - t.modified = lru + t.writable = lru } - // If this node has already been modified, we can - // continue to use it during this transaction. - if _, ok := t.modified.Get(n); ok { + // If this node has already been modified, we can continue to use it + // during this transaction. We know that we don't need to track it for + // a node update since the node is writable, but if this is for a leaf + // update we track it, in case the initial write to this node didn't + // update the leaf. + if _, ok := t.writable.Get(n); ok { + if t.trackMutate && forLeafUpdate && n.leaf != nil { + t.trackChannel(n.leaf.mutateCh) + } return n } - // Copy the existing node - nc := new(Node) + // Mark this node as being mutated. + if t.trackMutate { + t.trackChannel(n.mutateCh) + } + + // Mark its leaf as being mutated, if appropriate. + if t.trackMutate && forLeafUpdate && n.leaf != nil { + t.trackChannel(n.leaf.mutateCh) + } + + // Copy the existing node. If you have set forLeafUpdate it will be + // safe to replace this leaf with another after you get your node for + // writing. You MUST replace it, because the channel associated with + // this leaf will be closed when this transaction is committed. + nc := &Node{ + mutateCh: make(chan struct{}), + leaf: n.leaf, + } if n.prefix != nil { nc.prefix = make([]byte, len(n.prefix)) copy(nc.prefix, n.prefix) } - if n.leaf != nil { - nc.leaf = new(leafNode) - *nc.leaf = *n.leaf - } if len(n.edges) != 0 { nc.edges = make([]edge, len(n.edges)) copy(nc.edges, n.edges) } - // Mark this node as modified - t.modified.Add(n, nil) + // Mark this node as writable. + t.writable.Add(nc, nil) return nc } +// Visit all the nodes in the tree under n, and add their mutateChannels to the transaction +// Returns the size of the subtree visited +func (t *Txn) trackChannelsAndCount(n *Node) int { + // Count only leaf nodes + leaves := 0 + if n.leaf != nil { + leaves = 1 + } + // Mark this node as being mutated. + if t.trackMutate { + t.trackChannel(n.mutateCh) + } + + // Mark its leaf as being mutated, if appropriate. + if t.trackMutate && n.leaf != nil { + t.trackChannel(n.leaf.mutateCh) + } + + // Recurse on the children + for _, e := range n.edges { + leaves += t.trackChannelsAndCount(e.node) + } + return leaves +} + +// mergeChild is called to collapse the given node with its child. This is only +// called when the given node is not a leaf and has a single edge. +func (t *Txn) mergeChild(n *Node) { + // Mark the child node as being mutated since we are about to abandon + // it. We don't need to mark the leaf since we are retaining it if it + // is there. + e := n.edges[0] + child := e.node + if t.trackMutate { + t.trackChannel(child.mutateCh) + } + + // Merge the nodes. + n.prefix = concat(n.prefix, child.prefix) + n.leaf = child.leaf + if len(child.edges) != 0 { + n.edges = make([]edge, len(child.edges)) + copy(n.edges, child.edges) + } else { + n.edges = nil + } +} + // insert does a recursive insertion func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface{}, bool) { - // Handle key exhaution + // Handle key exhaustion if len(search) == 0 { - nc := t.writeNode(n) + var oldVal interface{} + didUpdate := false if n.isLeaf() { - old := nc.leaf.val - nc.leaf.val = v - return nc, old, true - } else { - nc.leaf = &leafNode{ - key: k, - val: v, - } - return nc, nil, false + oldVal = n.leaf.val + didUpdate = true } + + nc := t.writeNode(n, true) + nc.leaf = &leafNode{ + mutateCh: make(chan struct{}), + key: k, + val: v, + } + return nc, oldVal, didUpdate } // Look for the edge @@ -119,14 +259,16 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface e := edge{ label: search[0], node: &Node{ + mutateCh: make(chan struct{}), leaf: &leafNode{ - key: k, - val: v, + mutateCh: make(chan struct{}), + key: k, + val: v, }, prefix: search, }, } - nc := t.writeNode(n) + nc := t.writeNode(n, false) nc.addEdge(e) return nc, nil, false } @@ -137,7 +279,7 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface search = search[commonPrefix:] newChild, oldVal, didUpdate := t.insert(child, k, search, v) if newChild != nil { - nc := t.writeNode(n) + nc := t.writeNode(n, false) nc.edges[idx].node = newChild return nc, oldVal, didUpdate } @@ -145,9 +287,10 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface } // Split the node - nc := t.writeNode(n) + nc := t.writeNode(n, false) splitNode := &Node{ - prefix: search[:commonPrefix], + mutateCh: make(chan struct{}), + prefix: search[:commonPrefix], } nc.replaceEdge(edge{ label: search[0], @@ -155,7 +298,7 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface }) // Restore the existing child node - modChild := t.writeNode(child) + modChild := t.writeNode(child, false) splitNode.addEdge(edge{ label: modChild.prefix[commonPrefix], node: modChild, @@ -164,8 +307,9 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface // Create a new leaf node leaf := &leafNode{ - key: k, - val: v, + mutateCh: make(chan struct{}), + key: k, + val: v, } // If the new key is a subset, add to to this node @@ -179,8 +323,9 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface splitNode.addEdge(edge{ label: search[0], node: &Node{ - leaf: leaf, - prefix: search, + mutateCh: make(chan struct{}), + leaf: leaf, + prefix: search, }, }) return nc, nil, false @@ -188,19 +333,19 @@ func (t *Txn) insert(n *Node, k, search []byte, v interface{}) (*Node, interface // delete does a recursive deletion func (t *Txn) delete(parent, n *Node, search []byte) (*Node, *leafNode) { - // Check for key exhaution + // Check for key exhaustion if len(search) == 0 { if !n.isLeaf() { return nil, nil } // Remove the leaf node - nc := t.writeNode(n) + nc := t.writeNode(n, true) nc.leaf = nil // Check if this node should be merged if n != t.root && len(nc.edges) == 1 { - nc.mergeChild() + t.mergeChild(nc) } return nc, n.leaf } @@ -219,14 +364,17 @@ func (t *Txn) delete(parent, n *Node, search []byte) (*Node, *leafNode) { return nil, nil } - // Copy this node - nc := t.writeNode(n) + // Copy this node. WATCH OUT - it's safe to pass "false" here because we + // will only ADD a leaf via nc.mergeChild() if there isn't one due to + // the !nc.isLeaf() check in the logic just below. This is pretty subtle, + // so be careful if you change any of the logic here. + nc := t.writeNode(n, false) // Delete the edge if the node has no edges if newChild.leaf == nil && len(newChild.edges) == 0 { nc.delEdge(label) if n != t.root && len(nc.edges) == 1 && !nc.isLeaf() { - nc.mergeChild() + t.mergeChild(nc) } } else { nc.edges[idx].node = newChild @@ -234,6 +382,56 @@ func (t *Txn) delete(parent, n *Node, search []byte) (*Node, *leafNode) { return nc, leaf } +// delete does a recursive deletion +func (t *Txn) deletePrefix(parent, n *Node, search []byte) (*Node, int) { + // Check for key exhaustion + if len(search) == 0 { + nc := t.writeNode(n, true) + if n.isLeaf() { + nc.leaf = nil + } + nc.edges = nil + return nc, t.trackChannelsAndCount(n) + } + + // Look for an edge + label := search[0] + idx, child := n.getEdge(label) + // We make sure that either the child node's prefix starts with the search term, or the search term starts with the child node's prefix + // Need to do both so that we can delete prefixes that don't correspond to any node in the tree + if child == nil || (!bytes.HasPrefix(child.prefix, search) && !bytes.HasPrefix(search, child.prefix)) { + return nil, 0 + } + + // Consume the search prefix + if len(child.prefix) > len(search) { + search = []byte("") + } else { + search = search[len(child.prefix):] + } + newChild, numDeletions := t.deletePrefix(n, child, search) + if newChild == nil { + return nil, 0 + } + // Copy this node. WATCH OUT - it's safe to pass "false" here because we + // will only ADD a leaf via nc.mergeChild() if there isn't one due to + // the !nc.isLeaf() check in the logic just below. This is pretty subtle, + // so be careful if you change any of the logic here. + + nc := t.writeNode(n, false) + + // Delete the edge if the node has no edges + if newChild.leaf == nil && len(newChild.edges) == 0 { + nc.delEdge(label) + if n != t.root && len(nc.edges) == 1 && !nc.isLeaf() { + t.mergeChild(nc) + } + } else { + nc.edges[idx].node = newChild + } + return nc, numDeletions +} + // Insert is used to add or update a given key. The return provides // the previous value and a bool indicating if any was set. func (t *Txn) Insert(k []byte, v interface{}) (interface{}, bool) { @@ -261,6 +459,19 @@ func (t *Txn) Delete(k []byte) (interface{}, bool) { return nil, false } +// DeletePrefix is used to delete an entire subtree that matches the prefix +// This will delete all nodes under that prefix +func (t *Txn) DeletePrefix(prefix []byte) bool { + newRoot, numDeletions := t.deletePrefix(nil, t.root, prefix) + if newRoot != nil { + t.root = newRoot + t.size = t.size - numDeletions + return true + } + return false + +} + // Root returns the current root of the radix tree within this // transaction. The root is not safe across insert and delete operations, // but can be used to read the current state during a transaction. @@ -274,10 +485,115 @@ func (t *Txn) Get(k []byte) (interface{}, bool) { return t.root.Get(k) } -// Commit is used to finalize the transaction and return a new tree +// GetWatch is used to lookup a specific key, returning +// the watch channel, value and if it was found +func (t *Txn) GetWatch(k []byte) (<-chan struct{}, interface{}, bool) { + return t.root.GetWatch(k) +} + +// Commit is used to finalize the transaction and return a new tree. If mutation +// tracking is turned on then notifications will also be issued. func (t *Txn) Commit() *Tree { - t.modified = nil - return &Tree{t.root, t.size} + nt := t.CommitOnly() + if t.trackMutate { + t.Notify() + } + return nt +} + +// CommitOnly is used to finalize the transaction and return a new tree, but +// does not issue any notifications until Notify is called. +func (t *Txn) CommitOnly() *Tree { + nt := &Tree{t.root, t.size} + t.writable = nil + return nt +} + +// slowNotify does a complete comparison of the before and after trees in order +// to trigger notifications. This doesn't require any additional state but it +// is very expensive to compute. +func (t *Txn) slowNotify() { + snapIter := t.snap.rawIterator() + rootIter := t.root.rawIterator() + for snapIter.Front() != nil || rootIter.Front() != nil { + // If we've exhausted the nodes in the old snapshot, we know + // there's nothing remaining to notify. + if snapIter.Front() == nil { + return + } + snapElem := snapIter.Front() + + // If we've exhausted the nodes in the new root, we know we need + // to invalidate everything that remains in the old snapshot. We + // know from the loop condition there's something in the old + // snapshot. + if rootIter.Front() == nil { + close(snapElem.mutateCh) + if snapElem.isLeaf() { + close(snapElem.leaf.mutateCh) + } + snapIter.Next() + continue + } + + // Do one string compare so we can check the various conditions + // below without repeating the compare. + cmp := strings.Compare(snapIter.Path(), rootIter.Path()) + + // If the snapshot is behind the root, then we must have deleted + // this node during the transaction. + if cmp < 0 { + close(snapElem.mutateCh) + if snapElem.isLeaf() { + close(snapElem.leaf.mutateCh) + } + snapIter.Next() + continue + } + + // If the snapshot is ahead of the root, then we must have added + // this node during the transaction. + if cmp > 0 { + rootIter.Next() + continue + } + + // If we have the same path, then we need to see if we mutated a + // node and possibly the leaf. + rootElem := rootIter.Front() + if snapElem != rootElem { + close(snapElem.mutateCh) + if snapElem.leaf != nil && (snapElem.leaf != rootElem.leaf) { + close(snapElem.leaf.mutateCh) + } + } + snapIter.Next() + rootIter.Next() + } +} + +// Notify is used along with TrackMutate to trigger notifications. This must +// only be done once a transaction is committed via CommitOnly, and it is called +// automatically by Commit. +func (t *Txn) Notify() { + if !t.trackMutate { + return + } + + // If we've overflowed the tracking state we can't use it in any way and + // need to do a full tree compare. + if t.trackOverflow { + t.slowNotify() + } else { + for ch := range t.trackChannels { + close(ch) + } + } + + // Clean up the tracking state so that a re-notify is safe (will trigger + // the else clause above which will be a no-op). + t.trackChannels = nil + t.trackOverflow = false } // Insert is used to add or update a given key. The return provides @@ -296,6 +612,14 @@ func (t *Tree) Delete(k []byte) (*Tree, interface{}, bool) { return txn.Commit(), old, ok } +// DeletePrefix is used to delete all nodes starting with a given prefix. Returns the new tree, +// and a bool indicating if the prefix matched any nodes +func (t *Tree) DeletePrefix(k []byte) (*Tree, bool) { + txn := t.Txn() + ok := txn.DeletePrefix(k) + return txn.Commit(), ok +} + // Root returns the root node of the tree which can be used for richer // query operations. func (t *Tree) Root() *Node { diff --git a/vendor/github.com/hashicorp/go-immutable-radix/iter.go b/vendor/github.com/hashicorp/go-immutable-radix/iter.go index 75cbaa11..9815e025 100644 --- a/vendor/github.com/hashicorp/go-immutable-radix/iter.go +++ b/vendor/github.com/hashicorp/go-immutable-radix/iter.go @@ -9,11 +9,13 @@ type Iterator struct { stack []edges } -// SeekPrefix is used to seek the iterator to a given prefix -func (i *Iterator) SeekPrefix(prefix []byte) { +// SeekPrefixWatch is used to seek the iterator to a given prefix +// and returns the watch channel of the finest granularity +func (i *Iterator) SeekPrefixWatch(prefix []byte) (watch <-chan struct{}) { // Wipe the stack i.stack = nil n := i.node + watch = n.mutateCh search := prefix for { // Check for key exhaution @@ -29,6 +31,9 @@ func (i *Iterator) SeekPrefix(prefix []byte) { return } + // Update to the finest granularity as the search makes progress + watch = n.mutateCh + // Consume the search prefix if bytes.HasPrefix(search, n.prefix) { search = search[len(n.prefix):] @@ -43,6 +48,11 @@ func (i *Iterator) SeekPrefix(prefix []byte) { } } +// SeekPrefix is used to seek the iterator to a given prefix +func (i *Iterator) SeekPrefix(prefix []byte) { + i.SeekPrefixWatch(prefix) +} + // Next returns the next node in order func (i *Iterator) Next() ([]byte, interface{}, bool) { // Initialize our stack if needed diff --git a/vendor/github.com/hashicorp/go-immutable-radix/node.go b/vendor/github.com/hashicorp/go-immutable-radix/node.go index fea6f634..ef494fa7 100644 --- a/vendor/github.com/hashicorp/go-immutable-radix/node.go +++ b/vendor/github.com/hashicorp/go-immutable-radix/node.go @@ -12,8 +12,9 @@ type WalkFn func(k []byte, v interface{}) bool // leafNode is used to represent a value type leafNode struct { - key []byte - val interface{} + mutateCh chan struct{} + key []byte + val interface{} } // edge is used to represent an edge node @@ -24,6 +25,9 @@ type edge struct { // Node is an immutable node in the radix tree type Node struct { + // mutateCh is closed if this node is modified + mutateCh chan struct{} + // leaf is used to store possible leaf leaf *leafNode @@ -87,31 +91,14 @@ func (n *Node) delEdge(label byte) { } } -func (n *Node) mergeChild() { - e := n.edges[0] - child := e.node - n.prefix = concat(n.prefix, child.prefix) - if child.leaf != nil { - n.leaf = new(leafNode) - *n.leaf = *child.leaf - } else { - n.leaf = nil - } - if len(child.edges) != 0 { - n.edges = make([]edge, len(child.edges)) - copy(n.edges, child.edges) - } else { - n.edges = nil - } -} - -func (n *Node) Get(k []byte) (interface{}, bool) { +func (n *Node) GetWatch(k []byte) (<-chan struct{}, interface{}, bool) { search := k + watch := n.mutateCh for { - // Check for key exhaution + // Check for key exhaustion if len(search) == 0 { if n.isLeaf() { - return n.leaf.val, true + return n.leaf.mutateCh, n.leaf.val, true } break } @@ -122,6 +109,9 @@ func (n *Node) Get(k []byte) (interface{}, bool) { break } + // Update to the finest granularity as the search makes progress + watch = n.mutateCh + // Consume the search prefix if bytes.HasPrefix(search, n.prefix) { search = search[len(n.prefix):] @@ -129,7 +119,12 @@ func (n *Node) Get(k []byte) (interface{}, bool) { break } } - return nil, false + return watch, nil, false +} + +func (n *Node) Get(k []byte) (interface{}, bool) { + _, val, ok := n.GetWatch(k) + return val, ok } // LongestPrefix is like Get, but instead of an @@ -204,6 +199,14 @@ func (n *Node) Iterator() *Iterator { return &Iterator{node: n} } +// rawIterator is used to return a raw iterator at the given node to walk the +// tree. +func (n *Node) rawIterator() *rawIterator { + iter := &rawIterator{node: n} + iter.Next() + return iter +} + // Walk is used to walk the tree func (n *Node) Walk(fn WalkFn) { recursiveWalk(n, fn) @@ -271,6 +274,66 @@ func (n *Node) WalkPath(path []byte, fn WalkFn) { } } +func (n *Node) Seek(prefix []byte) *Seeker { + search := prefix + p := &pos{n: n} + for { + // Check for key exhaution + if len(search) == 0 { + return &Seeker{p} + } + + num := len(n.edges) + idx := sort.Search(num, func(i int) bool { + return n.edges[i].label >= search[0] + }) + p.current = idx + if idx < len(n.edges) { + n = n.edges[idx].node + if bytes.HasPrefix(search, n.prefix) && len(n.edges) > 0 { + search = search[len(n.prefix):] + p.current++ + p = &pos{n: n, prev: p} + continue + } + } + p.current++ + return &Seeker{p} + } +} + +type Seeker struct { + *pos +} + +type pos struct { + n *Node + current int + prev *pos + isLeaf bool +} + +func (s *Seeker) Next() (k []byte, v interface{}, ok bool) { + if s.current >= len(s.n.edges) { + if s.prev == nil { + return nil, nil, false + } + s.pos = s.prev + return s.Next() + } + + edge := s.n.edges[s.current] + s.current++ + if edge.node.leaf != nil && !s.isLeaf { + s.isLeaf = true + s.current-- + return edge.node.leaf.key, edge.node.leaf.val, true + } + s.isLeaf = false + s.pos = &pos{n: edge.node, prev: s.pos} + return s.Next() +} + // recursiveWalk is used to do a pre-order walk of a node // recursively. Returns true if the walk should be aborted func recursiveWalk(n *Node, fn WalkFn) bool { diff --git a/vendor/github.com/hashicorp/go-immutable-radix/raw_iter.go b/vendor/github.com/hashicorp/go-immutable-radix/raw_iter.go new file mode 100644 index 00000000..04814c13 --- /dev/null +++ b/vendor/github.com/hashicorp/go-immutable-radix/raw_iter.go @@ -0,0 +1,78 @@ +package iradix + +// rawIterator visits each of the nodes in the tree, even the ones that are not +// leaves. It keeps track of the effective path (what a leaf at a given node +// would be called), which is useful for comparing trees. +type rawIterator struct { + // node is the starting node in the tree for the iterator. + node *Node + + // stack keeps track of edges in the frontier. + stack []rawStackEntry + + // pos is the current position of the iterator. + pos *Node + + // path is the effective path of the current iterator position, + // regardless of whether the current node is a leaf. + path string +} + +// rawStackEntry is used to keep track of the cumulative common path as well as +// its associated edges in the frontier. +type rawStackEntry struct { + path string + edges edges +} + +// Front returns the current node that has been iterated to. +func (i *rawIterator) Front() *Node { + return i.pos +} + +// Path returns the effective path of the current node, even if it's not actually +// a leaf. +func (i *rawIterator) Path() string { + return i.path +} + +// Next advances the iterator to the next node. +func (i *rawIterator) Next() { + // Initialize our stack if needed. + if i.stack == nil && i.node != nil { + i.stack = []rawStackEntry{ + rawStackEntry{ + edges: edges{ + edge{node: i.node}, + }, + }, + } + } + + for len(i.stack) > 0 { + // Inspect the last element of the stack. + n := len(i.stack) + last := i.stack[n-1] + elem := last.edges[0].node + + // Update the stack. + if len(last.edges) > 1 { + i.stack[n-1].edges = last.edges[1:] + } else { + i.stack = i.stack[:n-1] + } + + // Push the edges onto the frontier. + if len(elem.edges) > 0 { + path := last.path + string(elem.prefix) + i.stack = append(i.stack, rawStackEntry{path, elem.edges}) + } + + i.pos = elem + i.path = last.path + string(elem.prefix) + return + } + + i.pos = nil + i.path = "" +}