diff --git a/util/flightcontrol/flightcontrol.go b/util/flightcontrol/flightcontrol.go index a543364c..fc9f7272 100644 --- a/util/flightcontrol/flightcontrol.go +++ b/util/flightcontrol/flightcontrol.go @@ -130,7 +130,11 @@ func (c *call) wait(ctx context.Context) (v interface{}, err error) { c.mu.Lock() // detect case where caller has just returned, let it clean up before select { - case <-c.ready: // could return if no error + case <-c.ready: + c.mu.Unlock() + <-c.cleaned + return nil, errRetry + case <-c.ctx.done: // could return if no error c.mu.Unlock() <-c.cleaned return nil, errRetry @@ -141,6 +145,10 @@ func (c *call) wait(ctx context.Context) (v interface{}, err error) { if ok { c.progressState.add(pw) } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + c.ctxs = append(c.ctxs, ctx) c.mu.Unlock() @@ -149,18 +157,16 @@ func (c *call) wait(ctx context.Context) (v interface{}, err error) { select { case <-ctx.Done(): - select { - case <-c.ctx.Done(): + if c.ctx.checkDone() { // if this cancelled the last context, then wait for function to shut down // and don't accept any more callers <-c.ready return c.result, c.err - default: - if ok { - c.progressState.close(pw) - } - return nil, ctx.Err() } + if ok { + c.progressState.close(pw) + } + return nil, ctx.Err() case <-c.ready: return c.result, c.err // shared not implemented yet } @@ -183,9 +189,6 @@ func (c *call) Deadline() (deadline time.Time, ok bool) { } func (c *call) Done() <-chan struct{} { - c.mu.Lock() - c.ctx.signal() - c.mu.Unlock() return c.ctx.done } @@ -238,23 +241,28 @@ func newContext(c *call) *sharedContext { return &sharedContext{call: c, done: make(chan struct{})} } -// call with lock -func (c *sharedContext) signal() { +func (sc *sharedContext) checkDone() bool { + sc.mu.Lock() select { - case <-c.done: + case <-sc.done: + sc.mu.Unlock() + return true default: - var err error - for _, ctx := range c.ctxs { - select { - case <-ctx.Done(): - err = ctx.Err() - default: - return - } - } - c.err = err - close(c.done) } + var err error + for _, ctx := range sc.ctxs { + select { + case <-ctx.Done(): + err = ctx.Err() + default: + sc.mu.Unlock() + return false + } + } + sc.err = err + close(sc.done) + sc.mu.Unlock() + return true } type rawProgressWriter interface { diff --git a/util/flightcontrol/flightcontrol_test.go b/util/flightcontrol/flightcontrol_test.go index 45a2e96f..bb02386f 100644 --- a/util/flightcontrol/flightcontrol_test.go +++ b/util/flightcontrol/flightcontrol_test.go @@ -83,6 +83,61 @@ func TestCancelOne(t *testing.T) { assert.Equal(t, counter, int64(1)) } +func TestCancelRace(t *testing.T) { + // t.Parallel() // disabled for better timing consistency. works with parallel as well + + g := &Group{} + ctx, cancel := context.WithCancel(context.Background()) + + kick := make(chan struct{}) + wait := make(chan struct{}) + + count := 0 + + // first run cancels context, second returns cleanly + f := func(ctx context.Context) (interface{}, error) { + done := ctx.Done() + if count > 0 { + time.Sleep(100 * time.Millisecond) + return nil, nil + } + go func() { + for { + select { + case <-wait: + return + default: + ctx.Done() + } + } + }() + count++ + time.Sleep(50 * time.Millisecond) + close(kick) + time.Sleep(50 * time.Millisecond) + select { + case <-done: + return nil, ctx.Err() + case <-time.After(200 * time.Millisecond): + } + return nil, nil + } + + go func() { + defer close(wait) + <-kick + cancel() + time.Sleep(5 * time.Millisecond) + _, err := g.Do(context.Background(), "foo", f) + require.NoError(t, err) + }() + + _, err := g.Do(ctx, "foo", f) + require.Error(t, err) + require.Equal(t, true, errors.Is(err, context.Canceled)) + <-wait +} + func TestCancelBoth(t *testing.T) { t.Parallel() g := &Group{} @@ -133,7 +188,6 @@ func TestCancelBoth(t *testing.T) { assert.Equal(t, "", r1) assert.Equal(t, "", r2) assert.Equal(t, counter, int64(1)) - ret1, err := g.Do(context.TODO(), "foo", f) assert.NoError(t, err) assert.Equal(t, ret1, "bar")