diff --git a/solver-next/jobs.go b/solver-next/jobs.go index 165eae50..270bb8f1 100644 --- a/solver-next/jobs.go +++ b/solver-next/jobs.go @@ -14,7 +14,11 @@ import ( ) // ResolveOpFunc finds an Op implementation for a Vertex -type ResolveOpFunc func(Vertex) (Op, error) +type ResolveOpFunc func(Vertex, Builder) (Op, error) + +type Builder interface { + Build(ctx context.Context, e Edge) (CachedResult, error) +} // JobList provides a shared graph of all the vertexes currently being // processed. Every vertex that is being solved needs to be loaded into job @@ -51,6 +55,11 @@ type state struct { cache map[string]CacheManager mainCache CacheManager + jobList *JobList +} + +func (s *state) builder() Builder { + return &subBuilder{state: s} } func (s *state) getEdge(index Index) *edge { @@ -106,6 +115,14 @@ func (s *state) Release() { } } +type subBuilder struct { + *state +} + +func (sb *subBuilder) Build(ctx context.Context, e Edge) (CachedResult, error) { + return sb.jobList.subBuild(ctx, e, sb.vtx) +} + type Job struct { list *JobList pr *progress.MultiReader @@ -157,7 +174,7 @@ func (jl *JobList) GetEdge(e Edge) *edge { return st.getEdge(e.Index) } -func (jl *JobList) SubBuild(ctx context.Context, e Edge, parent Vertex) (CachedResult, error) { +func (jl *JobList) subBuild(ctx context.Context, e Edge, parent Vertex) (CachedResult, error) { v, err := jl.load(e.Vertex, parent, nil) if err != nil { return nil, err @@ -223,6 +240,7 @@ func (jl *JobList) loadUnlocked(v, parent Vertex, j *Job) (Vertex, error) { index: jl.index, mainCache: jl.opts.DefaultCache, cache: map[string]CacheManager{}, + jobList: jl, } jl.actives[dgst] = st } @@ -526,7 +544,7 @@ func (s *sharedOp) Exec(ctx context.Context, inputs []Result) (outputs []Result, func (s *sharedOp) getOp() (Op, error) { s.opOnce.Do(func() { - s.op, s.err = s.resolver(s.st.vtx) + s.op, s.err = s.resolver(s.st.vtx, s.st.builder()) }) if s.err != nil { return nil, s.err diff --git a/solver-next/scheduler_test.go b/solver-next/scheduler_test.go index 4b48ff27..88a9f762 100644 --- a/solver-next/scheduler_test.go +++ b/solver-next/scheduler_test.go @@ -1796,6 +1796,68 @@ func TestParallelBuildsIgnoreCache(t *testing.T) { require.Equal(t, unwrap(res), "result3") } +func TestSubbuild(t *testing.T) { + t.Parallel() + ctx := context.TODO() + + l := NewJobList(SolverOpt{ + ResolveOpFunc: testOpResolver, + }) + defer l.Close() + + j0, err := l.NewJob("j0") + require.NoError(t, err) + + defer func() { + if j0 != nil { + j0.Discard() + } + }() + + g0 := Edge{ + Vertex: vtxSum(1, vtxOpt{ + inputs: []Edge{ + {Vertex: vtxSubBuild(Edge{Vertex: vtxConst(7, vtxOpt{})}, vtxOpt{ + cacheKeySeed: "seed0", + })}, + }, + }), + } + g0.Vertex.(*vertexSum).setupCallCounters() + + res, err := j0.Build(ctx, g0) + require.NoError(t, err) + require.Equal(t, unwrapInt(res), 8) + + require.Equal(t, int64(2), *g0.Vertex.(*vertexSum).cacheCallCount) + require.Equal(t, int64(2), *g0.Vertex.(*vertexSum).execCallCount) + + require.NoError(t, j0.Discard()) + j0 = nil + + j1, err := l.NewJob("j1") + require.NoError(t, err) + + defer func() { + if j1 != nil { + j1.Discard() + } + }() + + g0.Vertex.(*vertexSum).setupCallCounters() + + res, err = j1.Build(ctx, g0) + require.NoError(t, err) + require.Equal(t, unwrapInt(res), 8) + + require.Equal(t, int64(2), *g0.Vertex.(*vertexSum).cacheCallCount) + require.Equal(t, int64(0), *g0.Vertex.(*vertexSum).execCallCount) + + require.NoError(t, j1.Discard()) + j1 = nil + +} + func generateSubGraph(nodes int) (Edge, int) { if nodes == 1 { value := rand.Int() % 500 @@ -1898,6 +1960,8 @@ func (v *vertex) setCallCounters(cacheCount, execCount *int64) { v = vv.vertex case *vertexConst: v = vv.vertex + case *vertexSubBuild: + v = vv.vertex } v.setCallCounters(cacheCount, execCount) } @@ -2040,6 +2104,37 @@ func (v *vertexSum) Exec(ctx context.Context, inputs []Result) (outputs []Result return []Result{&dummyResult{id: identity.NewID(), intValue: s}}, nil } +func vtxSubBuild(g Edge, opt vtxOpt) *vertexSubBuild { + if opt.cacheKeySeed == "" { + opt.cacheKeySeed = fmt.Sprintf("sum-%s", identity.NewID()) + } + if opt.name == "" { + opt.name = opt.cacheKeySeed + "-" + identity.NewID() + } + return &vertexSubBuild{vertex: vtx(opt), g: g} +} + +type vertexSubBuild struct { + *vertex + g Edge + b Builder +} + +func (v *vertexSubBuild) Sys() interface{} { + return v +} + +func (v *vertexSubBuild) Exec(ctx context.Context, inputs []Result) (outputs []Result, err error) { + if err := v.exec(ctx, inputs); err != nil { + return nil, err + } + res, err := v.b.Build(ctx, v.g) + if err != nil { + return nil, err + } + return []Result{res}, nil +} + func printGraph(e Edge, pfx string) { name := e.Vertex.Name() fmt.Printf("%s %d %s\n", pfx, e.Index, name) @@ -2058,10 +2153,14 @@ func (r *dummyResult) ID() string { return r.id } func (r *dummyResult) Release(context.Context) error { return nil } func (r *dummyResult) Sys() interface{} { return r } -func testOpResolver(v Vertex) (Op, error) { +func testOpResolver(v Vertex, b Builder) (Op, error) { if op, ok := v.Sys().(Op); ok { + if vtx, ok := op.(*vertexSubBuild); ok { + vtx.b = b + } return op, nil } + return nil, errors.Errorf("invalid vertex") }