From 8da20f7d7215d2e0aae67b9a055f0e660a16c4d7 Mon Sep 17 00:00:00 2001 From: Emir Aganovic Date: Mon, 14 Aug 2023 14:02:04 +0200 Subject: [PATCH] fix: Setting protocol headers in websocket and add register support for wss #35 --- client.go | 15 +++++++++ example/register/client/main.go | 19 +++++++---- example/register/go.mod | 4 +-- example/register/go.sum | 9 +++--- example/register/server/main.go | 30 +++++++++++++++--- server.go | 6 ++-- server_integration_test.go | 2 +- sip/utils.go | 12 +++++++ testdata/generate_certs_rsa.sh | 56 +++++++++++++++++++++++++++++++++ transport/ws.go | 44 +++++++++++++++++++++++--- transport/wss.go | 11 ++++--- ua.go | 1 + 12 files changed, 180 insertions(+), 29 deletions(-) create mode 100755 testdata/generate_certs_rsa.sh diff --git a/client.go b/client.go index 8a92e4e..0a5fefc 100644 --- a/client.go +++ b/client.go @@ -49,6 +49,21 @@ func WithClientPort(port int) ClientOption { } } +// WithClientAddr is merge of WithClientHostname and WithClientPort +// addr is format : +func WithClientAddr(addr string) ClientOption { + return func(s *Client) error { + host, port, err := sip.SplitHostPort(addr) + if err != nil { + return err + } + + WithClientHostname(host) + WithClientPort(port) + return nil + } +} + // NewClient creates client handle for user agent func NewClient(ua *UserAgent, options ...ClientOption) (*Client, error) { c := &Client{ diff --git a/example/register/client/main.go b/example/register/client/main.go index 7dc3282..8c84467 100644 --- a/example/register/client/main.go +++ b/example/register/client/main.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "os" + "strings" "time" "github.com/emiago/sipgo" @@ -20,24 +21,27 @@ import ( func main() { extIP := flag.String("ip", "127.0.0.50:5060", "My exernal ip") dst := flag.String("srv", "127.0.0.10:5060", "Destination") + tran := flag.String("t", "udp", "Transport") username := flag.String("u", "alice", "SIP Username") password := flag.String("p", "alice", "Password") - sipdebug := flag.Bool("sipdebug", false, "Turn on sipdebug") flag.Parse() // Make SIP Debugging available - transport.SIPDebug = *sipdebug + transport.SIPDebug = os.Getenv("SIP_DEBUG") != "" zerolog.TimeFieldFormat = zerolog.TimeFormatUnixMicro log.Logger = zerolog.New(zerolog.ConsoleWriter{ Out: os.Stdout, TimeFormat: time.StampMicro, - }).With().Timestamp().Logger().Level(zerolog.DebugLevel) + }).With().Timestamp().Logger().Level(zerolog.InfoLevel) + + if lvl, err := zerolog.ParseLevel(os.Getenv("LOG_LEVEL")); err == nil && lvl != zerolog.NoLevel { + log.Logger = log.Logger.Level(lvl) + } // Setup UAC ua, err := sipgo.NewUA( sipgo.WithUserAgent(*username), - sipgo.WithUserAgentIP(*extIP), ) if err != nil { log.Fatal().Err(err).Msg("Fail to setup user agent") @@ -48,16 +52,17 @@ func main() { log.Fatal().Err(err).Msg("Fail to setup server handle") } - client, err := sipgo.NewClient(ua) + client, err := sipgo.NewClient(ua, sipgo.WithClientAddr(*extIP)) if err != nil { log.Fatal().Err(err).Msg("Fail to setup client handle") } ctx := context.TODO() - go srv.ListenAndServe(ctx, "udp", *extIP) + go srv.ListenAndServe(ctx, *tran, *extIP) // Wait that ouir server loads time.Sleep(1 * time.Second) + log.Info().Str("addr", *extIP).Msg("Server listening on") // Create basic REGISTER request structure recipient := &sip.Uri{} @@ -66,9 +71,11 @@ func main() { req.AppendHeader( sip.NewHeader("Contact", fmt.Sprintf("", *username, *extIP)), ) + req.SetTransport(strings.ToUpper(*tran)) // Send request and parse response // req.SetDestination(*dst) + log.Info().Msg(req.StartLine()) tx, err := client.TransactionRequest(req.Clone()) if err != nil { log.Fatal().Err(err).Msg("Fail to create transaction") diff --git a/example/register/go.mod b/example/register/go.mod index 84a4a91..36350b4 100644 --- a/example/register/go.mod +++ b/example/register/go.mod @@ -11,12 +11,12 @@ require ( require ( github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect - github.com/gobwas/ws v1.1.0 // indirect + github.com/gobwas/ws v1.2.1 // indirect github.com/google/uuid v1.3.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.16 // indirect github.com/satori/go.uuid v1.2.1-0.20181028125025-b2ce2384e17b // indirect - golang.org/x/sys v0.3.0 // indirect + golang.org/x/sys v0.6.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect ) diff --git a/example/register/go.sum b/example/register/go.sum index 26adcaf..6763700 100644 --- a/example/register/go.sum +++ b/example/register/go.sum @@ -4,8 +4,8 @@ github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= -github.com/gobwas/ws v1.1.0 h1:7RFti/xnNkMJnrK7D1yQ/iCIB5OrrY/54/H930kIbHA= -github.com/gobwas/ws v1.1.0/go.mod h1:nzvNcVha5eUziGrbxFCo6qFIojQHjJV5cLYIbezhfL0= +github.com/gobwas/ws v1.2.1 h1:F2aeBZrm2NDsc7vbovKrWSogd4wvfAxg0FQ89/iqOTk= +github.com/gobwas/ws v1.2.1/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -39,12 +39,11 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201207223542-d4d67f95c62d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= -golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190624222133-a101b041ded4/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/example/register/server/main.go b/example/register/server/main.go index d4d6426..7f6497e 100644 --- a/example/register/server/main.go +++ b/example/register/server/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "crypto/tls" "flag" "fmt" "os" @@ -20,17 +21,23 @@ import ( func main() { extIP := flag.String("ip", "127.0.0.1:5060", "My exernal ip") creds := flag.String("u", "alice:alice", "Coma seperated username:password list") - sipdebug := flag.Bool("sipdebug", false, "Turn on sipdebug") + tran := flag.String("t", "udp", "Transport") + tlskey := flag.String("tlskey", "", "TLS key path") + tlscrt := flag.String("tlscrt", "", "TLS crt path") flag.Parse() // Make SIP Debugging available - transport.SIPDebug = *sipdebug + transport.SIPDebug = os.Getenv("SIP_DEBUG") != "" zerolog.TimeFieldFormat = zerolog.TimeFormatUnixMicro log.Logger = zerolog.New(zerolog.ConsoleWriter{ Out: os.Stdout, TimeFormat: time.StampMicro, - }).With().Timestamp().Logger().Level(zerolog.DebugLevel) + }).With().Timestamp().Logger().Level(zerolog.InfoLevel) + + if lvl, err := zerolog.ParseLevel(os.Getenv("LOG_LEVEL")); err == nil && lvl != zerolog.NoLevel { + log.Logger = log.Logger.Level(lvl) + } registry := make(map[string]string) for _, c := range strings.Split(*creds, ",") { @@ -110,5 +117,20 @@ func main() { tx.Respond(sip.NewResponseFromRequest(req, 200, "OK", nil)) }) - srv.ListenAndServe(ctx, "udp", *extIP) + log.Info().Str("addr", *extIP).Msg("Listening on") + + switch *tran { + case "tls", "wss": + cert, err := tls.LoadX509KeyPair(*tlscrt, *tlskey) + if err != nil { + + log.Fatal().Err(err).Msg("Fail to load x509 key and crt") + } + if err := srv.ListenAndServeTLS(ctx, *tran, *extIP, &tls.Config{Certificates: []tls.Certificate{cert}}); err != nil { + log.Info().Err(err).Msg("Listening stop") + } + return + } + + srv.ListenAndServe(ctx, *tran, *extIP) } diff --git a/server.go b/server.go index 5395a50..b2476f1 100644 --- a/server.go +++ b/server.go @@ -208,17 +208,17 @@ func (srv *Server) ServeTCP(l net.Listener) error { } // ServeTLS starts serving request on TLS type listener. -func (srv *Server) ServeTLS(l net.Listener, conf *tls.Config) error { +func (srv *Server) ServeTLS(l net.Listener) error { return srv.tp.ServeTLS(l) } // ServeWS starts serving request on WS type listener. -func (srv *Server) ServeWS(l net.Listener, conf *tls.Config) error { +func (srv *Server) ServeWS(l net.Listener) error { return srv.tp.ServeWS(l) } // ServeWS starts serving request on WS type listener. -func (srv *Server) ServeWSS(l net.Listener, conf *tls.Config) error { +func (srv *Server) ServeWSS(l net.Listener) error { return srv.tp.ServeWSS(l) } diff --git a/server_integration_test.go b/server_integration_test.go index 80d7c52..5feed90 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -20,7 +20,7 @@ import ( // This will generate TLS certificates needed for test below // openssl is required -//go:generate bash -c "cd testdata && ./generate_certs.sh" +//go:generate bash -c "cd testdata && ./generate_certs_rsa.sh" var ( //go:embed testdata/certs/server.crt diff --git a/sip/utils.go b/sip/utils.go index a2fcc54..2e3c954 100644 --- a/sip/utils.go +++ b/sip/utils.go @@ -5,6 +5,7 @@ import ( "errors" "math/rand" "net" + "strconv" "strings" "time" ) @@ -265,3 +266,14 @@ func MessageShortString(msg Message) string { } return "Unknown message type" } + +func SplitHostPort(addr string) (host string, port int, err error) { + var p string + host, p, err = net.SplitHostPort(addr) + if err != nil { + return + } + + port, err = strconv.Atoi(p) + return +} diff --git a/testdata/generate_certs_rsa.sh b/testdata/generate_certs_rsa.sh new file mode 100755 index 0000000..2739979 --- /dev/null +++ b/testdata/generate_certs_rsa.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash + +mkdir -p certs && cd certs + +set -x +# Create root CA +openssl genrsa -out rootca-key.pem 2048 + +openssl req -new -x509 -nodes -days 3650 \ + -subj "/C=US/ST=California/CN=localhost" \ + -key rootca-key.pem \ + -out rootca-cert.pem + + +# Server +openssl req -newkey rsa:2048 -nodes \ + -subj "/C=US/ST=California/CN=localhost" \ + -keyout server.key \ + -out server.csr + + +openssl x509 -req -days 3650 -set_serial 01 \ + -in server.csr \ + -out server.crt \ + -CA rootca-cert.pem \ + -CAkey rootca-key.pem \ + -extensions SAN \ + -extfile <(printf "\n[SAN]\nsubjectAltName=IP:127.1.1.100\nextendedKeyUsage=serverAuth") + + +echo "Server cert and key created" +echo "===========================" +openssl x509 -noout -text -in server.crt +echo "===========================" + +# Client +openssl req -newkey rsa:2048 -nodes \ + -subj "/C=US/ST=California/CN=localhost" \ + -keyout client.key \ + -out client.csr + + +openssl x509 -req -days 3650 -set_serial 01 \ + -in client.csr \ + -out client.crt \ + -CA rootca-cert.pem \ + -CAkey rootca-key.pem \ + -extensions SAN \ + -extfile <(printf "\n[SAN]\nsubjectAltName=IP:127.1.1.100\nextendedKeyUsage=clientAuth") + + +echo "Client cert and key created" +echo "===========================" +openssl x509 -noout -text -in client.crt +echo "===========================" +set +x \ No newline at end of file diff --git a/transport/ws.go b/transport/ws.go index 1be6979..cb8f4cf 100644 --- a/transport/ws.go +++ b/transport/ws.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net" + "net/http" "sync" "time" @@ -18,13 +19,20 @@ import ( "github.com/rs/zerolog/log" ) +var ( + // WebSocketProtocols is used in setting websocket header + // By default clients must accept protocol sip + WebSocketProtocols = []string{"sip"} +) + // WS transport implementation type WSTransport struct { parser sip.Parser log zerolog.Logger transport string - pool ConnectionPool + pool ConnectionPool + dialer ws.Dialer } func NewWSTransport(par sip.Parser) *WSTransport { @@ -32,7 +40,10 @@ func NewWSTransport(par sip.Parser) *WSTransport { parser: par, pool: NewConnectionPool(), transport: TransportWS, + dialer: ws.DefaultDialer, } + + p.dialer.Protocols = WebSocketProtocols p.log = log.Logger.With().Str("caller", "transport").Logger() return p } @@ -53,6 +64,27 @@ func (t *WSTransport) Close() error { // Serve is direct way to provide conn on which this worker will listen func (t *WSTransport) Serve(l net.Listener, handler sip.MessageHandler) error { t.log.Debug().Msgf("begin listening on %s %s", t.Network(), l.Addr().String()) + + // Prepare handshake header writer from http.Header mapping. + // Some phones want to return this + // TODO make this configurable + header := ws.HandshakeHeaderHTTP(http.Header{ + "Sec-WebSocket-Protocol": WebSocketProtocols, + }) + + u := ws.Upgrader{ + OnBeforeUpgrade: func() (ws.HandshakeHeader, error) { + return header, nil + }, + } + + if SIPDebug { + u.OnHeader = func(key, value []byte) error { + log.Debug().Str(string(key), string(value)).Msg("non-websocket header:") + return nil + } + } + for { conn, err := l.Accept() if err != nil { @@ -64,9 +96,10 @@ func (t *WSTransport) Serve(l net.Listener, handler sip.MessageHandler) error { t.log.Debug().Str("addr", raddr).Msg("New connection accept") - _, err = ws.Upgrade(conn) + _, err = u.Upgrade(conn) if err != nil { - return err + t.log.Error().Err(err).Msg("Fail to upgrade") + continue } t.initConnection(conn, raddr, false, handler) @@ -184,7 +217,10 @@ func (t *WSTransport) CreateConnection(addr string, handler sip.MessageHandler) func (t *WSTransport) createConnection(addr string, handler sip.MessageHandler) (Connection, error) { t.log.Debug().Str("raddr", addr).Msg("Dialing new connection") - conn, _, _, err := ws.Dial(context.TODO(), "ws://"+addr) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, _, _, err := t.dialer.Dial(ctx, "ws://"+addr) if err != nil { return nil, fmt.Errorf("%s dial err=%w", t, err) } diff --git a/transport/wss.go b/transport/wss.go index e455778..45613b8 100644 --- a/transport/wss.go +++ b/transport/wss.go @@ -5,9 +5,9 @@ import ( "crypto/tls" "fmt" "net" + "time" "github.com/emiago/sipgo/sip" - "github.com/gobwas/ws" "github.com/rs/zerolog/log" ) @@ -23,12 +23,12 @@ type WSSTransport struct { func NewWSSTransport(par sip.Parser, dialTLSConf *tls.Config) *WSSTransport { tcptrans := NewWSTransport(par) tcptrans.transport = TransportWSS + // Set our TLS config p := &WSSTransport{ WSTransport: tcptrans, } - // TODO should have single or multiple dialers - ws.DefaultDialer.TLSConfig = dialTLSConf + p.dialer.TLSConfig = dialTLSConf // p.tlsConf = dialTLSConf p.log = log.Logger.With().Str("caller", "transport").Logger() @@ -52,7 +52,10 @@ func (t *WSSTransport) CreateConnection(addr string, handler sip.MessageHandler) func (t *WSSTransport) createConnection(addr string, handler sip.MessageHandler) (Connection, error) { t.log.Debug().Str("raddr", addr).Msg("Dialing new connection") - conn, _, _, err := ws.Dial(context.TODO(), "wss://"+addr) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, _, _, err := t.dialer.Dial(ctx, "wss://"+addr) if err != nil { return nil, fmt.Errorf("%s dial err=%w", t, err) } diff --git a/ua.go b/ua.go index 1ddfece..b2e28fd 100644 --- a/ua.go +++ b/ua.go @@ -32,6 +32,7 @@ func WithUserAgent(ua string) UserAgentOption { // WithUserAgentIP sets local IP that will be used in building request // If not used IP will be resolved +// Deprecated: Use on client WithClientHostname WithClientPort func WithUserAgentIP(ip net.IP) UserAgentOption { return func(s *UserAgent) error { return s.setIP(ip)