session: release forwarded ssh socket connection per connection
Signed-off-by: Tonis Tiigi <tonistiigi@gmail.com>v0.7
parent
6885427759
commit
bc3a1eefdd
|
@ -9,17 +9,17 @@ import (
|
|||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func Copy(ctx context.Context, conn io.ReadWriteCloser, stream grpc.Stream) error {
|
||||
func Copy(ctx context.Context, conn io.ReadWriteCloser, stream grpc.Stream, closeStream func() error) error {
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
g.Go(func() (retErr error) {
|
||||
p := &BytesMessage{}
|
||||
for {
|
||||
if err := stream.RecvMsg(p); err != nil {
|
||||
conn.Close()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
conn.Close()
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
select {
|
||||
|
@ -42,6 +42,9 @@ func Copy(ctx context.Context, conn io.ReadWriteCloser, stream grpc.Stream) erro
|
|||
n, err := conn.Read(buf)
|
||||
switch {
|
||||
case err == io.EOF:
|
||||
if closeStream != nil {
|
||||
closeStream()
|
||||
}
|
||||
return nil
|
||||
case err != nil:
|
||||
return errors.WithStack(err)
|
||||
|
|
|
@ -49,7 +49,7 @@ func (s *server) run(ctx context.Context, l net.Listener, id string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
go Copy(ctx, conn, stream)
|
||||
go Copy(ctx, conn, stream, stream.CloseSend)
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
@ -114,7 +114,7 @@ func (sp *socketProvider) ForwardAgent(stream sshforward.SSH_ForwardAgentServer)
|
|||
|
||||
eg.Go(func() error {
|
||||
defer s1.Close()
|
||||
return sshforward.Copy(ctx, s2, stream)
|
||||
return sshforward.Copy(ctx, s2, stream, nil)
|
||||
})
|
||||
|
||||
return eg.Wait()
|
||||
|
|
Loading…
Reference in New Issue