buildkit/util/progress/progress_test.go

126 lines
2.4 KiB
Go

package progress
import (
"context"
"fmt"
"io"
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
)
func TestProgress(t *testing.T) {
s, err := calc(context.TODO(), 4, "calc")
assert.NoError(t, err)
assert.Equal(t, 10, s)
eg, ctx := errgroup.WithContext(context.Background())
pr, ctx, cancelProgress := NewContext(ctx)
var trace trace
eg.Go(func() error {
return saveProgress(ctx, pr, &trace)
})
s, err = calc(ctx, 5, "calc")
assert.NoError(t, err)
assert.Equal(t, 15, s)
cancelProgress()
err = eg.Wait()
assert.NoError(t, err)
assert.True(t, len(trace.items) > 5)
assert.True(t, len(trace.items) <= 7)
}
func TestProgressNested(t *testing.T) {
eg, ctx := errgroup.WithContext(context.Background())
pr, ctx, cancelProgress := NewContext(ctx)
var trace trace
eg.Go(func() error {
return saveProgress(ctx, pr, &trace)
})
s, err := reduceCalc(ctx, 3)
assert.NoError(t, err)
assert.Equal(t, 6, s)
cancelProgress()
err = eg.Wait()
assert.NoError(t, err)
assert.True(t, len(trace.items) > 9) // usually 14
assert.True(t, len(trace.items) <= 15)
}
func calc(ctx context.Context, total int, name string) (int, error) {
pw, _, ctx := FromContext(ctx, name)
defer pw.Done()
sum := 0
pw.Write(Status{Action: "starting", Total: total})
for i := 1; i <= total; i++ {
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-time.After(10 * time.Millisecond):
}
if i == total {
pw.Write(Status{Action: "done", Total: total, Current: total})
} else {
pw.Write(Status{Action: "calculating", Total: total, Current: i})
}
sum += i
}
pw.Done()
return sum, nil
}
func reduceCalc(ctx context.Context, total int) (int, error) {
eg, ctx := errgroup.WithContext(ctx)
pw, _, ctx := FromContext(ctx, "reduce")
defer pw.Done()
pw.Write(Status{Action: "starting"})
// sync step
sum, err := calc(ctx, total, "synccalc")
if err != nil {
return 0, err
}
// parallel steps
for i := 0; i < 2; i++ {
func(i int) {
eg.Go(func() error {
_, err := calc(ctx, total, fmt.Sprintf("calc-%d", i))
return err
})
}(i)
}
if err := eg.Wait(); err != nil {
return 0, err
}
return sum, nil
}
type trace struct {
items []*Progress
}
func saveProgress(ctx context.Context, pr ProgressReader, t *trace) error {
for {
p, err := pr.Read(ctx)
if err != nil {
if err == io.EOF {
return nil
}
return err
}
t.items = append(t.items, p...)
}
}