flightcontrol: fix possible invalid cancellation

There was a race with context getting cancelled and new request
arriving that could resulted new request to receive cancelled result
as well. This happened because lock was held when getting the Done()
channel but it could have been already released by the time returned
channel was closed.

Signed-off-by: Tonis Tiigi <tonistiigi@gmail.com>
v0.8
Tonis Tiigi 2020-12-11 00:44:07 -08:00
parent bf5e780c5e
commit e56e7ba46b
2 changed files with 88 additions and 26 deletions

View File

@ -130,7 +130,11 @@ func (c *call) wait(ctx context.Context) (v interface{}, err error) {
c.mu.Lock()
// detect case where caller has just returned, let it clean up before
select {
case <-c.ready: // could return if no error
case <-c.ready:
c.mu.Unlock()
<-c.cleaned
return nil, errRetry
case <-c.ctx.done: // could return if no error
c.mu.Unlock()
<-c.cleaned
return nil, errRetry
@ -141,6 +145,10 @@ func (c *call) wait(ctx context.Context) (v interface{}, err error) {
if ok {
c.progressState.add(pw)
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
c.ctxs = append(c.ctxs, ctx)
c.mu.Unlock()
@ -149,18 +157,16 @@ func (c *call) wait(ctx context.Context) (v interface{}, err error) {
select {
case <-ctx.Done():
select {
case <-c.ctx.Done():
if c.ctx.checkDone() {
// if this cancelled the last context, then wait for function to shut down
// and don't accept any more callers
<-c.ready
return c.result, c.err
default:
if ok {
c.progressState.close(pw)
}
return nil, ctx.Err()
}
if ok {
c.progressState.close(pw)
}
return nil, ctx.Err()
case <-c.ready:
return c.result, c.err // shared not implemented yet
}
@ -183,9 +189,6 @@ func (c *call) Deadline() (deadline time.Time, ok bool) {
}
func (c *call) Done() <-chan struct{} {
c.mu.Lock()
c.ctx.signal()
c.mu.Unlock()
return c.ctx.done
}
@ -238,23 +241,28 @@ func newContext(c *call) *sharedContext {
return &sharedContext{call: c, done: make(chan struct{})}
}
// call with lock
func (c *sharedContext) signal() {
func (sc *sharedContext) checkDone() bool {
sc.mu.Lock()
select {
case <-c.done:
case <-sc.done:
sc.mu.Unlock()
return true
default:
var err error
for _, ctx := range c.ctxs {
select {
case <-ctx.Done():
err = ctx.Err()
default:
return
}
}
c.err = err
close(c.done)
}
var err error
for _, ctx := range sc.ctxs {
select {
case <-ctx.Done():
err = ctx.Err()
default:
sc.mu.Unlock()
return false
}
}
sc.err = err
close(sc.done)
sc.mu.Unlock()
return true
}
type rawProgressWriter interface {

View File

@ -83,6 +83,61 @@ func TestCancelOne(t *testing.T) {
assert.Equal(t, counter, int64(1))
}
func TestCancelRace(t *testing.T) {
// t.Parallel() // disabled for better timing consistency. works with parallel as well
g := &Group{}
ctx, cancel := context.WithCancel(context.Background())
kick := make(chan struct{})
wait := make(chan struct{})
count := 0
// first run cancels context, second returns cleanly
f := func(ctx context.Context) (interface{}, error) {
done := ctx.Done()
if count > 0 {
time.Sleep(100 * time.Millisecond)
return nil, nil
}
go func() {
for {
select {
case <-wait:
return
default:
ctx.Done()
}
}
}()
count++
time.Sleep(50 * time.Millisecond)
close(kick)
time.Sleep(50 * time.Millisecond)
select {
case <-done:
return nil, ctx.Err()
case <-time.After(200 * time.Millisecond):
}
return nil, nil
}
go func() {
defer close(wait)
<-kick
cancel()
time.Sleep(5 * time.Millisecond)
_, err := g.Do(context.Background(), "foo", f)
require.NoError(t, err)
}()
_, err := g.Do(ctx, "foo", f)
require.Error(t, err)
require.Equal(t, true, errors.Is(err, context.Canceled))
<-wait
}
func TestCancelBoth(t *testing.T) {
t.Parallel()
g := &Group{}
@ -133,7 +188,6 @@ func TestCancelBoth(t *testing.T) {
assert.Equal(t, "", r1)
assert.Equal(t, "", r2)
assert.Equal(t, counter, int64(1))
ret1, err := g.Do(context.TODO(), "foo", f)
assert.NoError(t, err)
assert.Equal(t, ret1, "bar")