From 00e905d394a85a8fd8ab62a4d941482e443a6002 Mon Sep 17 00:00:00 2001 From: Emir Aganovic Date: Wed, 21 Dec 2022 19:57:56 +0100 Subject: [PATCH] Adding Websocket support --- README.md | 34 ++-- go.mod | 5 +- go.sum | 7 + server_integration_test.go | 65 ++++++++ transport/layer.go | 27 +++- transport/transport.go | 1 + transport/ws.go | 321 +++++++++++++++++++++++++++++++++++++ 7 files changed, 447 insertions(+), 13 deletions(-) create mode 100644 server_integration_test.go create mode 100644 transport/ws.go diff --git a/README.md b/README.md index c5ebc33..92fea13 100644 --- a/README.md +++ b/README.md @@ -25,24 +25,29 @@ Lib allows you to write easily client or server or to build up stateful proxies, Writing in GO we are not limited to handle SIP requests/responses in many ways, or to integrate and scale with any external services (databases, caches...). -### UAS build +### UAS/UAC build ```go ua, _ := sipgo.NewUA() // Build user agent srv, _ := sipgo.NewServer(ua) // Creating server handle -srv.OnRegister(registerHandler) +client, _ := sipgo.NewClient(ua) // Creating client handle srv.OnInvite(inviteHandler) srv.OnAck(ackHandler) srv.OnCancel(cancelHandler) srv.OnBye(byeHandler) +// For registrars +// srv.OnRegister(registerHandler) + + // Add listeners srv.Listen("udp", "127.0.0.1:5060") srv.Listen("tcp", "127.0.0.1:5061") ... -// Start serving +// fire server srv.Serve() ``` + ### Server Transaction @@ -69,7 +74,7 @@ srv.OnInvite(func(req *sip.Request, tx sip.ServerTransaction) { ``` -### Stateless response +### Server stateless response ```go srv := sipgo.NewServer() @@ -82,15 +87,13 @@ srv.OnACK(ackHandler) ``` -### UAC build -```go -ua, _ := sipgo.NewUA() // Build user agent -client, _ := sipgo.NewClient(ua) // Creating client handle -``` - ### Client Transaction +**NOTE**: UA needs server handle and listener on same network before sending request + + ```go +client, _ := sipgo.NewClient(ua) // Creating client handle // Request is either from server request handler or created req.SetDestination("10.1.2.3") // Change sip.Request destination @@ -109,6 +112,15 @@ select { ``` +### Client stateless request + +```go +client, _ := sipgo.NewClient(ua) // Creating client handle +req := sip.NewRequest(method, &recipment, "SIP/2.0") +// Send request and forget +client.WriteRequest(req) +``` + ## Proxy build Proxy is combination client and server handle. @@ -147,7 +159,7 @@ More on documentation you can find on [Go doc](https://pkg.go.dev/github.com/emi - [x] UDP - [x] TCP - [ ] TLS -- [ ] WS +- [x] WS - [ ] WSS diff --git a/go.mod b/go.mod index ee8f1f2..b007bb0 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,9 @@ module github.com/emiraganov/sipgo -go 1.18 +go 1.19 require ( + github.com/gobwas/ws v1.1.0 github.com/prometheus/client_golang v1.12.0 github.com/rs/zerolog v1.26.1 github.com/satori/go.uuid v1.2.1-0.20181028125025-b2ce2384e17b @@ -14,6 +15,8 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gobwas/httphead v0.1.0 // indirect + github.com/gobwas/pool v0.2.1 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/kr/pretty v0.2.1 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect diff --git a/go.sum b/go.sum index e2cb330..facf86c 100644 --- a/go.sum +++ b/go.sum @@ -69,6 +69,12 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= +github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= +github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= +github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.1.0 h1:7RFti/xnNkMJnrK7D1yQ/iCIB5OrrY/54/H930kIbHA= +github.com/gobwas/ws v1.1.0/go.mod h1:nzvNcVha5eUziGrbxFCo6qFIojQHjJV5cLYIbezhfL0= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -327,6 +333,7 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201207223542-d4d67f95c62d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/server_integration_test.go b/server_integration_test.go new file mode 100644 index 0000000..7401ad7 --- /dev/null +++ b/server_integration_test.go @@ -0,0 +1,65 @@ +package sipgo + +import ( + "testing" + "time" + + "github.com/emiraganov/sipgo/sip" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWebsocket(t *testing.T) { + ua, _ := NewUA() + // transport.SIPDebug = true + log.Logger = log.Level(zerolog.DebugLevel) + + // Build UAS + srv, err := NewServer(ua) + if err != nil { + log.Fatal().Err(err).Msg("Fail to setup dialog server") + } + srv.OnInvite(func(req *sip.Request, tx sip.ServerTransaction) { + t.Log("Invite received") + res := sip.NewResponseFromRequest(req, 200, "OK", nil) + if err := tx.Respond(res); err != nil { + t.Fatal(err) + } + <-tx.Done() + }) + + srv.Listen("ws", "127.0.0.1:5060") + + go func() { + if err := srv.Serve(); err != nil { + log.Error().Err(err).Msg("Fail to serve") + } + }() + + // Build UAC + ua, _ = NewUA() + client, err := NewClient(ua) + require.Nil(t, err) + + csrv, err := NewServer(ua) // Create server handle + require.Nil(t, err) + csrv.Listen("ws", "127.0.0.2:5060") + go func() { + if err := csrv.Serve(); err != nil { + log.Error().Err(err).Msg("Fail to serve") + } + }() + + time.Sleep(2 * time.Second) + + req, _, _ := createTestInvite(t, "WS", client.ip.String()) + // err = client.WriteRequest(req) + // require.Nil(t, err) + // time.Sleep(2 * time.Second) + tx, err := client.TransactionRequest(req) + require.Nil(t, err) + res := <-tx.Responses() + assert.Equal(t, sip.StatusCode(200), res.StatusCode()) +} diff --git a/transport/layer.go b/transport/layer.go index 8f45706..b334916 100644 --- a/transport/layer.go +++ b/transport/layer.go @@ -2,6 +2,7 @@ package transport import ( "context" + "errors" "fmt" "math/rand" "net" @@ -17,6 +18,10 @@ import ( "github.com/rs/zerolog/log" ) +var ( + ErrNetworkExists = errors.New("network is already served") +) + func init() { rand.Seed(time.Now().UnixNano()) } @@ -97,7 +102,7 @@ func (l *Layer) ServeUDP(c net.PacketConn) error { return transport.ServeConn(c, l.handleMessage) } -// ServeTCP will listen on udp connection +// ServeTCP will listen on tcp connection func (l *Layer) ServeTCP(c net.Listener) error { _, port, err := sip.ParseAddr(c.Addr().String()) if err != nil { @@ -110,6 +115,19 @@ func (l *Layer) ServeTCP(c net.Listener) error { return transport.ServeConn(c, l.handleMessage) } +// ServeWS will listen on ws connection +func (l *Layer) ServeWS(c net.Listener) error { + _, port, err := sip.ParseAddr(c.Addr().String()) + if err != nil { + return err + } + + transport := NewWSTransport(c.Addr().String(), parser.NewParser()) + l.addTransport(transport, port) + + return transport.ServeConn(c, l.handleMessage) +} + // Serve on any network. This function will block func (l *Layer) Serve(ctx context.Context, network string, addr string) error { network = strings.ToLower(network) @@ -118,6 +136,11 @@ func (l *Layer) Serve(ctx context.Context, network string, addr string) error { return err } + _, exists := l.transports[network] + if exists { + return ErrNetworkExists + } + p := parser.NewParser() var t Transport @@ -126,6 +149,8 @@ func (l *Layer) Serve(ctx context.Context, network string, addr string) error { t = NewUDPTransport(addr, p) case "tcp": t = NewTCPTransport(addr, p) + case "ws": + t = NewWSTransport(addr, p) case "tls": fallthrough default: diff --git a/transport/transport.go b/transport/transport.go index 47de8d9..1159911 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -16,6 +16,7 @@ const ( TransportUDP = "UDP" TransportTCP = "TCP" TransportTLS = "TLS" + TransportWS = "WS" ) // Protocol implements network specific features. diff --git a/transport/ws.go b/transport/ws.go new file mode 100644 index 0000000..31b6ce3 --- /dev/null +++ b/transport/ws.go @@ -0,0 +1,321 @@ +package transport + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "sync" + + "github.com/emiraganov/sipgo/parser" + "github.com/emiraganov/sipgo/sip" + "github.com/gobwas/ws" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +var () + +// WS transport implementation +type WSTransport struct { + addr string + listener net.Listener + parser parser.SIPParser + handler sip.MessageHandler + log zerolog.Logger + + pool ConnectionPool +} + +func NewWSTransport(addr string, par parser.SIPParser) *WSTransport { + p := &WSTransport{ + addr: addr, + parser: par, + pool: NewConnectionPool(), + } + p.log = log.Logger.With().Str("caller", "transport").Logger() + return p +} + +func (t *WSTransport) String() string { + return "transport" +} + +func (t *WSTransport) Addr() string { + return t.addr +} + +func (t *WSTransport) Network() string { + return "ws" +} + +func (t *WSTransport) Close() error { + // return t.connections.Done() + var err error + if t.listener == nil { + return nil + } + + if err := t.listener.Close(); err != nil { + err = fmt.Errorf("err=%w", err) + } + + t.listener = nil + return err +} + +// This is more generic way to provide listener and it is blocking +func (t *WSTransport) Serve(handler sip.MessageHandler) error { + addr := t.addr + laddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return fmt.Errorf("fail to resolve address. err=%w", err) + } + + conn, err := net.ListenTCP("tcp", laddr) + if err != nil { + return fmt.Errorf("listen tcp error. err=%w", err) + } + + return t.ServeConn(conn, handler) +} + +// serveConn is direct way to provide conn on which this worker will listen +// UDPReadWorkers are used to create more workers +func (t *WSTransport) ServeConn(l net.Listener, handler sip.MessageHandler) error { + if t.listener != nil { + return fmt.Errorf("TCP transport instance can only listen on one lection") + } + + t.log.Debug().Msgf("begin listening on %s %s", t.Network(), l.Addr().String()) + t.listener = l + t.handler = handler + return t.Accept() +} + +func (t *WSTransport) Accept() error { + l := t.listener + for { + conn, err := l.Accept() + if err != nil { + t.log.Error().Err(err).Msg("Fail to accept conenction") + return err + } + + _, err = ws.Upgrade(conn) + if err != nil { + return err + } + + t.initConnection(conn, conn.RemoteAddr().String()) + } +} + +func (t *WSTransport) initConnection(conn net.Conn, addr string) Connection { + // // conn.SetKeepAlive(true) + // conn.SetKeepAlivePeriod(3 * time.Second) + t.log.Debug().Str("raddr", addr).Msg("New WS connection") + c := &WSConnection{ + Conn: conn, + refcount: 3, + } + t.pool.Add(addr, c) + go t.readConnection(c, addr) + return c +} + +// This should performe better to avoid any interface allocation +func (t *WSTransport) readConnection(conn *WSConnection, raddr string) { + buf := make([]byte, UDPbufferSize) + defer conn.Close() + defer t.pool.Del(raddr) + defer t.log.Debug().Str("raddr", raddr).Msg("WS connection read stopped") + + for { + num, err := conn.Read(buf) + if err != nil { + if errors.Is(err, io.EOF) { + t.log.Debug().Msg("Got EOF") + return + } + t.log.Error().Err(err).Msg("Got TCP error") + return + } + + data := buf[:num] + + if len(bytes.Trim(data, "\x00")) == 0 { + continue + } + + t.parse(data, raddr) + } + +} + +func (t *WSTransport) parse(data []byte, src string) { + // Check is keep alive + if len(data) <= 4 { + //One or 2 CRLF + if len(bytes.Trim(data, "\r\n")) == 0 { + t.log.Debug().Msg("Keep alive CRLF received") + return + } + } + + msg, err := t.parser.Parse(data) //Very expensive operation + if err != nil { + t.log.Error().Err(err).Str("data", string(data)).Msg("failed to parse") + return + } + + msg.SetTransport(TransportWS) + msg.SetSource(src) + t.handler(msg) +} + +func (t *WSTransport) ResolveAddr(addr string) (net.Addr, error) { + return net.ResolveTCPAddr("tcp", addr) +} + +func (t *WSTransport) GetConnection(addr string) (Connection, error) { + raddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + addr = raddr.String() + + c := t.pool.Get(addr) + return c, nil +} + +func (t *WSTransport) CreateConnection(addr string) (Connection, error) { + raddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + return t.createConnection(raddr.String()) +} + +func (t *WSTransport) createConnection(addr string) (Connection, error) { + t.log.Debug().Str("raddr", addr).Msg("Dialing new connection") + + conn, _, _, err := ws.Dial(context.TODO(), "ws://"+addr) + if err != nil { + return nil, fmt.Errorf("%s dial err=%w", t, err) + } + + c := t.initConnection(conn, addr) + return c, nil +} + +type WSConnection struct { + net.Conn + + mu sync.RWMutex + refcount int +} + +func (c *WSConnection) Ref(i int) { + c.mu.Lock() + c.refcount += i + ref := c.refcount + c.mu.Unlock() + log.Debug().Str("ip", c.RemoteAddr().String()).Int("ref", ref).Msg("WS reference increment") + +} + +func (c *WSConnection) Close() error { + c.mu.Lock() + c.refcount-- + ref := c.refcount + c.mu.Unlock() + log.Debug().Str("ip", c.RemoteAddr().String()).Int("ref", c.refcount).Msg("WS reference decrement") + if ref > 0 { + return nil + } + log.Debug().Str("ip", c.RemoteAddr().String()).Int("ref", c.refcount).Msg("WS closing") + return c.Conn.Close() +} + +func (c *WSConnection) Read(b []byte) (n int, err error) { + for { + header, err := ws.ReadHeader(c.Conn) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return 0, err + } + + if header.OpCode == ws.OpClose { + return 0, io.EOF + } + + if SIPDebug { + log.Debug().Str("caller", c.LocalAddr().String()).Msgf("WS read connection header <- %s len=%d", c.Conn.RemoteAddr(), header.Length) + } + + data := make([]byte, header.Length) + + // Read until + _, err = io.ReadFull(c.Conn, data) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return 0, err + } + + if header.Masked { + ws.Cipher(data, header.Mask, 0) + } + // header.Masked = false + + n += copy(b[n:], data) + + if header.Fin { + break + } + } + + if SIPDebug { + log.Debug().Str("caller", c.LocalAddr().String()).Msgf("WS read connection <- %s: len=%d\n%s", c.Conn.RemoteAddr(), n, string(b)) + } + + return n, nil +} + +func (c *WSConnection) Write(b []byte) (n int, err error) { + fs := ws.NewFrame(ws.OpText, true, b) + err = ws.WriteFrame(c.Conn, fs) + if SIPDebug { + log.Debug().Str("caller", c.LocalAddr().String()).Msgf("WS write -> %s:\n%s", c.Conn.RemoteAddr(), string(b)) + } + return len(b), err +} + +func (c *WSConnection) WriteMsg(msg sip.Message) error { + buf := bufPool.Get().(*bytes.Buffer) + defer bufPool.Put(buf) + buf.Reset() + msg.StringWrite(buf) + data := buf.Bytes() + + n, err := c.Write(data) + if err != nil { + return fmt.Errorf("conn %s write err=%w", c.RemoteAddr().String(), err) + } + + if n == 0 { + return fmt.Errorf("wrote 0 bytes") + } + + if n != len(data) { + return fmt.Errorf("fail to write full message") + } + return nil +}