From e56e7ba46ba649c5a1d9d9d7790ca0a9b38b3f12 Mon Sep 17 00:00:00 2001 From: Tonis Tiigi Date: Fri, 11 Dec 2020 00:44:07 -0800 Subject: [PATCH] flightcontrol: fix possible invalid cancellation There was a race with context getting cancelled and new request arriving that could resulted new request to receive cancelled result as well. This happened because lock was held when getting the Done() channel but it could have been already released by the time returned channel was closed. Signed-off-by: Tonis Tiigi --- util/flightcontrol/flightcontrol.go | 58 ++++++++++++++---------- util/flightcontrol/flightcontrol_test.go | 56 ++++++++++++++++++++++- 2 files changed, 88 insertions(+), 26 deletions(-) diff --git a/util/flightcontrol/flightcontrol.go b/util/flightcontrol/flightcontrol.go index a543364c..fc9f7272 100644 --- a/util/flightcontrol/flightcontrol.go +++ b/util/flightcontrol/flightcontrol.go @@ -130,7 +130,11 @@ func (c *call) wait(ctx context.Context) (v interface{}, err error) { c.mu.Lock() // detect case where caller has just returned, let it clean up before select { - case <-c.ready: // could return if no error + case <-c.ready: + c.mu.Unlock() + <-c.cleaned + return nil, errRetry + case <-c.ctx.done: // could return if no error c.mu.Unlock() <-c.cleaned return nil, errRetry @@ -141,6 +145,10 @@ func (c *call) wait(ctx context.Context) (v interface{}, err error) { if ok { c.progressState.add(pw) } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + c.ctxs = append(c.ctxs, ctx) c.mu.Unlock() @@ -149,18 +157,16 @@ func (c *call) wait(ctx context.Context) (v interface{}, err error) { select { case <-ctx.Done(): - select { - case <-c.ctx.Done(): + if c.ctx.checkDone() { // if this cancelled the last context, then wait for function to shut down // and don't accept any more callers <-c.ready return c.result, c.err - default: - if ok { - c.progressState.close(pw) - } - return nil, ctx.Err() } + if ok { + c.progressState.close(pw) + } + return nil, ctx.Err() case <-c.ready: return c.result, c.err // shared not implemented yet } @@ -183,9 +189,6 @@ func (c *call) Deadline() (deadline time.Time, ok bool) { } func (c *call) Done() <-chan struct{} { - c.mu.Lock() - c.ctx.signal() - c.mu.Unlock() return c.ctx.done } @@ -238,23 +241,28 @@ func newContext(c *call) *sharedContext { return &sharedContext{call: c, done: make(chan struct{})} } -// call with lock -func (c *sharedContext) signal() { +func (sc *sharedContext) checkDone() bool { + sc.mu.Lock() select { - case <-c.done: + case <-sc.done: + sc.mu.Unlock() + return true default: - var err error - for _, ctx := range c.ctxs { - select { - case <-ctx.Done(): - err = ctx.Err() - default: - return - } - } - c.err = err - close(c.done) } + var err error + for _, ctx := range sc.ctxs { + select { + case <-ctx.Done(): + err = ctx.Err() + default: + sc.mu.Unlock() + return false + } + } + sc.err = err + close(sc.done) + sc.mu.Unlock() + return true } type rawProgressWriter interface { diff --git a/util/flightcontrol/flightcontrol_test.go b/util/flightcontrol/flightcontrol_test.go index 45a2e96f..bb02386f 100644 --- a/util/flightcontrol/flightcontrol_test.go +++ b/util/flightcontrol/flightcontrol_test.go @@ -83,6 +83,61 @@ func TestCancelOne(t *testing.T) { assert.Equal(t, counter, int64(1)) } +func TestCancelRace(t *testing.T) { + // t.Parallel() // disabled for better timing consistency. works with parallel as well + + g := &Group{} + ctx, cancel := context.WithCancel(context.Background()) + + kick := make(chan struct{}) + wait := make(chan struct{}) + + count := 0 + + // first run cancels context, second returns cleanly + f := func(ctx context.Context) (interface{}, error) { + done := ctx.Done() + if count > 0 { + time.Sleep(100 * time.Millisecond) + return nil, nil + } + go func() { + for { + select { + case <-wait: + return + default: + ctx.Done() + } + } + }() + count++ + time.Sleep(50 * time.Millisecond) + close(kick) + time.Sleep(50 * time.Millisecond) + select { + case <-done: + return nil, ctx.Err() + case <-time.After(200 * time.Millisecond): + } + return nil, nil + } + + go func() { + defer close(wait) + <-kick + cancel() + time.Sleep(5 * time.Millisecond) + _, err := g.Do(context.Background(), "foo", f) + require.NoError(t, err) + }() + + _, err := g.Do(ctx, "foo", f) + require.Error(t, err) + require.Equal(t, true, errors.Is(err, context.Canceled)) + <-wait +} + func TestCancelBoth(t *testing.T) { t.Parallel() g := &Group{} @@ -133,7 +188,6 @@ func TestCancelBoth(t *testing.T) { assert.Equal(t, "", r1) assert.Equal(t, "", r2) assert.Equal(t, counter, int64(1)) - ret1, err := g.Do(context.TODO(), "foo", f) assert.NoError(t, err) assert.Equal(t, ret1, "bar")