diff --git a/client/client.go b/client/client.go index a2b54fc5..0da2a12b 100644 --- a/client/client.go +++ b/client/client.go @@ -29,7 +29,8 @@ import ( ) type Client struct { - conn *grpc.ClientConn + conn *grpc.ClientConn + sessionDialer func(ctx context.Context, proto string, meta map[string][]string) (net.Conn, error) } type ClientOpt interface{} @@ -49,6 +50,7 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error var customTracer bool // allows manually setting disabling tracing even if tracer in context var tracerProvider trace.TracerProvider var tracerDelegate TracerDelegate + var sessionDialer func(context.Context, string, map[string][]string) (net.Conn, error) for _, o := range opts { if _, ok := o.(*withFailFast); ok { @@ -73,6 +75,9 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error if wt, ok := o.(*withTracerDelegate); ok { tracerDelegate = wt } + if sd, ok := o.(*withSessionDialer); ok { + sessionDialer = sd.dialer + } } if !customTracer { @@ -131,7 +136,8 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error } c := &Client{ - conn: conn, + conn: conn, + sessionDialer: sessionDialer, } if tracerDelegate != nil { @@ -244,6 +250,14 @@ type withTracerDelegate struct { TracerDelegate } +func WithSessionDialer(dialer func(context.Context, string, map[string][]string) (net.Conn, error)) ClientOpt { + return &withSessionDialer{dialer} +} + +type withSessionDialer struct { + dialer func(context.Context, string, map[string][]string) (net.Conn, error) +} + func resolveDialer(address string) (func(context.Context, string) (net.Conn, error), error) { ch, err := connhelper.GetConnectionHelper(address) if err != nil { diff --git a/client/solve.go b/client/solve.go index 3b765ef9..f53fe842 100644 --- a/client/solve.go +++ b/client/solve.go @@ -162,7 +162,11 @@ func (c *Client) solve(ctx context.Context, def *llb.Definition, runGateway runG } eg.Go(func() error { - return s.Run(statusContext, grpchijack.Dialer(c.controlClient())) + sd := c.sessionDialer + if sd == nil { + sd = grpchijack.Dialer(c.controlClient()) + } + return s.Run(statusContext, sd) }) }