session: release forwarded ssh socket connection per connection

Signed-off-by: Tonis Tiigi <tonistiigi@gmail.com>
v0.7
Tonis Tiigi 2019-08-26 15:34:55 -07:00
parent 6885427759
commit bc3a1eefdd
3 changed files with 7 additions and 4 deletions

View File

@ -9,17 +9,17 @@ import (
"google.golang.org/grpc"
)
func Copy(ctx context.Context, conn io.ReadWriteCloser, stream grpc.Stream) error {
func Copy(ctx context.Context, conn io.ReadWriteCloser, stream grpc.Stream, closeStream func() error) error {
g, ctx := errgroup.WithContext(ctx)
g.Go(func() (retErr error) {
p := &BytesMessage{}
for {
if err := stream.RecvMsg(p); err != nil {
conn.Close()
if err == io.EOF {
return nil
}
conn.Close()
return errors.WithStack(err)
}
select {
@ -42,6 +42,9 @@ func Copy(ctx context.Context, conn io.ReadWriteCloser, stream grpc.Stream) erro
n, err := conn.Read(buf)
switch {
case err == io.EOF:
if closeStream != nil {
closeStream()
}
return nil
case err != nil:
return errors.WithStack(err)

View File

@ -49,7 +49,7 @@ func (s *server) run(ctx context.Context, l net.Listener, id string) error {
return err
}
go Copy(ctx, conn, stream)
go Copy(ctx, conn, stream, stream.CloseSend)
}
})

View File

@ -114,7 +114,7 @@ func (sp *socketProvider) ForwardAgent(stream sshforward.SSH_ForwardAgentServer)
eg.Go(func() error {
defer s1.Close()
return sshforward.Copy(ctx, s2, stream)
return sshforward.Copy(ctx, s2, stream, nil)
})
return eg.Wait()