diff --git a/cmd/buildkitd/main.go b/cmd/buildkitd/main.go index 0eff5e48..48d1eeda 100644 --- a/cmd/buildkitd/main.go +++ b/cmd/buildkitd/main.go @@ -8,8 +8,10 @@ import ( "io/ioutil" "net" "os" + "os/user" "path/filepath" "sort" + "strconv" "strings" "github.com/containerd/containerd/sys" @@ -80,6 +82,11 @@ func main() { Usage: "listening address (socket or tcp)", Value: &cli.StringSlice{appdefaults.Address}, }, + cli.StringFlag{ + Name: "group", + Usage: "group (name or gid) which will own all Unix socket listening addresses", + Value: "", + }, cli.StringFlag{ Name: "debugaddr", Usage: "debugging address (eg. 0.0.0.0:6060)", @@ -142,7 +149,7 @@ func main() { if len(addrs) > 1 { addrs = addrs[1:] // https://github.com/urfave/cli/issues/160 } - if err := serveGRPC(server, addrs, errCh); err != nil { + if err := serveGRPC(c, server, addrs, errCh); err != nil { return err } @@ -181,14 +188,14 @@ func main() { } } -func serveGRPC(server *grpc.Server, addrs []string, errCh chan error) error { +func serveGRPC(c *cli.Context, server *grpc.Server, addrs []string, errCh chan error) error { if len(addrs) == 0 { return errors.New("--addr cannot be empty") } eg, _ := errgroup.WithContext(context.Background()) listeners := make([]net.Listener, 0, len(addrs)) for _, addr := range addrs { - l, err := getListener(addr) + l, err := getListener(c, addr) if err != nil { for _, l := range listeners { l.Close() @@ -212,7 +219,40 @@ func serveGRPC(server *grpc.Server, addrs []string, errCh chan error) error { return nil } -func getListener(addr string) (net.Listener, error) { +// Convert a string containing either a group name or a stringified gid into a numeric id) +func groupToGid(group string) (int, error) { + if group == "" { + return os.Getgid(), nil + } + + var ( + err error + id int + ) + + // Try and parse as a number, if the error is ErrSyntax + // (i.e. its not a number) then we carry on and try it as a + // name. + if id, err = strconv.Atoi(group); err == nil { + return id, nil + } else if err.(*strconv.NumError).Err != strconv.ErrSyntax { + return 0, err + } + + ginfo, err := user.LookupGroup(group) + if err != nil { + return 0, err + } + group = ginfo.Gid + + if id, err = strconv.Atoi(group); err != nil { + return 0, err + } + + return id, nil +} + +func getListener(c *cli.Context, addr string) (net.Listener, error) { addrSlice := strings.SplitN(addr, "://", 2) if len(addrSlice) < 2 { return nil, errors.Errorf("address %s does not contain proto, you meant unix://%s ?", @@ -222,7 +262,12 @@ func getListener(addr string) (net.Listener, error) { listenAddr := addrSlice[1] switch proto { case "unix", "npipe": - return sys.GetLocalListener(listenAddr, os.Getuid(), os.Getgid()) + uid := os.Getuid() + gid, err := groupToGid(c.GlobalString("group")) + if err != nil { + return nil, err + } + return sys.GetLocalListener(listenAddr, uid, gid) case "tcp": return sockets.NewTCPSocket(listenAddr, nil) default: