fix race condition on progress that could cause deadlock

main
Martin Guibert 2021-03-16 18:07:22 +01:00
parent ead1d41ce6
commit d8665726b9
2 changed files with 27 additions and 14 deletions

View File

@ -23,7 +23,6 @@ type Progress interface {
} }
type progress struct { type progress struct {
ticChan chan struct{}
endChan chan struct{} endChan chan struct{}
started *atomic.Bool started *atomic.Bool
count *atomic.Uint64 count *atomic.Uint64
@ -31,8 +30,7 @@ type progress struct {
func NewProgress() *progress { func NewProgress() *progress {
return &progress{ return &progress{
make(chan struct{}), nil,
make(chan struct{}),
atomic.NewBool(false), atomic.NewBool(false),
atomic.NewUint64(0), atomic.NewUint64(0),
} }
@ -40,6 +38,8 @@ func NewProgress() *progress {
func (p *progress) Start() { func (p *progress) Start() {
if !p.started.Swap(true) { if !p.started.Swap(true) {
p.count.Store(0)
p.endChan = make(chan struct{})
go p.watch() go p.watch()
go p.render() go p.render()
} }
@ -47,14 +47,14 @@ func (p *progress) Start() {
func (p *progress) Stop() { func (p *progress) Stop() {
if p.started.Swap(false) { if p.started.Swap(false) {
p.endChan <- struct{}{} close(p.endChan)
Printf("\n") Printf("\n")
} }
} }
func (p *progress) Inc() { func (p *progress) Inc() {
if p.started.Load() { if p.started.Load() {
p.ticChan <- struct{}{} p.count.Inc()
} }
} }
@ -82,17 +82,19 @@ func (p *progress) render() {
func (p *progress) watch() { func (p *progress) watch() {
Loop: Loop:
for { for {
lastVal := p.count.Load()
select { select {
case <-p.ticChan:
p.count.Inc()
continue Loop
case <-time.After(progressTimeout): case <-time.After(progressTimeout):
p.started.Store(false) if p.count.Load() != lastVal {
break Loop continue
}
if p.started.Swap(false) {
close(p.endChan)
break Loop
}
case <-p.endChan: case <-p.endChan:
return return
} }
} }
logrus.Debug("Progress did not receive any tic. Stopping...") logrus.Debug("Progress did not receive any tic. Stopping...")
p.endChan <- struct{}{}
} }

View File

@ -7,13 +7,24 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestProgressTimeout(t *testing.T) { func TestProgressTimeoutDoesNotInc(t *testing.T) {
progress := NewProgress() progress := NewProgress()
progress.Start() progress.Start()
time.Sleep(progressTimeout + 1) progress.Stop() // should not hang
progress.Inc() // should not hang progress.Inc() // should not hang or inc
assert.Equal(t, uint64(0), progress.Val())
}
func TestProgressTimeoutDoesNotHang(t *testing.T) {
progress := NewProgress()
progress.Start()
time.Sleep(progressTimeout)
for progress.started.Load() == true {
}
progress.Inc() // should not hang or inc
progress.Stop() // should not hang progress.Stop() // should not hang
assert.Equal(t, uint64(0), progress.Val()) assert.Equal(t, uint64(0), progress.Val())
} }
func TestProgress(t *testing.T) { func TestProgress(t *testing.T) {