219 lines
4.3 KiB
Go
219 lines
4.3 KiB
Go
package session
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/pkg/errors"
|
|
"google.golang.org/grpc"
|
|
)
|
|
|
|
// Caller can invoke requests on the session
|
|
type Caller interface {
|
|
Context() context.Context
|
|
Supports(method string) bool
|
|
Conn() *grpc.ClientConn
|
|
Name() string
|
|
SharedKey() string
|
|
}
|
|
|
|
type client struct {
|
|
Session
|
|
cc *grpc.ClientConn
|
|
supported map[string]struct{}
|
|
}
|
|
|
|
// Manager is a controller for accessing currently active sessions
|
|
type Manager struct {
|
|
sessions map[string]*client
|
|
mu sync.Mutex
|
|
updateCondition *sync.Cond
|
|
}
|
|
|
|
// NewManager returns a new Manager
|
|
func NewManager() (*Manager, error) {
|
|
sm := &Manager{
|
|
sessions: make(map[string]*client),
|
|
}
|
|
sm.updateCondition = sync.NewCond(&sm.mu)
|
|
return sm, nil
|
|
}
|
|
|
|
// HandleHTTPRequest handles an incoming HTTP request
|
|
func (sm *Manager) HandleHTTPRequest(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
|
|
hijacker, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
return errors.New("handler does not support hijack")
|
|
}
|
|
|
|
id := r.Header.Get(headerSessionID)
|
|
|
|
proto := r.Header.Get("Upgrade")
|
|
|
|
sm.mu.Lock()
|
|
if _, ok := sm.sessions[id]; ok {
|
|
sm.mu.Unlock()
|
|
return errors.Errorf("session %s already exists", id)
|
|
}
|
|
|
|
if proto == "" {
|
|
sm.mu.Unlock()
|
|
return errors.New("no upgrade proto in request")
|
|
}
|
|
|
|
if proto != "h2c" {
|
|
sm.mu.Unlock()
|
|
return errors.Errorf("protocol %s not supported", proto)
|
|
}
|
|
|
|
conn, _, err := hijacker.Hijack()
|
|
if err != nil {
|
|
sm.mu.Unlock()
|
|
return errors.Wrap(err, "failed to hijack connection")
|
|
}
|
|
|
|
resp := &http.Response{
|
|
StatusCode: http.StatusSwitchingProtocols,
|
|
ProtoMajor: 1,
|
|
ProtoMinor: 1,
|
|
Header: http.Header{},
|
|
}
|
|
resp.Header.Set("Connection", "Upgrade")
|
|
resp.Header.Set("Upgrade", proto)
|
|
|
|
// set raw mode
|
|
conn.Write([]byte{})
|
|
resp.Write(conn)
|
|
|
|
return sm.handleConn(ctx, conn, r.Header)
|
|
}
|
|
|
|
// HandleConn handles an incoming raw connection
|
|
func (sm *Manager) HandleConn(ctx context.Context, conn net.Conn, opts map[string][]string) error {
|
|
sm.mu.Lock()
|
|
return sm.handleConn(ctx, conn, opts)
|
|
}
|
|
|
|
// caller needs to take lock, this function will release it
|
|
func (sm *Manager) handleConn(ctx context.Context, conn net.Conn, opts map[string][]string) error {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
opts = canonicalHeaders(opts)
|
|
|
|
h := http.Header(opts)
|
|
id := h.Get(headerSessionID)
|
|
name := h.Get(headerSessionName)
|
|
sharedKey := h.Get(headerSessionSharedKey)
|
|
|
|
ctx, cc, err := grpcClientConn(ctx, conn)
|
|
if err != nil {
|
|
sm.mu.Unlock()
|
|
return err
|
|
}
|
|
|
|
c := &client{
|
|
Session: Session{
|
|
id: id,
|
|
name: name,
|
|
sharedKey: sharedKey,
|
|
ctx: ctx,
|
|
cancelCtx: cancel,
|
|
done: make(chan struct{}),
|
|
},
|
|
cc: cc,
|
|
supported: make(map[string]struct{}),
|
|
}
|
|
|
|
for _, m := range opts[headerSessionMethod] {
|
|
c.supported[strings.ToLower(m)] = struct{}{}
|
|
}
|
|
sm.sessions[id] = c
|
|
sm.updateCondition.Broadcast()
|
|
sm.mu.Unlock()
|
|
|
|
defer func() {
|
|
sm.mu.Lock()
|
|
delete(sm.sessions, id)
|
|
sm.mu.Unlock()
|
|
}()
|
|
|
|
<-c.ctx.Done()
|
|
conn.Close()
|
|
close(c.done)
|
|
|
|
return nil
|
|
}
|
|
|
|
// Get returns a session by ID
|
|
func (sm *Manager) Get(ctx context.Context, id string) (Caller, error) {
|
|
// session prefix is used to identify vertexes with different contexts so
|
|
// they would not collide, but for lookup we don't need the prefix
|
|
if p := strings.SplitN(id, ":", 2); len(p) == 2 && len(p[1]) > 0 {
|
|
id = p[1]
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
sm.updateCondition.Broadcast()
|
|
}
|
|
}()
|
|
|
|
var c *client
|
|
|
|
sm.mu.Lock()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
sm.mu.Unlock()
|
|
return nil, errors.Wrapf(ctx.Err(), "no active session for %s", id)
|
|
default:
|
|
}
|
|
var ok bool
|
|
c, ok = sm.sessions[id]
|
|
if !ok || c.closed() {
|
|
sm.updateCondition.Wait()
|
|
continue
|
|
}
|
|
sm.mu.Unlock()
|
|
break
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func (c *client) Context() context.Context {
|
|
return c.context()
|
|
}
|
|
|
|
func (c *client) Name() string {
|
|
return c.name
|
|
}
|
|
|
|
func (c *client) SharedKey() string {
|
|
return c.sharedKey
|
|
}
|
|
|
|
func (c *client) Supports(url string) bool {
|
|
_, ok := c.supported[strings.ToLower(url)]
|
|
return ok
|
|
}
|
|
func (c *client) Conn() *grpc.ClientConn {
|
|
return c.cc
|
|
}
|
|
|
|
func canonicalHeaders(in map[string][]string) map[string][]string {
|
|
out := map[string][]string{}
|
|
for k := range in {
|
|
out[http.CanonicalHeaderKey(k)] = in[k]
|
|
}
|
|
return out
|
|
}
|