driftctl/pkg/output/progress.go

144 lines
2.5 KiB
Go

package output
import (
"fmt"
"time"
"go.uber.org/atomic"
"github.com/sirupsen/logrus"
)
var spinner = []string{"⣾", "⣽", "⣻", "⢿", "⡿", "⣟", "⣯", "⣷"}
const (
progressTimeout = 10 * time.Second
progressRefreshRate = 200 * time.Millisecond
)
type Progress interface {
Start()
Stop()
Inc()
Val() uint64
}
type ProgressOptions struct {
LoadingText string
FinishedText string
ShowCount bool
}
type progress struct {
endChan chan struct{}
started *atomic.Bool
count *atomic.Uint64
loadingText string
finishedText string
showCount bool
highestLineLength int
}
func NewProgress(loadingText, finishedText string, showCount bool) *progress {
return &progress{
nil,
atomic.NewBool(false),
atomic.NewUint64(0),
loadingText,
finishedText,
showCount,
0,
}
}
func (p *progress) Start() {
if !p.started.Swap(true) {
p.count.Store(0)
p.endChan = make(chan struct{})
go p.watch()
go p.render()
}
}
func (p *progress) Stop() {
if p.started.Swap(false) {
if p.showCount {
p.printf("%s (%d)\n", p.finishedText, p.count.Load())
} else {
p.printf("%s\r", p.finishedText)
}
close(p.endChan)
}
}
func (p *progress) Inc() {
if lastVal := p.count.Load(); !p.started.Load() {
logrus.Debug("Progress received a tic after stopping. Restarting...")
p.Start()
p.count.Store(lastVal)
}
p.count.Inc()
}
func (p *progress) Val() uint64 {
return p.count.Load()
}
func (p *progress) render() {
i := -1
p.printf("%s\r", p.loadingText)
for {
select {
case <-p.endChan:
return
case <-time.After(progressRefreshRate):
i++
if i >= len(spinner) {
i = 0
}
if p.showCount {
p.printf("%s %s (%d)\r", p.loadingText, spinner[i], p.count.Load())
} else {
p.printf("%s %s\r", p.loadingText, spinner[i])
}
}
}
}
func (p *progress) watch() {
Loop:
for {
lastVal := p.count.Load()
select {
case <-time.After(progressTimeout):
if p.count.Load() != lastVal {
continue
}
if p.started.Swap(false) {
close(p.endChan)
break Loop
}
case <-p.endChan:
return
}
}
logrus.Debug("Progress did not receive any tic. Stopping...")
}
func (p *progress) flush() {
for i := 0; i < p.highestLineLength; i++ {
Printf(" ")
}
Printf("\r")
}
func (p *progress) printf(format string, args ...interface{}) {
txt := fmt.Sprintf(format, args...)
length := len(txt)
if length > p.highestLineLength {
p.highestLineLength = length
}
p.flush()
Printf(txt)
}