flightcontrol: fix possible invalid cancellation
There was a race with context getting cancelled and new request arriving that could resulted new request to receive cancelled result as well. This happened because lock was held when getting the Done() channel but it could have been already released by the time returned channel was closed. Signed-off-by: Tonis Tiigi <tonistiigi@gmail.com>v0.8
parent
bf5e780c5e
commit
e56e7ba46b
|
@ -130,7 +130,11 @@ func (c *call) wait(ctx context.Context) (v interface{}, err error) {
|
|||
c.mu.Lock()
|
||||
// detect case where caller has just returned, let it clean up before
|
||||
select {
|
||||
case <-c.ready: // could return if no error
|
||||
case <-c.ready:
|
||||
c.mu.Unlock()
|
||||
<-c.cleaned
|
||||
return nil, errRetry
|
||||
case <-c.ctx.done: // could return if no error
|
||||
c.mu.Unlock()
|
||||
<-c.cleaned
|
||||
return nil, errRetry
|
||||
|
@ -141,6 +145,10 @@ func (c *call) wait(ctx context.Context) (v interface{}, err error) {
|
|||
if ok {
|
||||
c.progressState.add(pw)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
c.ctxs = append(c.ctxs, ctx)
|
||||
|
||||
c.mu.Unlock()
|
||||
|
@ -149,18 +157,16 @@ func (c *call) wait(ctx context.Context) (v interface{}, err error) {
|
|||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
if c.ctx.checkDone() {
|
||||
// if this cancelled the last context, then wait for function to shut down
|
||||
// and don't accept any more callers
|
||||
<-c.ready
|
||||
return c.result, c.err
|
||||
default:
|
||||
if ok {
|
||||
c.progressState.close(pw)
|
||||
}
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
if ok {
|
||||
c.progressState.close(pw)
|
||||
}
|
||||
return nil, ctx.Err()
|
||||
case <-c.ready:
|
||||
return c.result, c.err // shared not implemented yet
|
||||
}
|
||||
|
@ -183,9 +189,6 @@ func (c *call) Deadline() (deadline time.Time, ok bool) {
|
|||
}
|
||||
|
||||
func (c *call) Done() <-chan struct{} {
|
||||
c.mu.Lock()
|
||||
c.ctx.signal()
|
||||
c.mu.Unlock()
|
||||
return c.ctx.done
|
||||
}
|
||||
|
||||
|
@ -238,23 +241,28 @@ func newContext(c *call) *sharedContext {
|
|||
return &sharedContext{call: c, done: make(chan struct{})}
|
||||
}
|
||||
|
||||
// call with lock
|
||||
func (c *sharedContext) signal() {
|
||||
func (sc *sharedContext) checkDone() bool {
|
||||
sc.mu.Lock()
|
||||
select {
|
||||
case <-c.done:
|
||||
case <-sc.done:
|
||||
sc.mu.Unlock()
|
||||
return true
|
||||
default:
|
||||
var err error
|
||||
for _, ctx := range c.ctxs {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err = ctx.Err()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
c.err = err
|
||||
close(c.done)
|
||||
}
|
||||
var err error
|
||||
for _, ctx := range sc.ctxs {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err = ctx.Err()
|
||||
default:
|
||||
sc.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
}
|
||||
sc.err = err
|
||||
close(sc.done)
|
||||
sc.mu.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
type rawProgressWriter interface {
|
||||
|
|
|
@ -83,6 +83,61 @@ func TestCancelOne(t *testing.T) {
|
|||
assert.Equal(t, counter, int64(1))
|
||||
}
|
||||
|
||||
func TestCancelRace(t *testing.T) {
|
||||
// t.Parallel() // disabled for better timing consistency. works with parallel as well
|
||||
|
||||
g := &Group{}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
kick := make(chan struct{})
|
||||
wait := make(chan struct{})
|
||||
|
||||
count := 0
|
||||
|
||||
// first run cancels context, second returns cleanly
|
||||
f := func(ctx context.Context) (interface{}, error) {
|
||||
done := ctx.Done()
|
||||
if count > 0 {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return nil, nil
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-wait:
|
||||
return
|
||||
default:
|
||||
ctx.Done()
|
||||
}
|
||||
}
|
||||
}()
|
||||
count++
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
close(kick)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
select {
|
||||
case <-done:
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(wait)
|
||||
<-kick
|
||||
cancel()
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
_, err := g.Do(context.Background(), "foo", f)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
_, err := g.Do(ctx, "foo", f)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, true, errors.Is(err, context.Canceled))
|
||||
<-wait
|
||||
}
|
||||
|
||||
func TestCancelBoth(t *testing.T) {
|
||||
t.Parallel()
|
||||
g := &Group{}
|
||||
|
@ -133,7 +188,6 @@ func TestCancelBoth(t *testing.T) {
|
|||
assert.Equal(t, "", r1)
|
||||
assert.Equal(t, "", r2)
|
||||
assert.Equal(t, counter, int64(1))
|
||||
|
||||
ret1, err := g.Do(context.TODO(), "foo", f)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, ret1, "bar")
|
||||
|
|
Loading…
Reference in New Issue