From c46cb2aeb5c9355b4aefb5ddb1e1402e8322f91b Mon Sep 17 00:00:00 2001 From: Denys Smirnov Date: Thu, 10 Oct 2024 22:25:57 +0300 Subject: [PATCH] Always set contact header for outbound to IP. --- pkg/config/config.go | 16 ++++++------ pkg/sip/client.go | 14 ++++++++--- pkg/sip/config.go | 48 ++++++++++++++++++++++-------------- pkg/sip/inbound.go | 9 ++++--- pkg/sip/media_port.go | 5 ++-- pkg/sip/media_port_test.go | 13 ++++++++-- pkg/sip/outbound.go | 17 ++++++++----- pkg/sip/server.go | 11 ++++++--- pkg/sip/signaling.go | 12 ++++----- pkg/sip/types.go | 11 +++++++++ pkg/siptest/client.go | 13 +++++----- test/integration/sip_test.go | 2 +- 12 files changed, 112 insertions(+), 59 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index e029acc1..d1065f0e 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -17,6 +17,7 @@ package config import ( "fmt" "net" + "net/netip" "os" "time" @@ -164,14 +165,14 @@ func (c *Config) GetLoggerFields() logrus.Fields { return fields } -func GetLocalIP() (string, error) { +func GetLocalIP() (netip.Addr, error) { ifaces, err := net.Interfaces() if err != nil { - return "", nil + return netip.Addr{}, nil } type Iface struct { Name string - Addr net.IP + Addr netip.Addr } var candidates []Iface for _, ifc := range ifaces { @@ -191,15 +192,16 @@ func GetLocalIP() (string, error) { continue } if ip4 := ipnet.IP.To4(); ip4 != nil { + ip, _ := netip.AddrFromSlice(ip4) candidates = append(candidates, Iface{ - Name: ifc.Name, Addr: ip4, + Name: ifc.Name, Addr: ip, }) - logger.Debugw("considering interface", "iface", ifc.Name, "ip", ip4) + logger.Debugw("considering interface", "iface", ifc.Name, "ip", ip) } } } if len(candidates) == 0 { - return "", fmt.Errorf("No local IP found") + return netip.Addr{}, fmt.Errorf("No local IP found") } - return candidates[0].Addr.String(), nil + return candidates[0].Addr, nil } diff --git a/pkg/sip/client.go b/pkg/sip/client.go index bc8b25c0..bad9bc61 100644 --- a/pkg/sip/client.go +++ b/pkg/sip/client.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "log/slog" + "net/netip" "sync" "github.com/emiago/sipgo" @@ -28,6 +29,7 @@ import ( "github.com/livekit/protocol/logger" "github.com/livekit/protocol/rpc" "github.com/livekit/protocol/tracer" + "github.com/livekit/sip/pkg/config" siperrors "github.com/livekit/sip/pkg/errors" "github.com/livekit/sip/pkg/stats" @@ -39,8 +41,8 @@ type Client struct { mon *stats.Monitor sipCli *sipgo.Client - signalingIp string - signalingIpLocal string + signalingIp netip.Addr + signalingIpLocal netip.Addr closing core.Fuse cmu sync.Mutex @@ -74,7 +76,11 @@ func (c *Client) Start(agent *sipgo.UserAgent) error { return err } } else if c.conf.NAT1To1IP != "" { - c.signalingIp = c.conf.NAT1To1IP + ip, err := netip.ParseAddr(c.conf.NAT1To1IP) + if err != nil { + return err + } + c.signalingIp = ip c.signalingIpLocal = c.signalingIp } else { if c.signalingIp, err = getLocalIP(c.conf.LocalNet); err != nil { @@ -95,7 +101,7 @@ func (c *Client) Start(agent *sipgo.UserAgent) error { } c.sipCli, err = sipgo.NewClient(agent, - sipgo.WithClientHostname(c.signalingIp), + sipgo.WithClientHostname(c.signalingIp.String()), sipgo.WithClientLogger(slog.New(logger.ToSlogHandler(c.log))), ) if err != nil { diff --git a/pkg/sip/config.go b/pkg/sip/config.go index daa70afd..a96d9c7f 100644 --- a/pkg/sip/config.go +++ b/pkg/sip/config.go @@ -20,45 +20,47 @@ import ( "io" "net" "net/http" + "net/netip" ) -func getPublicIP() (string, error) { +func getPublicIP() (netip.Addr, error) { req, err := http.Get("http://ip-api.com/json/") if err != nil { - return "", err + return netip.Addr{}, err } defer req.Body.Close() body, err := io.ReadAll(req.Body) if err != nil { - return "", err + return netip.Addr{}, err } ip := struct { Query string }{} if err = json.Unmarshal(body, &ip); err != nil { - return "", err + return netip.Addr{}, err } if ip.Query == "" { - return "", fmt.Errorf("Query entry was not populated") + return netip.Addr{}, fmt.Errorf("Query entry was not populated") } - return ip.Query, nil + return netip.ParseAddr(ip.Query) } -func getLocalIP(localNet string) (string, error) { +func getLocalIP(localNet string) (netip.Addr, error) { ifaces, err := net.Interfaces() if err != nil { - return "", err + return netip.Addr{}, err } - var netw *net.IPNet + var netw *netip.Prefix if localNet != "" { - _, netw, err = net.ParseCIDR(localNet) + nw, err := netip.ParsePrefix(localNet) if err != nil { - return "", err + return netip.Addr{}, err } + netw = &nw } for _, i := range ifaces { addrs, err := i.Addrs() @@ -69,23 +71,33 @@ func getLocalIP(localNet string) (string, error) { for _, a := range addrs { switch v := a.(type) { case *net.IPAddr: - if netw != nil && !netw.Contains(v.IP) { + if v.IP.To4() == nil { continue } - if !v.IP.IsLoopback() && v.IP.To4() != nil { - return v.IP.String(), nil + ip, ok := netip.AddrFromSlice(v.IP.To4()) + if !ok || ip.IsLoopback() { + continue + } + if netw != nil && !netw.Contains(ip) { + continue } + return ip, nil case *net.IPNet: - if netw != nil && !netw.Contains(v.IP) { + if v.IP.To4() == nil { continue } - if !v.IP.IsLoopback() && v.IP.To4() != nil { - return v.IP.String(), nil + ip, ok := netip.AddrFromSlice(v.IP.To4()) + if !ok || ip.IsLoopback() { + continue + } + if netw != nil && !netw.Contains(ip) { + continue } + return ip, nil } } } - return "", fmt.Errorf("No local interface found") + return netip.Addr{}, fmt.Errorf("No local interface found") } diff --git a/pkg/sip/inbound.go b/pkg/sip/inbound.go index 7089d400..ab729fde 100644 --- a/pkg/sip/inbound.go +++ b/pkg/sip/inbound.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "net/netip" "slices" "sync" "sync/atomic" @@ -362,7 +363,7 @@ func (c *inboundCall) handleInvite(ctx context.Context, req *sip.Request, trunkI } acceptCall := func() bool { c.log.Infow("Accepting the call", "headers", disp.Headers) - if err = c.cc.Accept(ctx, c.s.signalingIp, c.s.conf.SIPPort, answerData, disp.Headers); err != nil { + if err = c.cc.Accept(ctx, netip.AddrPortFrom(c.s.signalingIp, uint16(c.s.conf.SIPPort)), answerData, disp.Headers); err != nil { c.log.Errorw("Cannot respond to INVITE", err) return false } @@ -965,7 +966,7 @@ func (c *sipInbound) setDestFromVia(r *sip.Response) { } } -func (c *sipInbound) Accept(ctx context.Context, contactHost string, contactPort int, sdpData []byte, headers map[string]string) error { +func (c *sipInbound) Accept(ctx context.Context, contactHost netip.AddrPort, sdpData []byte, headers map[string]string) error { ctx, span := tracer.Start(ctx, "sipInbound.Accept") defer span.End() c.mu.Lock() @@ -976,7 +977,7 @@ func (c *sipInbound) Accept(ctx context.Context, contactHost string, contactPort r := sip.NewResponseFromRequest(c.invite, 200, "OK", sdpData) // This will effectively redirect future SIP requests to this server instance (if host address is not LB). - r.AppendHeader(&sip.ContactHeader{Address: sip.Uri{Host: contactHost, Port: contactPort}}) + r.AppendHeader(&sip.ContactHeader{Address: sip.Uri{Host: contactHost.Addr().String(), Port: int(contactHost.Port())}}) c.setDestFromVia(r) @@ -1081,7 +1082,7 @@ func (c *sipInbound) newReferReq(transferTo string) (*sip.Request, error) { } // This will effectively redirect future SIP requests to this server instance (if host address is not LB). - contactHeader := &sip.ContactHeader{Address: sip.Uri{Host: c.s.signalingIp, Port: c.s.conf.SIPPort}} + contactHeader := &sip.ContactHeader{Address: sip.Uri{Host: c.s.signalingIp.String(), Port: c.s.conf.SIPPort}} req := NewReferRequest(c.invite, c.inviteOk, contactHeader, transferTo) c.setCSeq(req) diff --git a/pkg/sip/media_port.go b/pkg/sip/media_port.go index 1676cf76..1b5c728a 100644 --- a/pkg/sip/media_port.go +++ b/pkg/sip/media_port.go @@ -16,6 +16,7 @@ package sip import ( "context" + "net/netip" "sync" "sync/atomic" "time" @@ -33,7 +34,7 @@ import ( ) type MediaConfig struct { - IP string + IP netip.Addr Ports rtcconfig.PortRange MediaTimeoutInitial time.Duration MediaTimeout time.Duration @@ -71,7 +72,7 @@ func NewMediaPortWith(log logger.Logger, mon *stats.CallMonitor, conn rtp.UDPCon type MediaPort struct { log logger.Logger mon *stats.CallMonitor - externalIP string + externalIP netip.Addr conn *rtp.Conn mediaTimeout <-chan struct{} dtmfAudioEnabled bool diff --git a/pkg/sip/media_port_test.go b/pkg/sip/media_port_test.go index 98ea6c93..44079535 100644 --- a/pkg/sip/media_port_test.go +++ b/pkg/sip/media_port_test.go @@ -19,6 +19,7 @@ import ( "io" "math" "net" + "net/netip" "slices" "strconv" "strings" @@ -111,6 +112,14 @@ func PrintAudioInWriter(p *MediaPort) string { return p.audioInHandler.(fmt.Stringer).String() } +func newIP(v string) netip.Addr { + ip, err := netip.ParseAddr(v) + if err != nil { + panic(err) + } + return ip +} + func TestMediaPort(t *testing.T) { codecs := media.Codecs() disableAll := func() { @@ -151,14 +160,14 @@ func TestMediaPort(t *testing.T) { log := logger.GetLogger() m1, err := NewMediaPortWith(log.WithName("one"), nil, c1, &MediaConfig{ - IP: "1.1.1.1", + IP: newIP("1.1.1.1"), Ports: rtcconfig.PortRange{Start: 10000}, }, rate) require.NoError(t, err) defer m1.Close() m2, err := NewMediaPortWith(log.WithName("two"), nil, c2, &MediaConfig{ - IP: "2.2.2.2", + IP: newIP("2.2.2.2"), Ports: rtcconfig.PortRange{Start: 20000}, }, rate) require.NoError(t, err) diff --git a/pkg/sip/outbound.go b/pkg/sip/outbound.go index 046bcafc..a90ace29 100644 --- a/pkg/sip/outbound.go +++ b/pkg/sip/outbound.go @@ -74,7 +74,7 @@ type outboundCall struct { func (c *Client) newCall(ctx context.Context, conf *config.Config, log logger.Logger, id LocalTag, room RoomConfig, sipConf sipOutboundConfig) (*outboundCall, error) { if sipConf.host == "" { - sipConf.host = c.signalingIp + sipConf.host = c.signalingIp.String() } call := &outboundCall{ c: c, @@ -82,7 +82,7 @@ func (c *Client) newCall(ctx context.Context, conf *config.Config, log logger.Lo cc: c.newOutbound(id, URI{ User: sipConf.from, Host: sipConf.host, - Addr: netip.AddrPortFrom(netip.Addr{}, uint16(conf.SIPPort)), + Addr: netip.AddrPortFrom(c.signalingIp, uint16(conf.SIPPort)), }), sipConf: sipConf, } @@ -438,19 +438,24 @@ func (c *Client) newOutbound(id LocalTag, from URI) *sipOutbound { Address: *from.GetURI(), Params: sip.NewParams(), } + contactHeader := &sip.ContactHeader{ + Address: *from.GetContactURI(), + } fromHeader.Params.Add("tag", string(id)) return &sipOutbound{ c: c, id: id, from: fromHeader, + contact: contactHeader, referDone: make(chan error), // Do not buffer the channel to avoid reading a result for an old request } } type sipOutbound struct { - c *Client - id LocalTag - from *sip.FromHeader + c *Client + id LocalTag + from *sip.FromHeader + contact *sip.ContactHeader mu sync.RWMutex tag RemoteTag @@ -631,7 +636,7 @@ func (c *sipOutbound) attemptInvite(ctx context.Context, dest string, to *sip.To req.SetBody(offer) req.AppendHeader(to) req.AppendHeader(c.from) - req.AppendHeader(&sip.ContactHeader{Address: c.from.Address}) + req.AppendHeader(c.contact) req.AppendHeader(sip.NewHeader("Content-Type", "application/sdp")) req.AppendHeader(sip.NewHeader("Allow", "INVITE, ACK, CANCEL, BYE, NOTIFY, REFER, MESSAGE, OPTIONS, INFO, SUBSCRIBE")) diff --git a/pkg/sip/server.go b/pkg/sip/server.go index cbf89c81..ffcc976d 100644 --- a/pkg/sip/server.go +++ b/pkg/sip/server.go @@ -31,6 +31,7 @@ import ( "github.com/livekit/protocol/logger" "github.com/livekit/protocol/rpc" + "github.com/livekit/sip/pkg/media" "github.com/livekit/sip/pkg/config" @@ -110,8 +111,8 @@ type Server struct { sipConnUDP *net.UDPConn sipConnTCP *net.TCPListener sipUnhandled RequestHandler - signalingIp string - signalingIpLocal string + signalingIp netip.Addr + signalingIpLocal netip.Addr inProgressInvites []*inProgressInvite @@ -208,7 +209,11 @@ func (s *Server) Start(agent *sipgo.UserAgent, unhandled RequestHandler) error { return err } } else if s.conf.NAT1To1IP != "" { - s.signalingIp = s.conf.NAT1To1IP + ip, err := netip.ParseAddr(s.conf.NAT1To1IP) + if err != nil { + return err + } + s.signalingIp = ip s.signalingIpLocal = s.signalingIp } else { if s.signalingIp, err = getLocalIP(s.conf.LocalNet); err != nil { diff --git a/pkg/sip/signaling.go b/pkg/sip/signaling.go index 21dbc327..ef8109d8 100644 --- a/pkg/sip/signaling.go +++ b/pkg/sip/signaling.go @@ -147,7 +147,7 @@ func sdpAnswerMediaDesc(rtpListenerPort int, res *MediaConf) []*sdp.MediaDescrip } } -func sdpGenerateOffer(publicIp string, rtpListenerPort int) ([]byte, error) { +func sdpGenerateOffer(publicIp netip.Addr, rtpListenerPort int) ([]byte, error) { sessId := rand.Uint64() // TODO: do we need to track these? mediaDesc := sdpMediaOffer(rtpListenerPort) @@ -159,13 +159,13 @@ func sdpGenerateOffer(publicIp string, rtpListenerPort int) ([]byte, error) { SessionVersion: sessId, NetworkType: "IN", AddressType: "IP4", - UnicastAddress: publicIp, + UnicastAddress: publicIp.String(), }, SessionName: "LiveKit", ConnectionInformation: &sdp.ConnectionInformation{ NetworkType: "IN", AddressType: "IP4", - Address: &sdp.Address{Address: publicIp}, + Address: &sdp.Address{Address: publicIp.String()}, }, TimeDescriptions: []sdp.TimeDescription{ { @@ -182,7 +182,7 @@ func sdpGenerateOffer(publicIp string, rtpListenerPort int) ([]byte, error) { return data, err } -func sdpGenerateAnswer(offer *sdp.SessionDescription, publicIp string, rtpListenerPort int, res *MediaConf) ([]byte, error) { +func sdpGenerateAnswer(offer *sdp.SessionDescription, publicIp netip.Addr, rtpListenerPort int, res *MediaConf) ([]byte, error) { answer := sdp.SessionDescription{ Version: 0, Origin: sdp.Origin{ @@ -191,13 +191,13 @@ func sdpGenerateAnswer(offer *sdp.SessionDescription, publicIp string, rtpListen SessionVersion: offer.Origin.SessionID + 2, NetworkType: "IN", AddressType: "IP4", - UnicastAddress: publicIp, + UnicastAddress: publicIp.String(), }, SessionName: "LiveKit", ConnectionInformation: &sdp.ConnectionInformation{ NetworkType: "IN", AddressType: "IP4", - Address: &sdp.Address{Address: publicIp}, + Address: &sdp.Address{Address: publicIp.String()}, }, TimeDescriptions: []sdp.TimeDescription{ { diff --git a/pkg/sip/types.go b/pkg/sip/types.go index a8bb0edf..bc4a8f5a 100644 --- a/pkg/sip/types.go +++ b/pkg/sip/types.go @@ -100,6 +100,17 @@ func (u URI) GetURI() *sip.Uri { return su } +func (u URI) GetContactURI() *sip.Uri { + su := &sip.Uri{ + User: u.User, + Host: u.Addr.Addr().String(), + } + if port := u.Addr.Port(); port != 0 { + su.Port = int(port) + } + return su +} + type LocalTag string type RemoteTag string diff --git a/pkg/siptest/client.go b/pkg/siptest/client.go index a0b0a2cc..f14eb92e 100644 --- a/pkg/siptest/client.go +++ b/pkg/siptest/client.go @@ -24,6 +24,7 @@ import ( "math" "math/rand" "net" + "net/netip" "os" "slices" "strconv" @@ -49,7 +50,7 @@ import ( ) type ClientConfig struct { - IP string + IP netip.Addr Number string AuthUser string AuthPass string @@ -73,7 +74,7 @@ func NewClient(id string, conf ClientConfig) (*Client, error) { if id != "" { conf.Log = conf.Log.With("id", id) } - if conf.IP == "" { + if !conf.IP.IsValid() { localIP, err := config.GetLocalIP() if err != nil { return nil, err @@ -123,7 +124,7 @@ func NewClient(id string, conf ClientConfig) (*Client, error) { return nil, err } - cli.sipClient, err = sipgo.NewClient(ua, sipgo.WithClientHostname(conf.IP)) + cli.sipClient, err = sipgo.NewClient(ua, sipgo.WithClientHostname(conf.IP.String())) if err != nil { cli.Close() return nil, err @@ -183,7 +184,7 @@ type Client struct { } func (c *Client) LocalIP() string { - return c.conf.IP + return c.conf.IP.String() } func (c *Client) RemoteHeaders() []sip.Header { @@ -501,13 +502,13 @@ func (c *Client) createOffer() ([]byte, error) { SessionVersion: sessionId, NetworkType: "IN", AddressType: "IP4", - UnicastAddress: c.conf.IP, + UnicastAddress: c.conf.IP.String(), }, SessionName: "LiveKit", ConnectionInformation: &sdp.ConnectionInformation{ NetworkType: "IN", AddressType: "IP4", - Address: &sdp.Address{Address: c.conf.IP}, + Address: &sdp.Address{Address: c.conf.IP.String()}, }, TimeDescriptions: []sdp.TimeDescription{ { diff --git a/test/integration/sip_test.go b/test/integration/sip_test.go index 8ddacd19..93c7ad04 100644 --- a/test/integration/sip_test.go +++ b/test/integration/sip_test.go @@ -66,7 +66,7 @@ func runSIPServer(t testing.TB, lk *LiveKit) *SIPServer { Redis: lk.Redis, SIPPort: sipPort, SIPPortListen: sipPort, - ListenIP: local, + ListenIP: local.String(), RTPPort: rtcconfig.PortRange{Start: 20000, End: 20010}, UseExternalIP: false, MaxCpuUtilization: 0.9,