diff --git a/pkg/output/progress.go b/pkg/output/progress.go index 97ed60f9..2893cacd 100644 --- a/pkg/output/progress.go +++ b/pkg/output/progress.go @@ -23,7 +23,6 @@ type Progress interface { } type progress struct { - ticChan chan struct{} endChan chan struct{} started *atomic.Bool count *atomic.Uint64 @@ -31,8 +30,7 @@ type progress struct { func NewProgress() *progress { return &progress{ - make(chan struct{}), - make(chan struct{}), + nil, atomic.NewBool(false), atomic.NewUint64(0), } @@ -40,6 +38,8 @@ func NewProgress() *progress { func (p *progress) Start() { if !p.started.Swap(true) { + p.count.Store(0) + p.endChan = make(chan struct{}) go p.watch() go p.render() } @@ -47,14 +47,14 @@ func (p *progress) Start() { func (p *progress) Stop() { if p.started.Swap(false) { - p.endChan <- struct{}{} + close(p.endChan) Printf("\n") } } func (p *progress) Inc() { if p.started.Load() { - p.ticChan <- struct{}{} + p.count.Inc() } } @@ -82,17 +82,19 @@ func (p *progress) render() { func (p *progress) watch() { Loop: for { + lastVal := p.count.Load() select { - case <-p.ticChan: - p.count.Inc() - continue Loop case <-time.After(progressTimeout): - p.started.Store(false) - break Loop + if p.count.Load() != lastVal { + continue + } + if p.started.Swap(false) { + close(p.endChan) + break Loop + } case <-p.endChan: return } } logrus.Debug("Progress did not receive any tic. Stopping...") - p.endChan <- struct{}{} } diff --git a/pkg/output/progress_test.go b/pkg/output/progress_test.go index 4b6073f7..fd721d8f 100644 --- a/pkg/output/progress_test.go +++ b/pkg/output/progress_test.go @@ -7,13 +7,24 @@ import ( "github.com/stretchr/testify/assert" ) -func TestProgressTimeout(t *testing.T) { +func TestProgressTimeoutDoesNotInc(t *testing.T) { progress := NewProgress() progress.Start() - time.Sleep(progressTimeout + 1) - progress.Inc() // should not hang + progress.Stop() // should not hang + progress.Inc() // should not hang or inc + assert.Equal(t, uint64(0), progress.Val()) +} + +func TestProgressTimeoutDoesNotHang(t *testing.T) { + progress := NewProgress() + progress.Start() + time.Sleep(progressTimeout) + for progress.started.Load() == true { + } + progress.Inc() // should not hang or inc progress.Stop() // should not hang assert.Equal(t, uint64(0), progress.Val()) + } func TestProgress(t *testing.T) {