From b3e8c63a48ad8c015f5631fc1947945b229b3919 Mon Sep 17 00:00:00 2001 From: Edgar Lee Date: Wed, 26 Jan 2022 11:55:45 -0800 Subject: [PATCH] Local should use session ID in op and only fallback to session group if failed before transfer started Signed-off-by: Edgar Lee --- session/filesync/filesync.go | 16 ++++++++++++-- source/local/local.go | 43 +++++++++++++++++++++++++----------- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/session/filesync/filesync.go b/session/filesync/filesync.go index 75e23cb2..ae3f29f8 100644 --- a/session/filesync/filesync.go +++ b/session/filesync/filesync.go @@ -70,7 +70,7 @@ func (sp *fsSyncProvider) handle(method string, stream grpc.ServerStream) (retEr } } if pr == nil { - return errors.New("failed to negotiate protocol") + return InvalidSessionError{errors.New("failed to negotiate protocol")} } opts, _ := metadata.FromIncomingContext(stream.Context()) // if no metadata continue with empty object @@ -83,7 +83,7 @@ func (sp *fsSyncProvider) handle(method string, stream grpc.ServerStream) (retEr dir, ok := sp.dirs[dirName] if !ok { - return status.Errorf(codes.NotFound, "no access allowed to dir %q", dirName) + return InvalidSessionError{status.Errorf(codes.NotFound, "no access allowed to dir %q", dirName)} } excludes := opts[keyExcludePatterns] @@ -317,3 +317,15 @@ func CopyFileWriter(ctx context.Context, md map[string]string, c session.Caller) return newStreamWriter(cc), nil } + +type InvalidSessionError struct { + err error +} + +func (e InvalidSessionError) Error() string { + return e.err.Error() +} + +func (e InvalidSessionError) Unwrap() error { + return e.err +} diff --git a/source/local/local.go b/source/local/local.go index 22562206..d2cd9d98 100644 --- a/source/local/local.go +++ b/source/local/local.go @@ -24,8 +24,6 @@ import ( "github.com/tonistiigi/fsutil" fstypes "github.com/tonistiigi/fsutil/types" "golang.org/x/time/rate" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) type Opt struct { @@ -89,22 +87,44 @@ func (ls *localSourceHandler) CacheKey(ctx context.Context, g session.Group, ind } func (ls *localSourceHandler) Snapshot(ctx context.Context, g session.Group) (cache.ImmutableRef, error) { + sessionID := ls.src.SessionID + if sessionID == "" { + return ls.snapshotWithAnySession(ctx, g) + } + + timeoutCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + caller, err := ls.sm.Get(timeoutCtx, sessionID, false) + if err != nil { + return ls.snapshotWithAnySession(ctx, g) + } + + ref, err := ls.snapshot(ctx, caller) + if err != nil { + var serr filesync.InvalidSessionError + if errors.As(err, &serr) { + return ls.snapshotWithAnySession(ctx, g) + } + return nil, err + } + return ref, nil +} + +func (ls *localSourceHandler) snapshotWithAnySession(ctx context.Context, g session.Group) (cache.ImmutableRef, error) { var ref cache.ImmutableRef err := ls.sm.Any(ctx, g, func(ctx context.Context, _ string, c session.Caller) error { - r, err := ls.snapshot(ctx, g, c) + r, err := ls.snapshot(ctx, c) if err != nil { return err } ref = r return nil }) - if err != nil { - return nil, err - } - return ref, nil + return ref, err } -func (ls *localSourceHandler) snapshot(ctx context.Context, s session.Group, caller session.Caller) (out cache.ImmutableRef, retErr error) { +func (ls *localSourceHandler) snapshot(ctx context.Context, caller session.Caller) (out cache.ImmutableRef, retErr error) { sharedKey := ls.src.Name + ":" + ls.src.SharedKeyHint + ":" + caller.SharedKey() // TODO: replace caller.SharedKey() with source based hint from client(absolute-path+nodeid) var mutable cache.MutableRef @@ -123,7 +143,7 @@ func (ls *localSourceHandler) snapshot(ctx context.Context, s session.Group, cal } if mutable == nil { - m, err := ls.cm.New(ctx, nil, s, cache.CachePolicyRetain, cache.WithRecordType(client.UsageRecordTypeLocalSource), cache.WithDescription(fmt.Sprintf("local source for %s", ls.src.Name))) + m, err := ls.cm.New(ctx, nil, nil, cache.CachePolicyRetain, cache.WithRecordType(client.UsageRecordTypeLocalSource), cache.WithDescription(fmt.Sprintf("local source for %s", ls.src.Name))) if err != nil { return nil, err } @@ -142,7 +162,7 @@ func (ls *localSourceHandler) snapshot(ctx context.Context, s session.Group, cal } }() - mount, err := mutable.Mount(ctx, false, s) + mount, err := mutable.Mount(ctx, false, nil) if err != nil { return nil, err } @@ -193,9 +213,6 @@ func (ls *localSourceHandler) snapshot(ctx context.Context, s session.Group, cal } if err := filesync.FSSync(ctx, caller, opt); err != nil { - if status.Code(err) == codes.NotFound { - return nil, errors.Errorf("local source %s not enabled from the client", ls.src.Name) - } return nil, err }