diff --git a/pkg/config/config.go b/pkg/config/config.go index d1065f0..8830752 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -35,13 +35,25 @@ import ( ) const ( - DefaultSIPPort int = 5060 + DefaultSIPPort int = 5060 + DefaultSIPPortTLS int = 5061 ) var ( DefaultRTPPortRange = rtcconfig.PortRange{Start: 10000, End: 20000} ) +type TLSCert struct { + CertFile string `yaml:"cert_file"` + KeyFile string `yaml:"key_file"` +} + +type TLSConfig struct { + Port int `yaml:"port"` // announced SIP signaling port + ListenPort int `yaml:"port_listen"` // SIP signaling port to listen on + Certs []TLSCert `yaml:"certs"` +} + type Config struct { Redis *redis.RedisConfig `yaml:"redis"` // required ApiKey string `yaml:"api_key"` // required (env LIVEKIT_API_KEY) @@ -53,6 +65,8 @@ type Config struct { PProfPort int `yaml:"pprof_port"` SIPPort int `yaml:"sip_port"` // announced SIP signaling port SIPPortListen int `yaml:"sip_port_listen"` // SIP signaling port to listen on + SIPHostname string `yaml:"sip_hostname"` + TLS *TLSConfig `yaml:"tls"` RTPPort rtcconfig.PortRange `yaml:"rtp_port"` Logging logger.Config `yaml:"logging"` ClusterID string `yaml:"cluster_id"` // cluster this instance belongs to @@ -109,6 +123,14 @@ func (c *Config) Init() error { if c.SIPPortListen == 0 { c.SIPPortListen = c.SIPPort } + if tc := c.TLS; tc != nil { + if tc.Port == 0 { + tc.Port = DefaultSIPPortTLS + } + if tc.ListenPort == 0 { + tc.ListenPort = tc.Port + } + } if c.RTPPort.Start == 0 { c.RTPPort.Start = DefaultRTPPortRange.Start } diff --git a/pkg/sip/client.go b/pkg/sip/client.go index 1a3d9e1..c223e12 100644 --- a/pkg/sip/client.go +++ b/pkg/sip/client.go @@ -131,6 +131,10 @@ func (c *Client) SetHandler(handler Handler) { c.handler = handler } +func (c *Client) ContactURI(tr Transport) URI { + return getContactURI(c.conf, c.signalingIp, tr) +} + func (c *Client) CreateSIPParticipant(ctx context.Context, req *rpc.InternalCreateSIPParticipantRequest) (*rpc.InternalCreateSIPParticipantResponse, error) { ctx, span := tracer.Start(ctx, "Client.CreateSIPParticipant") defer span.End() diff --git a/pkg/sip/inbound.go b/pkg/sip/inbound.go index 4a667af..7001b45 100644 --- a/pkg/sip/inbound.go +++ b/pkg/sip/inbound.go @@ -134,7 +134,8 @@ func (s *Server) onInvite(req *sip.Request, tx sip.ServerTransaction) { "toIP", req.Destination(), ) - cc := s.newInbound(LocalTag(callID), req, tx) + tr := transportFromReq(req) + cc := s.newInbound(LocalTag(callID), s.ContactURI(tr), req, tx) log = LoggerWithParams(log, cc) log = LoggerWithHeaders(log, cc) log.Infow("processing invite") @@ -345,7 +346,7 @@ func (c *inboundCall) handleInvite(ctx context.Context, req *sip.Request, trunkI case DispatchNoRuleDrop: c.log.Debugw("Rejecting inbound flood") c.cc.Drop() - c.close(false, callDropped, "flood") + c.close(false, callFlood, "flood") return case DispatchNoRuleReject: c.log.Infow("Rejecting inbound call, doesn't match any Dispatch Rules") @@ -368,7 +369,7 @@ func (c *inboundCall) handleInvite(ctx context.Context, req *sip.Request, trunkI } acceptCall := func() bool { c.log.Infow("Accepting the call", "headers", disp.Headers) - if err = c.cc.Accept(ctx, netip.AddrPortFrom(c.s.signalingIp, uint16(c.s.conf.SIPPort)), answerData, disp.Headers); err != nil { + if err = c.cc.Accept(ctx, answerData, disp.Headers); err != nil { c.log.Errorw("Cannot respond to INVITE", err) return false } @@ -627,7 +628,10 @@ func (c *inboundCall) close(error bool, status CallStatus, reason string) { } else { c.log.Infow("Closing inbound call", "reason", reason) } - defer c.log.Infow("Inbound call closed", "reason", reason) + if status != callFlood { + defer c.log.Infow("Inbound call closed", "reason", reason) + } + c.closeMedia() c.cc.Close() if c.callDur != nil { @@ -815,12 +819,15 @@ func (c *inboundCall) transferCall(ctx context.Context, transferTo string, dialt } -func (s *Server) newInbound(id LocalTag, invite *sip.Request, inviteTx sip.ServerTransaction) *sipInbound { +func (s *Server) newInbound(id LocalTag, contact URI, invite *sip.Request, inviteTx sip.ServerTransaction) *sipInbound { c := &sipInbound{ - s: s, - id: id, - invite: invite, - inviteTx: inviteTx, + s: s, + id: id, + invite: invite, + inviteTx: inviteTx, + contact: &sip.ContactHeader{ + Address: *contact.GetContactURI(), + }, cancelled: make(chan struct{}), referDone: make(chan error), // Do not buffer the channel to avoid reading a result for an old request } @@ -843,6 +850,7 @@ type sipInbound struct { tag RemoteTag invite *sip.Request inviteTx sip.ServerTransaction + contact *sip.ContactHeader cancelled chan struct{} from *sip.FromHeader to *sip.ToHeader @@ -1009,7 +1017,7 @@ func (c *sipInbound) setDestFromVia(r *sip.Response) { } } -func (c *sipInbound) Accept(ctx context.Context, contactHost netip.AddrPort, sdpData []byte, headers map[string]string) error { +func (c *sipInbound) Accept(ctx context.Context, sdpData []byte, headers map[string]string) error { ctx, span := tracer.Start(ctx, "sipInbound.Accept") defer span.End() c.mu.Lock() @@ -1020,7 +1028,7 @@ func (c *sipInbound) Accept(ctx context.Context, contactHost netip.AddrPort, sdp r := sip.NewResponseFromRequest(c.invite, 200, "OK", sdpData) // This will effectively redirect future SIP requests to this server instance (if host address is not LB). - r.AppendHeader(&sip.ContactHeader{Address: sip.Uri{Host: contactHost.Addr().String(), Port: int(contactHost.Port())}}) + r.AppendHeader(c.contact) c.setDestFromVia(r) @@ -1125,9 +1133,7 @@ func (c *sipInbound) newReferReq(transferTo string) (*sip.Request, error) { } // This will effectively redirect future SIP requests to this server instance (if host address is not LB). - contactHeader := &sip.ContactHeader{Address: sip.Uri{Host: c.s.signalingIp.String(), Port: c.s.conf.SIPPort}} - - req := NewReferRequest(c.invite, c.inviteOk, contactHeader, transferTo) + req := NewReferRequest(c.invite, c.inviteOk, c.contact, transferTo) c.setCSeq(req) c.swapSrcDst(req) diff --git a/pkg/sip/outbound.go b/pkg/sip/outbound.go index 68b8662..b65bd92 100644 --- a/pkg/sip/outbound.go +++ b/pkg/sip/outbound.go @@ -19,7 +19,6 @@ import ( "errors" "fmt" "math" - "net/netip" "sort" "sync" "time" @@ -75,23 +74,26 @@ type outboundCall struct { } func (c *Client) newCall(ctx context.Context, conf *config.Config, log logger.Logger, id LocalTag, room RoomConfig, sipConf sipOutboundConfig) (*outboundCall, error) { - if sipConf.host == "" { - sipConf.host = c.signalingIp.String() - } if sipConf.maxCallDuration <= 0 || sipConf.maxCallDuration > maxCallDuration { sipConf.maxCallDuration = maxCallDuration } if sipConf.ringingTimeout <= 0 { sipConf.ringingTimeout = defaultRingingTimeout } + tr := TransportFrom(sipConf.transport) + contact := c.ContactURI(tr) + if sipConf.host == "" { + sipConf.host = contact.GetHost() + } call := &outboundCall{ c: c, log: log, cc: c.newOutbound(id, URI{ - User: sipConf.from, - Host: sipConf.host, - Addr: netip.AddrPortFrom(c.signalingIp, uint16(conf.SIPPort)), - }), + User: sipConf.from, + Host: sipConf.host, + Addr: contact.Addr, + Transport: tr, + }, contact), sipConf: sipConf, } call.mon = c.mon.NewCall(stats.Outbound, sipConf.host, sipConf.address) @@ -375,9 +377,10 @@ func (c *outboundCall) sipSignal(ctx context.Context) error { joinDur := c.mon.JoinDur() c.mon.InviteReq() - sdpResp, err := c.cc.Invite(ctx, c.sipConf.transport, URI{ - User: c.sipConf.to, - Host: c.sipConf.address, + sdpResp, err := c.cc.Invite(ctx, URI{ + User: c.sipConf.to, + Host: c.sipConf.address, + Transport: TransportFrom(c.sipConf.transport), }, c.sipConf.user, c.sipConf.pass, c.sipConf.headers, sdpOffer) if err != nil { // TODO: should we retry? maybe new offer will work @@ -473,7 +476,7 @@ func (c *outboundCall) transferCall(ctx context.Context, transferTo string, dial return nil } -func (c *Client) newOutbound(id LocalTag, from URI) *sipOutbound { +func (c *Client) newOutbound(id LocalTag, from, contact URI) *sipOutbound { from = from.Normalize() fromHeader := &sip.FromHeader{ DisplayName: from.User, @@ -481,7 +484,7 @@ func (c *Client) newOutbound(id LocalTag, from URI) *sipOutbound { Params: sip.NewParams(), } contactHeader := &sip.ContactHeader{ - Address: *from.GetContactURI(), + Address: *contact.GetContactURI(), } fromHeader.Params.Add("tag", string(id)) return &sipOutbound{ @@ -542,20 +545,13 @@ func (c *sipOutbound) RemoteHeaders() Headers { return c.inviteOk.Headers() } -func (c *sipOutbound) Invite(ctx context.Context, transport livekit.SIPTransport, to URI, user, pass string, headers map[string]string, sdpOffer []byte) ([]byte, error) { +func (c *sipOutbound) Invite(ctx context.Context, to URI, user, pass string, headers map[string]string, sdpOffer []byte) ([]byte, error) { ctx, span := tracer.Start(ctx, "sipOutbound.Invite") defer span.End() c.mu.Lock() defer c.mu.Unlock() to = to.Normalize() toHeader := &sip.ToHeader{Address: *to.GetURI()} - toHeader.Address.UriParams = make(sip.HeaderParams) - switch transport { - case livekit.SIPTransport_SIP_TRANSPORT_UDP: - toHeader.Address.UriParams.Add("transport", "udp") - case livekit.SIPTransport_SIP_TRANSPORT_TCP: - toHeader.Address.UriParams.Add("transport", "tcp") - } dest := to.GetDest() diff --git a/pkg/sip/participant.go b/pkg/sip/participant.go index f94c3c7..6350762 100644 --- a/pkg/sip/participant.go +++ b/pkg/sip/participant.go @@ -75,6 +75,7 @@ func (v CallStatus) DisconnectReason() livekit.DisconnectReason { const ( callDropped = CallStatus(iota) + callFlood CallDialing CallAutomation CallActive diff --git a/pkg/sip/protocol.go b/pkg/sip/protocol.go index 1b78944..0c18c16 100644 --- a/pkg/sip/protocol.go +++ b/pkg/sip/protocol.go @@ -17,6 +17,7 @@ package sip import ( "context" "fmt" + "net/netip" "regexp" "strconv" "strings" @@ -25,6 +26,8 @@ import ( "github.com/emiago/sipgo/sip" "github.com/livekit/psrpc" "github.com/pkg/errors" + + "github.com/livekit/sip/pkg/config" ) const ( @@ -61,6 +64,35 @@ type Signaling interface { Drop() } +func transportFromReq(req *sip.Request) Transport { + if to, _ := req.To(); to != nil { + if tr, _ := to.Params.Get("transport"); tr != "" { + return Transport(strings.ToLower(tr)) + } + } + if via, _ := req.Via(); via != nil { + return Transport(strings.ToLower(via.Transport)) + } + return "" +} + +func transportPort(c *config.Config, t Transport) int { + if t == TransportTLS { + if tc := c.TLS; tc != nil { + return tc.Port + } + } + return c.SIPPort +} + +func getContactURI(c *config.Config, ip netip.Addr, t Transport) URI { + return URI{ + Host: c.SIPHostname, + Addr: netip.AddrPortFrom(ip, uint16(transportPort(c, t))), + Transport: t, + } +} + func sendAndACK(ctx context.Context, c Signaling, req *sip.Request) { tx, err := c.Transaction(req) if err != nil { diff --git a/pkg/sip/server.go b/pkg/sip/server.go index c0a5968..e1ce040 100644 --- a/pkg/sip/server.go +++ b/pkg/sip/server.go @@ -16,8 +16,10 @@ package sip import ( "context" + "crypto/tls" "errors" "fmt" + "io" "log/slog" "net" "net/netip" @@ -111,8 +113,7 @@ type Server struct { log logger.Logger mon *stats.Monitor sipSrv *sipgo.Server - sipConnUDP *net.UDPConn - sipConnTCP *net.TCPListener + sipListeners []io.Closer sipUnhandled RequestHandler signalingIp netip.Addr signalingIpLocal netip.Addr @@ -154,6 +155,10 @@ func (s *Server) SetHandler(handler Handler) { s.handler = handler } +func (s *Server) ContactURI(tr Transport) URI { + return getContactURI(s.conf, s.signalingIp, tr) +} + func (s *Server) startUDP(addr netip.AddrPort) error { lis, err := net.ListenUDP("udp", &net.UDPAddr{ IP: addr.Addr().AsSlice(), @@ -162,10 +167,10 @@ func (s *Server) startUDP(addr netip.AddrPort) error { if err != nil { return fmt.Errorf("cannot listen on the UDP signaling port %d: %w", s.conf.SIPPortListen, err) } - s.sipConnUDP = lis + s.sipListeners = append(s.sipListeners, lis) s.log.Infow("sip signaling listening on", "local", s.signalingIpLocal, "external", s.signalingIp, - "port", s.conf.SIPPortListen, "announce-port", s.conf.SIPPort, + "port", addr.Port(), "announce-port", s.conf.SIPPort, "proto", "udp", ) @@ -185,10 +190,10 @@ func (s *Server) startTCP(addr netip.AddrPort) error { if err != nil { return fmt.Errorf("cannot listen on the TCP signaling port %d: %w", s.conf.SIPPortListen, err) } - s.sipConnTCP = lis + s.sipListeners = append(s.sipListeners, lis) s.log.Infow("sip signaling listening on", "local", s.signalingIpLocal, "external", s.signalingIp, - "port", s.conf.SIPPortListen, "announce-port", s.conf.SIPPort, + "port", addr.Port(), "announce-port", s.conf.SIPPort, "proto", "tcp", ) @@ -200,6 +205,30 @@ func (s *Server) startTCP(addr netip.AddrPort) error { return nil } +func (s *Server) startTLS(addr netip.AddrPort, conf *tls.Config) error { + tlis, err := net.ListenTCP("tcp", &net.TCPAddr{ + IP: addr.Addr().AsSlice(), + Port: int(addr.Port()), + }) + if err != nil { + return fmt.Errorf("cannot listen on the TLS signaling port %d: %w", s.conf.SIPPortListen, err) + } + lis := tls.NewListener(tlis, conf) + s.sipListeners = append(s.sipListeners, lis) + s.log.Infow("sip signaling listening on", + "local", s.signalingIpLocal, "external", s.signalingIp, + "port", addr.Port(), "announce-port", s.conf.TLS.Port, + "proto", "tls", + ) + + go func() { + if err := s.sipSrv.ServeTLS(lis); err != nil && !errors.Is(err, net.ErrClosed) { + panic(fmt.Errorf("SIP listen TLS error: %w", err)) + } + }() + return nil +} + type RequestHandler func(req *sip.Request, tx sip.ServerTransaction) bool func (s *Server) Start(agent *sipgo.UserAgent, unhandled RequestHandler) error { @@ -265,6 +294,27 @@ func (s *Server) Start(agent *sipgo.UserAgent, unhandled RequestHandler) error { if err := s.startTCP(addr); err != nil { return err } + if tconf := s.conf.TLS; tconf != nil { + if len(tconf.Certs) == 0 { + return errors.New("TLS certificate required") + } + var certs []tls.Certificate + for _, c := range tconf.Certs { + cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile) + if err != nil { + return err + } + certs = append(certs, cert) + } + tlsConf := &tls.Config{ + NextProtos: []string{"sip"}, + Certificates: certs, + } + addrTLS := netip.AddrPortFrom(ip, uint16(tconf.ListenPort)) + if err := s.startTLS(addrTLS, tlsConf); err != nil { + return err + } + } return nil } @@ -281,11 +331,8 @@ func (s *Server) Stop() { if s.sipSrv != nil { _ = s.sipSrv.Close() } - if s.sipConnUDP != nil { - _ = s.sipConnUDP.Close() - } - if s.sipConnTCP != nil { - _ = s.sipConnTCP.Close() + for _, l := range s.sipListeners { + _ = l.Close() } } diff --git a/pkg/sip/types.go b/pkg/sip/types.go index bc4a8f5..fffa5f1 100644 --- a/pkg/sip/types.go +++ b/pkg/sip/types.go @@ -21,6 +21,7 @@ import ( "strconv" "github.com/emiago/sipgo/sip" + "github.com/livekit/protocol/livekit" "github.com/livekit/protocol/logger" ) @@ -36,10 +37,31 @@ func (h Headers) GetHeader(name string) sip.Header { return nil } +func TransportFrom(t livekit.SIPTransport) Transport { + switch t { + case livekit.SIPTransport_SIP_TRANSPORT_UDP: + return TransportUDP + case livekit.SIPTransport_SIP_TRANSPORT_TCP: + return TransportTCP + case livekit.SIPTransport_SIP_TRANSPORT_TLS: + return TransportTLS + } + return "" +} + +type Transport string + +const ( + TransportUDP = Transport("udp") + TransportTCP = Transport("tcp") + TransportTLS = Transport("tls") +) + type URI struct { - User string - Host string - Addr netip.AddrPort + User string + Host string + Addr netip.AddrPort + Transport Transport } func (u URI) Normalize() URI { @@ -97,16 +119,23 @@ func (u URI) GetURI() *sip.Uri { if port := u.Addr.Port(); port != 0 { su.Port = int(port) } + if u.Transport != "" { + if su.UriParams == nil { + su.UriParams = make(sip.HeaderParams) + } + su.UriParams.Add("transport", string(u.Transport)) + } return su } func (u URI) GetContactURI() *sip.Uri { - su := &sip.Uri{ - User: u.User, - Host: u.Addr.Addr().String(), - } - if port := u.Addr.Port(); port != 0 { - su.Port = int(port) + su := u.GetURI() + switch u.Transport { + case TransportUDP, TransportTCP: + // Use IP instead of a hostname for TCP and UDP. + if addr := u.Addr.Addr(); addr.IsValid() { + su.Host = addr.String() + } } return su }