buildkit/util/progress/progress_test.go

133 lines
2.6 KiB
Go

package progress
import (
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
)
func TestProgress(t *testing.T) {
s, err := calc(context.TODO(), 4, "calc")
assert.NoError(t, err)
assert.Equal(t, 10, s)
eg, ctx := errgroup.WithContext(context.Background())
pr, ctx, cancelProgress := NewContext(ctx)
var trace trace
eg.Go(func() error {
return saveProgress(ctx, pr, &trace)
})
s, err = calc(ctx, 5, "calc")
assert.NoError(t, err)
assert.Equal(t, 15, s)
cancelProgress()
err = eg.Wait()
assert.NoError(t, err)
assert.True(t, len(trace.items) > 5)
assert.True(t, len(trace.items) <= 7)
assert.Equal(t, trace.items[len(trace.items)-1].Done, true)
}
func TestProgressNested(t *testing.T) {
eg, ctx := errgroup.WithContext(context.Background())
pr, ctx, cancelProgress := NewContext(ctx)
var trace trace
eg.Go(func() error {
return saveProgress(ctx, pr, &trace)
})
s, err := reduceCalc(ctx, 3)
assert.NoError(t, err)
assert.Equal(t, 6, s)
cancelProgress()
err = eg.Wait()
assert.NoError(t, err)
assert.True(t, len(trace.items) > 9) // usually 14
assert.True(t, len(trace.items) <= 15)
streams := 0
for _, t := range trace.items {
if t.Done {
streams++
}
}
assert.Equal(t, streams, 4)
}
func calc(ctx context.Context, total int, name string) (int, error) {
pw, _, ctx := FromContext(ctx, name)
defer pw.Done()
sum := 0
pw.Write(Status{Action: "starting", Total: total})
for i := 1; i <= total; i++ {
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-time.After(10 * time.Millisecond):
}
if i == total {
pw.Write(Status{Action: "done", Total: total, Current: total})
} else {
pw.Write(Status{Action: "calculating", Total: total, Current: i})
}
sum += i
}
pw.Done()
return sum, nil
}
func reduceCalc(ctx context.Context, total int) (int, error) {
eg, ctx := errgroup.WithContext(ctx)
pw, _, ctx := FromContext(ctx, "reduce")
defer pw.Done()
pw.Write(Status{Action: "starting"})
// sync step
sum, err := calc(ctx, total, "synccalc")
if err != nil {
return 0, err
}
// parallel steps
for i := 0; i < 2; i++ {
func(i int) {
eg.Go(func() error {
_, err := calc(ctx, total, fmt.Sprintf("calc-%d", i))
return err
})
}(i)
}
if err := eg.Wait(); err != nil {
return 0, err
}
return sum, nil
}
type trace struct {
items []Progress
}
func saveProgress(ctx context.Context, pr ProgressReader, t *trace) error {
for {
p, err := pr.Read(ctx)
if err != nil {
return err
}
if p == nil {
return nil
}
t.items = append(t.items, *p)
}
}