Update SSH agent socket handling to support Windows OpenSSH using named pipes.

Signed-off-by: Siebe Schaap <siebe@digibites.nl>
v0.9
Siebe Schaap 2021-05-20 21:50:35 +02:00
parent e51dbe9fae
commit dbbe65baec
3 changed files with 114 additions and 12 deletions

View File

@ -35,7 +35,11 @@ func NewSSHAgentProvider(confs []AgentConfig) (session.Attachable, error) {
}
if conf.Paths[0] == "" {
return nil, errors.Errorf("invalid empty ssh agent socket, make sure SSH_AUTH_SOCK is set")
p, err := getFallbackAgentPath()
if err != nil {
return nil, errors.Wrap(err, "invalid empty ssh agent socket")
}
conf.Paths[0] = p
}
src, err := toAgentSource(conf.Paths)
@ -56,7 +60,20 @@ func NewSSHAgentProvider(confs []AgentConfig) (session.Attachable, error) {
type source struct {
agent agent.Agent
socket string
socket *socketDialer
}
type socketDialer struct {
path string
dialer func(string) (net.Conn, error)
}
func (s socketDialer) Dial() (net.Conn, error) {
return s.dialer(s.path)
}
func (s socketDialer) String() string {
return s.path
}
type socketProvider struct {
@ -94,8 +111,8 @@ func (sp *socketProvider) ForwardAgent(stream sshforward.SSH_ForwardAgentServer)
var a agent.Agent
if src.socket != "" {
conn, err := net.DialTimeout("unix", src.socket, time.Second)
if src.socket != nil {
conn, err := src.socket.Dial()
if err != nil {
return errors.Wrapf(err, "failed to connect to %s", src.socket)
}
@ -124,21 +141,24 @@ func (sp *socketProvider) ForwardAgent(stream sshforward.SSH_ForwardAgentServer)
func toAgentSource(paths []string) (source, error) {
var keys bool
var socket string
var socket *socketDialer
a := agent.NewKeyring()
for _, p := range paths {
if socket != "" {
if socket != nil {
return source{}, errors.New("only single socket allowed")
}
if parsed := parsePlatformSocketPath(p); parsed != nil {
socket = parsed
continue
}
fi, err := os.Stat(p)
if err != nil {
return source{}, errors.WithStack(err)
}
if fi.Mode()&os.ModeSocket > 0 {
if keys {
return source{}, errors.Errorf("invalid combination of keys and sockets")
}
socket = p
socket = &socketDialer{path: p, dialer: unixSocketDialer}
continue
}
@ -160,7 +180,7 @@ func toAgentSource(paths []string) (source, error) {
if keys {
return source{}, errors.Errorf("invalid combination of keys and sockets")
}
socket = p
socket = &socketDialer{path: p, dialer: unixSocketDialer}
continue
}
@ -173,13 +193,20 @@ func toAgentSource(paths []string) (source, error) {
keys = true
}
if socket != "" {
if socket != nil {
if keys {
return source{}, errors.Errorf("invalid combination of keys and sockets")
}
return source{socket: socket}, nil
}
return source{agent: a}, nil
}
func unixSocketDialer(path string) (net.Conn, error) {
return net.DialTimeout("unix", path, 2*time.Second)
}
func sockPair() (io.ReadWriteCloser, io.ReadWriteCloser) {
pr1, pw1 := io.Pipe()
pr2, pw2 := io.Pipe()

View File

@ -0,0 +1,15 @@
// +build !windows
package sshprovider
import (
"github.com/pkg/errors"
)
func getFallbackAgentPath() (string, error) {
return "", errors.Errorf("make sure SSH_AUTH_SOCK is set")
}
func parsePlatformSocketPath(path string) (*socketDialer) {
return nil
}

View File

@ -0,0 +1,60 @@
// +build windows
package sshprovider
import (
"net"
"regexp"
"strings"
"syscall"
"github.com/Microsoft/go-winio"
"github.com/pkg/errors"
)
// Returns the Windows OpenSSH agent named pipe path, but
// only if the agent is running. Returns an error otherwise.
func getFallbackAgentPath() (string, error) {
// Windows OpenSSH agent uses a named pipe rather
// than a UNIX socket. These pipes do not play nice
// with os.Stat (which tries to open its target), so
// use a FindFirstFile syscall to check for existence.
var fd syscall.Win32finddata
path := `\\.\pipe\openssh-ssh-agent`
pathPtr, _ := syscall.UTF16PtrFromString(path)
handle, err := syscall.FindFirstFile(pathPtr, &fd)
if err != nil {
msg := "Windows OpenSSH agent not available at %s." +
" Enable the SSH agent service or set SSH_AUTH_SOCK."
return "", errors.Errorf(msg, path)
}
_ = syscall.CloseHandle(handle)
return path, nil
}
// Returns true if the path references a named pipe.
func isWindowsPipePath(path string) bool {
// If path matches \\*\pipe\* then it references a named pipe
// and requires winio.DialPipe() rather than DialTimeout("unix").
// Slashes and backslashes may be used interchangeably in the path.
// Path separators may consist of multiple consecutive (back)slashes.
pipePattern := strings.ReplaceAll("^[/]{2}[^/]+[/]+pipe[/]+", "[/]", `[\\/]`)
ok, _ := regexp.MatchString(pipePattern, path)
return ok
}
func parsePlatformSocketPath(path string) *socketDialer {
if isWindowsPipePath(path) {
return &socketDialer{path: path, dialer: windowsPipeDialer}
}
return nil
}
func windowsPipeDialer(path string) (net.Conn, error) {
return winio.DialPipe(path, nil)
}