Update SSH agent socket handling to support Windows OpenSSH using named pipes.
Signed-off-by: Siebe Schaap <siebe@digibites.nl>v0.9
parent
e51dbe9fae
commit
dbbe65baec
|
@ -35,7 +35,11 @@ func NewSSHAgentProvider(confs []AgentConfig) (session.Attachable, error) {
|
|||
}
|
||||
|
||||
if conf.Paths[0] == "" {
|
||||
return nil, errors.Errorf("invalid empty ssh agent socket, make sure SSH_AUTH_SOCK is set")
|
||||
p, err := getFallbackAgentPath()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid empty ssh agent socket")
|
||||
}
|
||||
conf.Paths[0] = p
|
||||
}
|
||||
|
||||
src, err := toAgentSource(conf.Paths)
|
||||
|
@ -56,7 +60,20 @@ func NewSSHAgentProvider(confs []AgentConfig) (session.Attachable, error) {
|
|||
|
||||
type source struct {
|
||||
agent agent.Agent
|
||||
socket string
|
||||
socket *socketDialer
|
||||
}
|
||||
|
||||
type socketDialer struct {
|
||||
path string
|
||||
dialer func(string) (net.Conn, error)
|
||||
}
|
||||
|
||||
func (s socketDialer) Dial() (net.Conn, error) {
|
||||
return s.dialer(s.path)
|
||||
}
|
||||
|
||||
func (s socketDialer) String() string {
|
||||
return s.path
|
||||
}
|
||||
|
||||
type socketProvider struct {
|
||||
|
@ -94,8 +111,8 @@ func (sp *socketProvider) ForwardAgent(stream sshforward.SSH_ForwardAgentServer)
|
|||
|
||||
var a agent.Agent
|
||||
|
||||
if src.socket != "" {
|
||||
conn, err := net.DialTimeout("unix", src.socket, time.Second)
|
||||
if src.socket != nil {
|
||||
conn, err := src.socket.Dial()
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to connect to %s", src.socket)
|
||||
}
|
||||
|
@ -124,21 +141,24 @@ func (sp *socketProvider) ForwardAgent(stream sshforward.SSH_ForwardAgentServer)
|
|||
|
||||
func toAgentSource(paths []string) (source, error) {
|
||||
var keys bool
|
||||
var socket string
|
||||
var socket *socketDialer
|
||||
a := agent.NewKeyring()
|
||||
for _, p := range paths {
|
||||
if socket != "" {
|
||||
if socket != nil {
|
||||
return source{}, errors.New("only single socket allowed")
|
||||
}
|
||||
|
||||
if parsed := parsePlatformSocketPath(p); parsed != nil {
|
||||
socket = parsed
|
||||
continue
|
||||
}
|
||||
|
||||
fi, err := os.Stat(p)
|
||||
if err != nil {
|
||||
return source{}, errors.WithStack(err)
|
||||
}
|
||||
if fi.Mode()&os.ModeSocket > 0 {
|
||||
if keys {
|
||||
return source{}, errors.Errorf("invalid combination of keys and sockets")
|
||||
}
|
||||
socket = p
|
||||
socket = &socketDialer{path: p, dialer: unixSocketDialer}
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -160,7 +180,7 @@ func toAgentSource(paths []string) (source, error) {
|
|||
if keys {
|
||||
return source{}, errors.Errorf("invalid combination of keys and sockets")
|
||||
}
|
||||
socket = p
|
||||
socket = &socketDialer{path: p, dialer: unixSocketDialer}
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -173,13 +193,20 @@ func toAgentSource(paths []string) (source, error) {
|
|||
keys = true
|
||||
}
|
||||
|
||||
if socket != "" {
|
||||
if socket != nil {
|
||||
if keys {
|
||||
return source{}, errors.Errorf("invalid combination of keys and sockets")
|
||||
}
|
||||
return source{socket: socket}, nil
|
||||
}
|
||||
|
||||
return source{agent: a}, nil
|
||||
}
|
||||
|
||||
func unixSocketDialer(path string) (net.Conn, error) {
|
||||
return net.DialTimeout("unix", path, 2*time.Second)
|
||||
}
|
||||
|
||||
func sockPair() (io.ReadWriteCloser, io.ReadWriteCloser) {
|
||||
pr1, pw1 := io.Pipe()
|
||||
pr2, pw2 := io.Pipe()
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
// +build !windows
|
||||
|
||||
package sshprovider
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func getFallbackAgentPath() (string, error) {
|
||||
return "", errors.Errorf("make sure SSH_AUTH_SOCK is set")
|
||||
}
|
||||
|
||||
func parsePlatformSocketPath(path string) (*socketDialer) {
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
// +build windows
|
||||
|
||||
package sshprovider
|
||||
|
||||
import (
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/Microsoft/go-winio"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Returns the Windows OpenSSH agent named pipe path, but
|
||||
// only if the agent is running. Returns an error otherwise.
|
||||
func getFallbackAgentPath() (string, error) {
|
||||
// Windows OpenSSH agent uses a named pipe rather
|
||||
// than a UNIX socket. These pipes do not play nice
|
||||
// with os.Stat (which tries to open its target), so
|
||||
// use a FindFirstFile syscall to check for existence.
|
||||
var fd syscall.Win32finddata
|
||||
|
||||
path := `\\.\pipe\openssh-ssh-agent`
|
||||
pathPtr, _ := syscall.UTF16PtrFromString(path)
|
||||
handle, err := syscall.FindFirstFile(pathPtr, &fd)
|
||||
|
||||
if err != nil {
|
||||
msg := "Windows OpenSSH agent not available at %s." +
|
||||
" Enable the SSH agent service or set SSH_AUTH_SOCK."
|
||||
return "", errors.Errorf(msg, path)
|
||||
}
|
||||
|
||||
_ = syscall.CloseHandle(handle)
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// Returns true if the path references a named pipe.
|
||||
func isWindowsPipePath(path string) bool {
|
||||
// If path matches \\*\pipe\* then it references a named pipe
|
||||
// and requires winio.DialPipe() rather than DialTimeout("unix").
|
||||
// Slashes and backslashes may be used interchangeably in the path.
|
||||
// Path separators may consist of multiple consecutive (back)slashes.
|
||||
pipePattern := strings.ReplaceAll("^[/]{2}[^/]+[/]+pipe[/]+", "[/]", `[\\/]`)
|
||||
ok, _ := regexp.MatchString(pipePattern, path)
|
||||
return ok
|
||||
}
|
||||
|
||||
func parsePlatformSocketPath(path string) *socketDialer {
|
||||
if isWindowsPipePath(path) {
|
||||
return &socketDialer{path: path, dialer: windowsPipeDialer}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func windowsPipeDialer(path string) (net.Conn, error) {
|
||||
return winio.DialPipe(path, nil)
|
||||
}
|
Loading…
Reference in New Issue