client: allow setting custom dialer for session endpoint

Signed-off-by: Tonis Tiigi <tonistiigi@gmail.com>
master
Tonis Tiigi 2021-10-17 20:21:28 -07:00
parent cbf808fb09
commit 1c51e87e16
2 changed files with 21 additions and 3 deletions

View File

@ -30,6 +30,7 @@ import (
type Client struct { type Client struct {
conn *grpc.ClientConn conn *grpc.ClientConn
sessionDialer func(ctx context.Context, proto string, meta map[string][]string) (net.Conn, error)
} }
type ClientOpt interface{} type ClientOpt interface{}
@ -49,6 +50,7 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error
var customTracer bool // allows manually setting disabling tracing even if tracer in context var customTracer bool // allows manually setting disabling tracing even if tracer in context
var tracerProvider trace.TracerProvider var tracerProvider trace.TracerProvider
var tracerDelegate TracerDelegate var tracerDelegate TracerDelegate
var sessionDialer func(context.Context, string, map[string][]string) (net.Conn, error)
for _, o := range opts { for _, o := range opts {
if _, ok := o.(*withFailFast); ok { if _, ok := o.(*withFailFast); ok {
@ -73,6 +75,9 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error
if wt, ok := o.(*withTracerDelegate); ok { if wt, ok := o.(*withTracerDelegate); ok {
tracerDelegate = wt tracerDelegate = wt
} }
if sd, ok := o.(*withSessionDialer); ok {
sessionDialer = sd.dialer
}
} }
if !customTracer { if !customTracer {
@ -132,6 +137,7 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error
c := &Client{ c := &Client{
conn: conn, conn: conn,
sessionDialer: sessionDialer,
} }
if tracerDelegate != nil { if tracerDelegate != nil {
@ -244,6 +250,14 @@ type withTracerDelegate struct {
TracerDelegate TracerDelegate
} }
func WithSessionDialer(dialer func(context.Context, string, map[string][]string) (net.Conn, error)) ClientOpt {
return &withSessionDialer{dialer}
}
type withSessionDialer struct {
dialer func(context.Context, string, map[string][]string) (net.Conn, error)
}
func resolveDialer(address string) (func(context.Context, string) (net.Conn, error), error) { func resolveDialer(address string) (func(context.Context, string) (net.Conn, error), error) {
ch, err := connhelper.GetConnectionHelper(address) ch, err := connhelper.GetConnectionHelper(address)
if err != nil { if err != nil {

View File

@ -162,7 +162,11 @@ func (c *Client) solve(ctx context.Context, def *llb.Definition, runGateway runG
} }
eg.Go(func() error { eg.Go(func() error {
return s.Run(statusContext, grpchijack.Dialer(c.controlClient())) sd := c.sessionDialer
if sd == nil {
sd = grpchijack.Dialer(c.controlClient())
}
return s.Run(statusContext, sd)
}) })
} }