fix race condition on progress that could cause deadlock
parent
ead1d41ce6
commit
d8665726b9
|
@ -23,7 +23,6 @@ type Progress interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type progress struct {
|
type progress struct {
|
||||||
ticChan chan struct{}
|
|
||||||
endChan chan struct{}
|
endChan chan struct{}
|
||||||
started *atomic.Bool
|
started *atomic.Bool
|
||||||
count *atomic.Uint64
|
count *atomic.Uint64
|
||||||
|
@ -31,8 +30,7 @@ type progress struct {
|
||||||
|
|
||||||
func NewProgress() *progress {
|
func NewProgress() *progress {
|
||||||
return &progress{
|
return &progress{
|
||||||
make(chan struct{}),
|
nil,
|
||||||
make(chan struct{}),
|
|
||||||
atomic.NewBool(false),
|
atomic.NewBool(false),
|
||||||
atomic.NewUint64(0),
|
atomic.NewUint64(0),
|
||||||
}
|
}
|
||||||
|
@ -40,6 +38,8 @@ func NewProgress() *progress {
|
||||||
|
|
||||||
func (p *progress) Start() {
|
func (p *progress) Start() {
|
||||||
if !p.started.Swap(true) {
|
if !p.started.Swap(true) {
|
||||||
|
p.count.Store(0)
|
||||||
|
p.endChan = make(chan struct{})
|
||||||
go p.watch()
|
go p.watch()
|
||||||
go p.render()
|
go p.render()
|
||||||
}
|
}
|
||||||
|
@ -47,14 +47,14 @@ func (p *progress) Start() {
|
||||||
|
|
||||||
func (p *progress) Stop() {
|
func (p *progress) Stop() {
|
||||||
if p.started.Swap(false) {
|
if p.started.Swap(false) {
|
||||||
p.endChan <- struct{}{}
|
close(p.endChan)
|
||||||
Printf("\n")
|
Printf("\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *progress) Inc() {
|
func (p *progress) Inc() {
|
||||||
if p.started.Load() {
|
if p.started.Load() {
|
||||||
p.ticChan <- struct{}{}
|
p.count.Inc()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,17 +82,19 @@ func (p *progress) render() {
|
||||||
func (p *progress) watch() {
|
func (p *progress) watch() {
|
||||||
Loop:
|
Loop:
|
||||||
for {
|
for {
|
||||||
|
lastVal := p.count.Load()
|
||||||
select {
|
select {
|
||||||
case <-p.ticChan:
|
|
||||||
p.count.Inc()
|
|
||||||
continue Loop
|
|
||||||
case <-time.After(progressTimeout):
|
case <-time.After(progressTimeout):
|
||||||
p.started.Store(false)
|
if p.count.Load() != lastVal {
|
||||||
break Loop
|
continue
|
||||||
|
}
|
||||||
|
if p.started.Swap(false) {
|
||||||
|
close(p.endChan)
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
case <-p.endChan:
|
case <-p.endChan:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logrus.Debug("Progress did not receive any tic. Stopping...")
|
logrus.Debug("Progress did not receive any tic. Stopping...")
|
||||||
p.endChan <- struct{}{}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,13 +7,24 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProgressTimeout(t *testing.T) {
|
func TestProgressTimeoutDoesNotInc(t *testing.T) {
|
||||||
progress := NewProgress()
|
progress := NewProgress()
|
||||||
progress.Start()
|
progress.Start()
|
||||||
time.Sleep(progressTimeout + 1)
|
progress.Stop() // should not hang
|
||||||
progress.Inc() // should not hang
|
progress.Inc() // should not hang or inc
|
||||||
|
assert.Equal(t, uint64(0), progress.Val())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProgressTimeoutDoesNotHang(t *testing.T) {
|
||||||
|
progress := NewProgress()
|
||||||
|
progress.Start()
|
||||||
|
time.Sleep(progressTimeout)
|
||||||
|
for progress.started.Load() == true {
|
||||||
|
}
|
||||||
|
progress.Inc() // should not hang or inc
|
||||||
progress.Stop() // should not hang
|
progress.Stop() // should not hang
|
||||||
assert.Equal(t, uint64(0), progress.Val())
|
assert.Equal(t, uint64(0), progress.Val())
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProgress(t *testing.T) {
|
func TestProgress(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue