From 9a076d63800013d7cede0870315178131c76839f Mon Sep 17 00:00:00 2001 From: Emir Aganovic Date: Sat, 16 Mar 2024 19:01:32 +0100 Subject: [PATCH] fix: dialog handle Record Route and Route headers --- client.go | 1 + dialog_client.go | 10 +++ dialog_server.go | 5 ++ sip/parse_uri_test.go | 137 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 153 insertions(+) create mode 100644 sip/parse_uri_test.go diff --git a/client.go b/client.go index 4237ee5..65d6b3e 100644 --- a/client.go +++ b/client.go @@ -322,6 +322,7 @@ func ClientRequestAddRecordRoute(c *Client, r *sip.Request) error { "transport": sip.NetworkToLower(r.Transport()), "lr": "", }, + Headers: sip.NewParams(), }, } diff --git a/dialog_client.go b/dialog_client.go index b351061..df71699 100644 --- a/dialog_client.go +++ b/dialog_client.go @@ -243,6 +243,11 @@ func (s *DialogClientSession) Ack(ctx context.Context) error { } func (s *DialogClientSession) WriteAck(ctx context.Context, ack *sip.Request) error { + // Check Record-Route Header + if rr := s.InviteResponse.RecordRoute(); rr != nil { + ack.SetDestination(rr.Address.HostPort()) + } + if err := s.dc.c.WriteRequest(ack); err != nil { // Make sure we close our error // s.Close() @@ -273,6 +278,11 @@ func (s *DialogClientSession) WriteBye(ctx context.Context, bye *sip.Request) er return fmt.Errorf("Dialog not confirmed. ACK not send?") } + // Check Record-Route Header + if rr := s.InviteResponse.RecordRoute(); rr != nil { + bye.SetDestination(rr.Address.HostPort()) + } + tx, err := dc.c.TransactionRequest(ctx, bye) if err != nil { return err diff --git a/dialog_server.go b/dialog_server.go index 934175a..7a4f84f 100644 --- a/dialog_server.go +++ b/dialog_server.go @@ -240,6 +240,11 @@ func (s *DialogServerSession) Bye(ctx context.Context) error { return fmt.Errorf("non matching ID %q %q", s.ID, byeID) } + // Check Route Header + if rr := bye.Route(); rr != nil { + bye.SetDestination(rr.Address.HostPort()) + } + tx, err := cli.TransactionRequest(ctx, bye) if err != nil { return err diff --git a/sip/parse_uri_test.go b/sip/parse_uri_test.go new file mode 100644 index 0000000..44a105c --- /dev/null +++ b/sip/parse_uri_test.go @@ -0,0 +1,137 @@ +package sip + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseUri(t *testing.T) { + // This are all good accepted URIs test. + + /* + https://datatracker.ietf.org/doc/html/rfc3261#section-19.1.3 + sip:alice@atlanta.com + sip:alice:secretword@atlanta.com;transport=tcp + sips:alice@atlanta.com?subject=project%20x&priority=urgent + sip:+1-212-555-1212:1234@gateway.com;user=phone + sips:1212@gateway.com + sip:alice@192.0.2.4 + sip:atlanta.com;method=REGISTER?to=alice%40atlanta.com + sip:alice;day=tuesday@atlanta.com + */ + + var uri Uri + var err error + var str string + + t.Run("basic", func(t *testing.T) { + uri = Uri{} + str = "sip:alice@localhost:5060" + err = ParseUri(str, &uri) + require.NoError(t, err) + assert.Equal(t, "alice", uri.User) + assert.Equal(t, "localhost", uri.Host) + assert.Equal(t, 5060, uri.Port) + assert.Equal(t, "localhost:5060", uri.HostPort()) + assert.Equal(t, "alice@localhost:5060", uri.Endpoint()) + }) + + t.Run("sip case insensitive", func(t *testing.T) { + testCases := []string{ + "sip:alice@atlanta.com", + "SIP:alice@atlanta.com", + "sIp:alice@atlanta.com", + } + for _, testCase := range testCases { + err = ParseUri(testCase, &uri) + require.NoError(t, err) + assert.Equal(t, "alice", uri.User) + assert.Equal(t, "atlanta.com", uri.Host) + assert.False(t, uri.Encrypted) + } + + testCases = []string{ + "sips:alice@atlanta.com", + "SIPS:alice@atlanta.com", + "sIpS:alice@atlanta.com", + } + for _, testCase := range testCases { + err = ParseUri(testCase, &uri) + require.NoError(t, err) + assert.Equal(t, "alice", uri.User) + assert.Equal(t, "atlanta.com", uri.Host) + assert.True(t, uri.Encrypted) + } + + }) + + t.Run("no sip scheme", func(t *testing.T) { + // No scheme we currently allow + uri = Uri{} + str = "alice@localhost:5060" + err = ParseUri(str, &uri) + require.NoError(t, err) + assert.Equal(t, "sip:alice@localhost:5060", uri.String()) + }) + + t.Run("uri params parsed", func(t *testing.T) { + uri = Uri{} + str = "sips:alice@atlanta.com?subject=project%20x&priority=urgent" + err = ParseUri(str, &uri) + require.NoError(t, err) + + assert.Equal(t, "alice", uri.User) + assert.Equal(t, "atlanta.com", uri.Host) + subject, _ := uri.Headers.Get("subject") + priority, _ := uri.Headers.Get("priority") + assert.Equal(t, "project%20x", subject) + assert.Equal(t, "urgent", priority) + }) + + t.Run("header params parsed", func(t *testing.T) { + uri = Uri{} + str = "sip:bob:secret@atlanta.com:9999;rport;transport=tcp;method=REGISTER?to=sip:bob%40biloxi.com" + err = ParseUri(str, &uri) + require.NoError(t, err) + + assert.Equal(t, "bob", uri.User) + assert.Equal(t, "secret", uri.Password) + assert.Equal(t, "atlanta.com", uri.Host) + assert.Equal(t, 9999, uri.Port) + + assert.Equal(t, 3, uri.UriParams.Length()) + transport, _ := uri.UriParams.Get("transport") + method, _ := uri.UriParams.Get("method") + assert.Equal(t, "tcp", transport) + assert.Equal(t, "REGISTER", method) + + assert.Equal(t, 1, uri.Headers.Length()) + to, _ := uri.Headers.Get("to") + assert.Equal(t, "sip:bob%40biloxi.com", to) + + }) + + t.Run("params no value", func(t *testing.T) { + uri = Uri{} + str = "127.0.0.2:5060;rport;branch=z9hG4bKPj6c65c5d9-b6d0-4a30-9383-1f9b42f97de9" + err = ParseUri(str, &uri) + require.NoError(t, err) + + rport, _ := uri.UriParams.Get("rport") + branch, _ := uri.UriParams.Get("branch") + assert.Equal(t, "", rport) + assert.Equal(t, "z9hG4bKPj6c65c5d9-b6d0-4a30-9383-1f9b42f97de9", branch) + + }) +} + +func TestParseUriBad(t *testing.T) { + t.Run("double ports", func(t *testing.T) { + str := "sip:127.0.0.1:5060:5060;lr;transport=udp" + uri := Uri{} + err := ParseUri(str, &uri) + require.Error(t, err) + }) +}