From 613c8fbc7309718eb0bc4177d1f1f795f9fde11f Mon Sep 17 00:00:00 2001 From: Denys Smirnov Date: Mon, 25 Nov 2024 12:53:31 +0200 Subject: [PATCH] Initial support for SRTP. --- pkg/media/dtmf/dtmf.go | 6 +- pkg/media/rtp/conn.go | 22 ++-- pkg/media/rtp/jitter.go | 11 +- pkg/media/rtp/listen.go | 7 +- pkg/media/rtp/mux.go | 18 ++-- pkg/media/rtp/rtp.go | 49 ++++----- pkg/media/rtp/session.go | 207 ++++++++++++++++++++++++++++++++++++ pkg/media/sdp/offer.go | 163 +++++++++++++++++++++++++--- pkg/media/sdp/offer_test.go | 48 ++++++++- pkg/media/srtp/srtp.go | 156 +++++++++++++++++++++++++++ pkg/sip/media.go | 38 +++---- pkg/sip/media_port.go | 191 +++++++++++++++++++++++++++------ pkg/sip/media_port_test.go | 114 +++++++++++++------- pkg/sip/outbound.go | 5 +- pkg/sip/service_test.go | 3 +- pkg/siptest/client.go | 20 ++-- 16 files changed, 889 insertions(+), 169 deletions(-) create mode 100644 pkg/media/rtp/session.go create mode 100644 pkg/media/srtp/srtp.go diff --git a/pkg/media/dtmf/dtmf.go b/pkg/media/dtmf/dtmf.go index e89998ab..8deebcc3 100644 --- a/pkg/media/dtmf/dtmf.go +++ b/pkg/media/dtmf/dtmf.go @@ -161,11 +161,11 @@ func Decode(data []byte) (Event, error) { }, nil } -func DecodeRTP(p *rtp.Packet) (Event, bool) { - if !p.Marker { +func DecodeRTP(h *rtp.Header, payload []byte) (Event, bool) { + if !h.Marker { return Event{}, false } - ev, err := Decode(p.Payload) + ev, err := Decode(payload) if err != nil { return Event{}, false } diff --git a/pkg/media/rtp/conn.go b/pkg/media/rtp/conn.go index 626c67cb..6463214c 100644 --- a/pkg/media/rtp/conn.go +++ b/pkg/media/rtp/conn.go @@ -16,6 +16,7 @@ package rtp import ( "net" + "net/netip" "sync" "sync/atomic" "time" @@ -131,9 +132,11 @@ func (c *Conn) Listen(portMin, portMax int, listenAddr string) error { if listenAddr == "" { listenAddr = "0.0.0.0" } - - var err error - c.conn, err = ListenUDPPortRange(portMin, portMax, net.ParseIP(listenAddr)) + ip, err := netip.ParseAddr(listenAddr) + if err != nil { + return err + } + c.conn, err = ListenUDPPortRange(portMin, portMax, ip) if err != nil { return err } @@ -167,24 +170,23 @@ func (c *Conn) readLoop() { close(c.received) } if h := c.onRTP.Load(); h != nil { - _ = (*h).HandleRTP(&p) + _ = (*h).HandleRTP(&p.Header, p.Payload) } } } -func (c *Conn) WriteRTP(p *rtp.Packet) error { +func (c *Conn) WriteRTP(h *rtp.Header, payload []byte) (int, error) { addr := c.dest.Load() if addr == nil { - return nil + return 0, nil } - data, err := p.Marshal() + data, err := (&rtp.Packet{Header: *h, Payload: payload}).Marshal() if err != nil { - return err + return 0, err } c.wmu.Lock() defer c.wmu.Unlock() - _, err = c.conn.WriteToUDP(data, addr) - return err + return c.conn.WriteToUDP(data, addr) } func (c *Conn) ReadRTP() (*rtp.Packet, *net.UDPAddr, error) { diff --git a/pkg/media/rtp/jitter.go b/pkg/media/rtp/jitter.go index 9a1f3874..929c56b0 100644 --- a/pkg/media/rtp/jitter.go +++ b/pkg/media/rtp/jitter.go @@ -17,8 +17,9 @@ package rtp import ( "time" - "github.com/livekit/server-sdk-go/v2/pkg/jitter" "github.com/pion/rtp" + + "github.com/livekit/server-sdk-go/v2/pkg/jitter" ) const ( @@ -41,11 +42,11 @@ type jitterHandler struct { buf *jitter.Buffer } -func (h *jitterHandler) HandleRTP(p *rtp.Packet) error { - h.buf.Push(p) +func (r *jitterHandler) HandleRTP(h *rtp.Header, payload []byte) error { + r.buf.Push(&rtp.Packet{Header: *h, Payload: payload}) var last error - for _, p := range h.buf.Pop(false) { - if err := h.h.HandleRTP(p); err != nil { + for _, p := range r.buf.Pop(false) { + if err := r.h.HandleRTP(&p.Header, p.Payload); err != nil { last = err } } diff --git a/pkg/media/rtp/listen.go b/pkg/media/rtp/listen.go index 3d9fb7c2..72c26652 100644 --- a/pkg/media/rtp/listen.go +++ b/pkg/media/rtp/listen.go @@ -18,14 +18,15 @@ import ( "errors" "math/rand" "net" + "net/netip" ) var ListenErr = errors.New("failed to listen on udp port") -func ListenUDPPortRange(portMin, portMax int, IP net.IP) (*net.UDPConn, error) { +func ListenUDPPortRange(portMin, portMax int, ip netip.Addr) (*net.UDPConn, error) { if portMin == 0 && portMax == 0 { return net.ListenUDP("udp", &net.UDPAddr{ - IP: IP, + IP: ip.AsSlice(), Port: 0, }) } @@ -48,7 +49,7 @@ func ListenUDPPortRange(portMin, portMax int, IP net.IP) (*net.UDPConn, error) { portCurrent := portStart for { - c, e := net.ListenUDP("udp", &net.UDPAddr{IP: IP, Port: portCurrent}) + c, e := net.ListenUDP("udp", &net.UDPAddr{IP: ip.AsSlice(), Port: portCurrent}) if e == nil { return c, nil } diff --git a/pkg/media/rtp/mux.go b/pkg/media/rtp/mux.go index d115fdee..52d20226 100644 --- a/pkg/media/rtp/mux.go +++ b/pkg/media/rtp/mux.go @@ -35,25 +35,25 @@ type Mux struct { // HandleRTP selects a Handler based on payload type. // Types can be registered with Register. If no handler is set, a default one will be used. -func (m *Mux) HandleRTP(p *rtp.Packet) error { +func (m *Mux) HandleRTP(h *rtp.Header, payload []byte) error { if m == nil { return nil } - var h Handler + var r Handler m.mu.RLock() - if p.PayloadType < byte(len(m.static)) { - h = m.static[p.PayloadType] + if h.PayloadType < byte(len(m.static)) { + r = m.static[h.PayloadType] } else { - h = m.dynamic[p.PayloadType] + r = m.dynamic[h.PayloadType] } - if h == nil { - h = m.def + if r == nil { + r = m.def } m.mu.RUnlock() - if h == nil { + if r == nil { return nil } - return h.HandleRTP(p) + return r.HandleRTP(h, payload) } // SetDefault sets a default RTP handler. diff --git a/pkg/media/rtp/rtp.go b/pkg/media/rtp/rtp.go index 08ccf7c5..de1034de 100644 --- a/pkg/media/rtp/rtp.go +++ b/pkg/media/rtp/rtp.go @@ -17,6 +17,7 @@ package rtp import ( "fmt" "math/rand/v2" + "slices" "sync" "github.com/pion/interceptor" @@ -31,7 +32,7 @@ type BytesFrame interface { } type Writer interface { - WriteRTP(p *rtp.Packet) error + WriteRTP(h *rtp.Header, payload []byte) (int, error) } type Reader interface { @@ -39,13 +40,13 @@ type Reader interface { } type Handler interface { - HandleRTP(p *rtp.Packet) error + HandleRTP(h *rtp.Header, payload []byte) error } -type HandlerFunc func(p *rtp.Packet) error +type HandlerFunc func(h *rtp.Header, payload []byte) error -func (fnc HandlerFunc) HandleRTP(p *rtp.Packet) error { - return fnc(p) +func (fnc HandlerFunc) HandleRTP(h *rtp.Header, payload []byte) error { + return fnc(h, payload) } func HandleLoop(r Reader, h Handler) error { @@ -54,7 +55,7 @@ func HandleLoop(r Reader, h Handler) error { if err != nil { return err } - err = h.HandleRTP(p) + err = h.HandleRTP(&p.Header, p.Payload) if err != nil { return err } @@ -64,26 +65,27 @@ func HandleLoop(r Reader, h Handler) error { // Buffer is a Writer that clones and appends RTP packets into a slice. type Buffer []*Packet -func (b *Buffer) WriteRTP(p *Packet) error { - p2 := p.Clone() - *b = append(*b, p2) +func (b *Buffer) WriteRTP(h *rtp.Header, payload []byte) error { + *b = append(*b, &rtp.Packet{ + Header: *h, + Payload: slices.Clone(payload), + }) return nil } // NewSeqWriter creates an RTP writer that automatically increments the sequence number. func NewSeqWriter(w Writer) *SeqWriter { s := &SeqWriter{w: w} - s.p = rtp.Packet{ - Header: rtp.Header{ - Version: 2, - SSRC: rand.Uint32(), - SequenceNumber: 0, - }, + s.h = rtp.Header{ + Version: 2, + SSRC: rand.Uint32(), + SequenceNumber: 0, } return s } type Packet = rtp.Packet +type Header = rtp.Header type Event struct { Type byte @@ -95,20 +97,19 @@ type Event struct { type SeqWriter struct { mu sync.Mutex w Writer - p Packet + h Header } func (s *SeqWriter) WriteEvent(ev *Event) error { s.mu.Lock() defer s.mu.Unlock() - s.p.PayloadType = ev.Type - s.p.Payload = ev.Payload - s.p.Marker = ev.Marker - s.p.Timestamp = ev.Timestamp - if err := s.w.WriteRTP(&s.p); err != nil { + s.h.PayloadType = ev.Type + s.h.Marker = ev.Marker + s.h.Timestamp = ev.Timestamp + if _, err := s.w.WriteRTP(&s.h, ev.Payload); err != nil { return err } - s.p.Header.SequenceNumber++ + s.h.SequenceNumber++ return nil } @@ -211,6 +212,6 @@ func (s *MediaStreamIn[T]) String() string { return fmt.Sprintf("RTP(%d) -> %s", s.Writer.SampleRate(), s.Writer) } -func (s *MediaStreamIn[T]) HandleRTP(p *rtp.Packet) error { - return s.Writer.WriteSample(T(p.Payload)) +func (s *MediaStreamIn[T]) HandleRTP(_ *rtp.Header, payload []byte) error { + return s.Writer.WriteSample(T(payload)) } diff --git a/pkg/media/rtp/session.go b/pkg/media/rtp/session.go new file mode 100644 index 00000000..727938e0 --- /dev/null +++ b/pkg/media/rtp/session.go @@ -0,0 +1,207 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rtp + +import ( + "io" + "net" + "slices" + "sync" + + "github.com/frostbyte73/core" + "github.com/pion/rtp" +) + +const enableZeroCopy = true + +type Session interface { + OpenWriteStream() (WriteStream, error) + AcceptStream() (ReadStream, uint32, error) + Close() error +} + +type WriteStream interface { + // WriteRTP writes RTP packet to the connection. + WriteRTP(h *rtp.Header, payload []byte) (int, error) +} + +type ReadStream interface { + // ReadRTP reads RTP packet and its header from the connection. + ReadRTP(h *rtp.Header, payload []byte) (int, error) +} + +func NewSession(conn net.Conn) Session { + return &session{ + conn: conn, + w: &writeStream{conn: conn}, + bySSRC: make(map[uint32]*readStream), + } +} + +type session struct { + conn net.Conn + closed core.Fuse + w *writeStream + + rmu sync.Mutex + rbuf [1500]byte + bySSRC map[uint32]*readStream +} + +func (s *session) OpenWriteStream() (WriteStream, error) { + return s.w, nil +} + +func (s *session) AcceptStream() (ReadStream, uint32, error) { + s.rmu.Lock() + defer s.rmu.Unlock() + for { + n, err := s.conn.Read(s.rbuf[:]) + if err != nil { + return nil, 0, err + } + buf := s.rbuf[:n] + var p rtp.Packet + err = p.Unmarshal(buf) + if err != nil { + continue // ignore + } + + isNew := false + r := s.bySSRC[p.SSRC] + if r == nil { + r = &readStream{ + ssrc: p.SSRC, + closed: s.closed.Watch(), + copied: make(chan int), + recv: make(chan *rtp.Packet, 10), + } + s.bySSRC[p.SSRC] = r + isNew = true + } + r.write(&p) + if isNew { + return r, r.ssrc, nil + } + } +} + +func (s *session) Close() error { + var err error + s.closed.Once(func() { + err = s.conn.Close() + s.rmu.Lock() + defer s.rmu.Unlock() + s.bySSRC = nil + }) + return err +} + +type writeStream struct { + mu sync.Mutex + buf []byte + conn net.Conn +} + +func (w *writeStream) WriteRTP(h *rtp.Header, payload []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + hsz := h.MarshalSize() + sz := hsz + len(payload) + w.buf = w.buf[:0] + w.buf = slices.Grow(w.buf, sz) + buf := w.buf[:sz] + n, err := h.MarshalTo(buf) + if err != nil { + return 0, err + } + copy(buf[n:], payload) + return w.conn.Write(buf) +} + +type readStream struct { + ssrc uint32 + + closed <-chan struct{} + recv chan *rtp.Packet + copied chan int + mu sync.Mutex + hdr *rtp.Header + payload []byte +} + +func (r *readStream) write(p *rtp.Packet) { + if enableZeroCopy { + r.mu.Lock() + h, payload := r.hdr, r.payload + r.hdr, r.payload = nil, nil + r.mu.Unlock() + if h != nil { + // zero copy + *h = p.Header + n := copy(payload, p.Payload) + select { + case <-r.closed: + case r.copied <- n: + } + return + } + } + p.Payload = slices.Clone(p.Payload) + select { + case r.recv <- p: + default: + } +} + +func (r *readStream) ReadRTP(h *rtp.Header, payload []byte) (int, error) { + direct := false + if enableZeroCopy { + r.mu.Lock() + if r.hdr == nil { + r.hdr = h + r.payload = payload + direct = true + } + r.mu.Unlock() + } + if !direct { + select { + case p := <-r.recv: + *h = p.Header + n := copy(payload, p.Payload) + return n, nil + case <-r.closed: + } + return 0, io.EOF + } + defer func() { + r.mu.Lock() + defer r.mu.Unlock() + if r.hdr == h { + r.hdr, r.payload = nil, nil + } + }() + select { + case n := <-r.copied: + return n, nil + case p := <-r.recv: + *h = p.Header + n := copy(payload, p.Payload) + return n, nil + case <-r.closed: + } + return 0, io.EOF +} diff --git a/pkg/media/sdp/offer.go b/pkg/media/sdp/offer.go index 6fceb65a..a80117ab 100644 --- a/pkg/media/sdp/offer.go +++ b/pkg/media/sdp/offer.go @@ -15,6 +15,7 @@ package sdp import ( + "encoding/base64" "errors" "fmt" "math/rand/v2" @@ -29,6 +30,7 @@ import ( "github.com/livekit/sip/pkg/media" "github.com/livekit/sip/pkg/media/dtmf" "github.com/livekit/sip/pkg/media/rtp" + "github.com/livekit/sip/pkg/media/srtp" ) type CodecInfo struct { @@ -70,11 +72,27 @@ func OfferCodecs() []CodecInfo { } type MediaDesc struct { - Codecs []CodecInfo - DTMFType byte // set to 0 if there's no DTMF + Codecs []CodecInfo + DTMFType byte // set to 0 if there's no DTMF + CryptoProfiles []srtp.Profile } -func OfferMedia(rtpListenerPort int) (MediaDesc, *sdp.MediaDescription) { +func appendCryptoProfiles(attrs []sdp.Attribute, profiles []srtp.Profile) []sdp.Attribute { + var buf []byte + for _, p := range profiles { + buf = buf[:0] + buf = append(buf, p.Key...) + buf = append(buf, p.Salt...) + skey := base64.StdEncoding.WithPadding(base64.NoPadding).EncodeToString(buf) + attrs = append(attrs, sdp.Attribute{ + Key: "crypto", + Value: fmt.Sprintf("%d %s inline:%s", p.Index, p.Profile, skey), + }) + } + return attrs +} + +func OfferMedia(rtpListenerPort int, encrypted bool) (MediaDesc, *sdp.MediaDescription, error) { // Static compiler check for frame duration hardcoded below. var _ = [1]struct{}{}[20*time.Millisecond-rtp.DefFrameDur] @@ -98,26 +116,42 @@ func OfferMedia(rtpListenerPort int) (MediaDesc, *sdp.MediaDescription) { Key: "fmtp", Value: fmt.Sprintf("%d 0-16", dtmfType), }) } + var cryptoProfiles []srtp.Profile + if encrypted { + var err error + cryptoProfiles, err = srtp.DefaultProfiles() + if err != nil { + return MediaDesc{}, nil, err + } + attrs = appendCryptoProfiles(attrs, cryptoProfiles) + } + attrs = append(attrs, []sdp.Attribute{ {Key: "ptime", Value: "20"}, {Key: "sendrecv"}, }...) + proto := "AVP" + if encrypted { + proto = "SAVP" + } + return MediaDesc{ - Codecs: codecs, - DTMFType: dtmfType, + Codecs: codecs, + DTMFType: dtmfType, + CryptoProfiles: cryptoProfiles, }, &sdp.MediaDescription{ MediaName: sdp.MediaName{ Media: "audio", Port: sdp.RangedPort{Value: rtpListenerPort}, - Protos: []string{"RTP", "AVP"}, + Protos: []string{"RTP", proto}, Formats: formats, }, Attributes: attrs, - } + }, nil } -func AnswerMedia(rtpListenerPort int, audio *AudioConfig) *sdp.MediaDescription { +func AnswerMedia(rtpListenerPort int, audio *AudioConfig, crypt *srtp.Profile) *sdp.MediaDescription { // Static compiler check for frame duration hardcoded below. var _ = [1]struct{}{}[20*time.Millisecond-rtp.DefFrameDur] @@ -134,6 +168,11 @@ func AnswerMedia(rtpListenerPort int, audio *AudioConfig) *sdp.MediaDescription {Key: "fmtp", Value: fmt.Sprintf("%d 0-16", audio.DTMFType)}, }...) } + proto := "AVP" + if crypt != nil { + proto = "SAVP" + attrs = appendCryptoProfiles(attrs, []srtp.Profile{*crypt}) + } attrs = append(attrs, []sdp.Attribute{ {Key: "ptime", Value: "20"}, {Key: "sendrecv"}, @@ -142,7 +181,7 @@ func AnswerMedia(rtpListenerPort int, audio *AudioConfig) *sdp.MediaDescription MediaName: sdp.MediaName{ Media: "audio", Port: sdp.RangedPort{Value: rtpListenerPort}, - Protos: []string{"RTP", "AVP"}, + Protos: []string{"RTP", proto}, Formats: formats, }, Attributes: attrs, @@ -159,10 +198,13 @@ type Offer Description type Answer Description -func NewOffer(publicIp netip.Addr, rtpListenerPort int) *Offer { +func NewOffer(publicIp netip.Addr, rtpListenerPort int, encrypted bool) (*Offer, error) { sessId := rand.Uint64() // TODO: do we need to track these? - m, mediaDesc := OfferMedia(rtpListenerPort) + m, mediaDesc, err := OfferMedia(rtpListenerPort, encrypted) + if err != nil { + return nil, err + } offer := sdp.SessionDescription{ Version: 0, Origin: sdp.Origin{ @@ -193,7 +235,7 @@ func NewOffer(publicIp netip.Addr, rtpListenerPort int) *Offer { SDP: offer, Addr: netip.AddrPortFrom(publicIp, uint16(rtpListenerPort)), MediaDesc: m, - } + }, nil } func (d *Offer) Answer(publicIp netip.Addr, rtpListenerPort int) (*Answer, *MediaConfig, error) { @@ -202,7 +244,22 @@ func (d *Offer) Answer(publicIp netip.Addr, rtpListenerPort int) (*Answer, *Medi return nil, nil, err } - mediaDesc := AnswerMedia(rtpListenerPort, audio) + var ( + sconf *srtp.Config + sprof *srtp.Profile + ) + if len(d.CryptoProfiles) != 0 { + answer, err := srtp.DefaultProfiles() + if err != nil { + return nil, nil, err + } + sconf, sprof, err = SelectCrypto(d.CryptoProfiles, answer, true) + if err != nil { + return nil, nil, err + } + } + + mediaDesc := AnswerMedia(rtpListenerPort, audio, sprof) answer := sdp.SessionDescription{ Version: 0, Origin: sdp.Origin{ @@ -243,6 +300,7 @@ func (d *Offer) Answer(publicIp netip.Addr, rtpListenerPort int) (*Answer, *Medi Local: src, Remote: d.Addr, Audio: *audio, + Crypto: sconf, }, nil } @@ -251,10 +309,18 @@ func (d *Answer) Apply(offer *Offer) (*MediaConfig, error) { if err != nil { return nil, err } + var sconf *srtp.Config + if len(d.CryptoProfiles) != 0 { + sconf, _, err = SelectCrypto(offer.CryptoProfiles, d.CryptoProfiles, false) + if err != nil { + return nil, err + } + } return &MediaConfig{ Local: offer.Addr, Remote: d.Addr, Audio: *audio, + Crypto: sconf, }, nil } @@ -315,6 +381,39 @@ func ParseMedia(d *sdp.MediaDescription) (*MediaDesc, error) { Type: byte(typ), Codec: codec, }) + case "crypto": + sub := strings.SplitN(m.Value, " ", 3) + if len(sub) != 3 { + continue + } + sind, prof, skey := sub[0], srtp.ProtectionProfile(sub[1]), sub[2] + ind, err := strconv.Atoi(sind) + if err != nil { + return nil, err + } + var ok bool + skey, ok = strings.CutPrefix(skey, "inline:") + if !ok { + continue + } + keys, err := base64.StdEncoding.WithPadding(base64.NoPadding).DecodeString(skey) + if err != nil { + return nil, err + } + var salt []byte + if sp, err := prof.Parse(); err == nil { + keyLen, err := sp.KeyLen() + if err != nil { + return nil, err + } + keys, salt = keys[:keyLen], keys[keyLen:] + } + out.CryptoProfiles = append(out.CryptoProfiles, srtp.Profile{ + Index: ind, + Profile: prof, + Key: keys, + Salt: salt, + }) } } for _, f := range d.MediaName.Formats { @@ -335,6 +434,7 @@ type MediaConfig struct { Local netip.AddrPort Remote netip.AddrPort Audio AudioConfig + Crypto *srtp.Config } type AudioConfig struct { @@ -369,3 +469,40 @@ func SelectAudio(desc MediaDesc) (*AudioConfig, error) { DTMFType: desc.DTMFType, }, nil } + +func SelectCrypto(offer, answer []srtp.Profile, swap bool) (*srtp.Config, *srtp.Profile, error) { + if len(offer) == 0 { + return nil, nil, nil + } + for _, ans := range answer { + sp, err := ans.Profile.Parse() + if err != nil { + continue + } + i := slices.IndexFunc(offer, func(off srtp.Profile) bool { + return off.Profile == ans.Profile + }) + if i >= 0 { + off := offer[i] + c := &srtp.Config{ + Keys: srtp.SessionKeys{ + LocalMasterKey: off.Key, + LocalMasterSalt: off.Salt, + RemoteMasterKey: ans.Key, + RemoteMasterSalt: ans.Salt, + }, + Profile: sp, + } + if swap { + c.Keys.LocalMasterKey, c.Keys.RemoteMasterKey = c.Keys.RemoteMasterKey, c.Keys.LocalMasterKey + c.Keys.LocalMasterSalt, c.Keys.RemoteMasterSalt = c.Keys.RemoteMasterSalt, c.Keys.LocalMasterSalt + } + prof := &off + if swap { + prof = &ans + } + return c, prof, nil + } + } + return nil, nil, errors.New("no common crypto") +} diff --git a/pkg/media/sdp/offer_test.go b/pkg/media/sdp/offer_test.go index 390537c3..cd6cfa89 100644 --- a/pkg/media/sdp/offer_test.go +++ b/pkg/media/sdp/offer_test.go @@ -15,6 +15,8 @@ package sdp_test import ( + "slices" + "strings" "testing" "github.com/pion/sdp/v3" @@ -27,9 +29,19 @@ import ( . "github.com/livekit/sip/pkg/media/sdp" ) +func getInline(s string) string { + const word = "inline:" + i := strings.Index(s, word) + if i < 0 { + return s + } + return s[i+len(word):] +} + func TestSDPMediaOffer(t *testing.T) { const port = 12345 - _, offer := OfferMedia(port) + _, offer, err := OfferMedia(port, false) + require.NoError(t, err) require.Equal(t, &sdp.MediaDescription{ MediaName: sdp.MediaName{ Media: "audio", @@ -48,10 +60,39 @@ func TestSDPMediaOffer(t *testing.T) { }, }, offer) + _, offer, err = OfferMedia(port, true) + require.NoError(t, err) + i := slices.IndexFunc(offer.Attributes, func(a sdp.Attribute) bool { + return a.Key == "crypto" + }) + require.True(t, i > 0) + require.Equal(t, &sdp.MediaDescription{ + MediaName: sdp.MediaName{ + Media: "audio", + Port: sdp.RangedPort{Value: port}, + Protos: []string{"RTP", "SAVP"}, + Formats: []string{"9", "0", "8", "101"}, + }, + Attributes: []sdp.Attribute{ + {Key: "rtpmap", Value: "9 G722/8000"}, + {Key: "rtpmap", Value: "0 PCMU/8000"}, + {Key: "rtpmap", Value: "8 PCMA/8000"}, + {Key: "rtpmap", Value: "101 telephone-event/8000"}, + {Key: "fmtp", Value: "101 0-16"}, + {Key: "crypto", Value: "1 AES_CM_128_HMAC_SHA1_80 inline:" + getInline(offer.Attributes[i+0].Value)}, + {Key: "crypto", Value: "2 AES_CM_128_HMAC_SHA1_32 inline:" + getInline(offer.Attributes[i+1].Value)}, + {Key: "crypto", Value: "3 AES_256_CM_HMAC_SHA1_80 inline:" + getInline(offer.Attributes[i+2].Value)}, + {Key: "crypto", Value: "4 AES_256_CM_HMAC_SHA1_32 inline:" + getInline(offer.Attributes[i+3].Value)}, + {Key: "ptime", Value: "20"}, + {Key: "sendrecv"}, + }, + }, offer) + media.CodecSetEnabled(g722.SDPName, false) defer media.CodecSetEnabled(g722.SDPName, true) - _, offer = OfferMedia(port) + _, offer, err = OfferMedia(port, false) + require.NoError(t, err) require.Equal(t, &sdp.MediaDescription{ MediaName: sdp.MediaName{ Media: "audio", @@ -230,7 +271,8 @@ func TestSDPMediaAnswer(t *testing.T) { require.Equal(t, c.exp, got) }) } - _, offer := OfferMedia(port) + _, offer, err := OfferMedia(port, false) + require.NoError(t, err) require.Equal(t, &sdp.MediaDescription{ MediaName: sdp.MediaName{ Media: "audio", diff --git a/pkg/media/srtp/srtp.go b/pkg/media/srtp/srtp.go new file mode 100644 index 00000000..0d407d63 --- /dev/null +++ b/pkg/media/srtp/srtp.go @@ -0,0 +1,156 @@ +// Copyright 2024 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package srtp + +import ( + "crypto/rand" + "fmt" + "net" + + prtp "github.com/pion/rtp" + "github.com/pion/srtp/v2" + + "github.com/livekit/sip/pkg/media/rtp" +) + +var defaultProfiles = []ProtectionProfile{ + "AES_CM_128_HMAC_SHA1_80", + "AES_CM_128_HMAC_SHA1_32", + "AES_256_CM_HMAC_SHA1_80", + "AES_256_CM_HMAC_SHA1_32", +} + +func DefaultProfiles() ([]Profile, error) { + out := make([]Profile, 0, len(defaultProfiles)) + for i, p := range defaultProfiles { + sp, err := p.Parse() + if err != nil { + return nil, err + } + keyLen, err := sp.KeyLen() + if err != nil { + return nil, err + } + saltLen, err := sp.SaltLen() + if err != nil { + return nil, err + } + key := make([]byte, keyLen) + salt := make([]byte, saltLen) + if _, err := rand.Read(key); err != nil { + return nil, err + } + if _, err := rand.Read(salt); err != nil { + return nil, err + } + out = append(out, Profile{ + Index: i + 1, + Profile: p, + Key: key, + Salt: salt, + }) + } + return out, nil +} + +type Options struct { + Profiles []Profile +} + +type ProtectionProfile string + +func (p ProtectionProfile) Parse() (srtp.ProtectionProfile, error) { + switch p { + case "AES_CM_128_HMAC_SHA1_80": + return srtp.ProtectionProfileAes128CmHmacSha1_80, nil + case "AES_CM_128_HMAC_SHA1_32": + return srtp.ProtectionProfileAes128CmHmacSha1_32, nil + case "AES_256_CM_HMAC_SHA1_80": + return srtp.ProtectionProfileAes256CmHmacSha1_80, nil + case "AES_256_CM_HMAC_SHA1_32": + return srtp.ProtectionProfileAes256CmHmacSha1_32, nil + default: + return 0, fmt.Errorf("unsupported profile %q", p) + } +} + +type Profile struct { + Index int + Profile ProtectionProfile + Key []byte + Salt []byte +} + +type Config = srtp.Config +type SessionKeys = srtp.SessionKeys + +func NewSession(conn net.Conn, conf *Config) (rtp.Session, error) { + s, err := srtp.NewSessionSRTP(conn, conf) + if err != nil { + return nil, err + } + return &session{s: s}, nil +} + +type session struct { + s *srtp.SessionSRTP +} + +func (s *session) OpenWriteStream() (rtp.WriteStream, error) { + w, err := s.s.OpenWriteStream() + if err != nil { + return nil, err + } + return writeStream{w: w}, nil +} + +func (s *session) AcceptStream() (rtp.ReadStream, uint32, error) { + r, ssrc, err := s.s.AcceptStream() + if err != nil { + return nil, 0, err + } + return readStream{r: r}, ssrc, nil +} + +func (s *session) Close() error { + return s.s.Close() +} + +type writeStream struct { + w *srtp.WriteStreamSRTP +} + +func (w writeStream) WriteRTP(h *prtp.Header, payload []byte) (int, error) { + return w.w.WriteRTP(h, payload) +} + +type readStream struct { + r *srtp.ReadStreamSRTP +} + +func (r readStream) ReadRTP(h *prtp.Header, payload []byte) (int, error) { + buf := payload + n, err := r.r.Read(buf) + if err != nil { + return 0, err + } + var p rtp.Packet + if err = p.Unmarshal(buf[:n]); err != nil { + return 0, err + } + *h = p.Header + n = copy(payload, p.Payload) + return n, nil +} diff --git a/pkg/sip/media.go b/pkg/sip/media.go index d4eeb5b8..9810e6e4 100644 --- a/pkg/sip/media.go +++ b/pkg/sip/media.go @@ -17,6 +17,8 @@ package sip import ( "strconv" + prtp "github.com/pion/rtp" + "github.com/livekit/sip/pkg/media/rtp" "github.com/livekit/sip/pkg/stats" ) @@ -26,13 +28,13 @@ const ( RoomSampleRate = 48000 ) -func newRTPStatsHandler(mon *stats.CallMonitor, typ string, h rtp.Handler) rtp.Handler { - if h == nil { - h = rtp.HandlerFunc(func(p *rtp.Packet) error { +func newRTPStatsHandler(mon *stats.CallMonitor, typ string, r rtp.Handler) rtp.Handler { + if r == nil { + r = rtp.HandlerFunc(func(h *rtp.Header, payload []byte) error { return nil }) } - return &rtpStatsHandler{h: h, typ: typ, mon: mon} + return &rtpStatsHandler{h: r, typ: typ, mon: mon} } type rtpStatsHandler struct { @@ -41,34 +43,34 @@ type rtpStatsHandler struct { mon *stats.CallMonitor } -func (h *rtpStatsHandler) HandleRTP(p *rtp.Packet) error { - if h.mon != nil { - typ := h.typ +func (r *rtpStatsHandler) HandleRTP(h *rtp.Header, payload []byte) error { + if r.mon != nil { + typ := r.typ if typ == "" { - typ = strconv.Itoa(int(p.PayloadType)) + typ = strconv.Itoa(int(h.PayloadType)) } - h.mon.RTPPacketRecv(typ) + r.mon.RTPPacketRecv(typ) } - return h.h.HandleRTP(p) + return r.h.HandleRTP(h, payload) } -func newRTPStatsWriter(mon *stats.CallMonitor, typ string, w rtp.Writer) rtp.Writer { +func newRTPStatsWriter(mon *stats.CallMonitor, typ string, w rtp.WriteStream) rtp.WriteStream { return &rtpStatsWriter{w: w, typ: typ, mon: mon} } type rtpStatsWriter struct { - w rtp.Writer + w rtp.WriteStream typ string mon *stats.CallMonitor } -func (h *rtpStatsWriter) WriteRTP(p *rtp.Packet) error { - if h.mon != nil { - typ := h.typ +func (w *rtpStatsWriter) WriteRTP(h *prtp.Header, payload []byte) (int, error) { + if w.mon != nil { + typ := w.typ if typ == "" { - typ = strconv.Itoa(int(p.PayloadType)) + typ = strconv.Itoa(int(h.PayloadType)) } - h.mon.RTPPacketSend(typ) + w.mon.RTPPacketSend(typ) } - return h.w.WriteRTP(p) + return w.w.WriteRTP(h, payload) } diff --git a/pkg/sip/media_port.go b/pkg/sip/media_port.go index 2e7f7c76..bd207abc 100644 --- a/pkg/sip/media_port.go +++ b/pkg/sip/media_port.go @@ -16,12 +16,16 @@ package sip import ( "context" + "errors" + "io" "net" "net/netip" "sync" "sync/atomic" "time" + "github.com/frostbyte73/core" + "github.com/livekit/mediatransportutil/pkg/rtcconfig" "github.com/livekit/protocol/logger" @@ -29,10 +33,56 @@ import ( "github.com/livekit/sip/pkg/media/dtmf" "github.com/livekit/sip/pkg/media/rtp" "github.com/livekit/sip/pkg/media/sdp" + "github.com/livekit/sip/pkg/media/srtp" "github.com/livekit/sip/pkg/mixer" "github.com/livekit/sip/pkg/stats" ) +type UDPConn interface { + net.Conn + ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) + WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) +} + +func newUDPConn(conn UDPConn) *udpConn { + return &udpConn{UDPConn: conn} +} + +type udpConn struct { + UDPConn + src atomic.Pointer[netip.AddrPort] + dst atomic.Pointer[netip.AddrPort] +} + +func (c *udpConn) GetSrc() (netip.AddrPort, bool) { + ptr := c.src.Load() + if ptr == nil { + return netip.AddrPort{}, false + } + addr := *ptr + return addr, addr.IsValid() +} + +func (c *udpConn) SetDst(addr netip.AddrPort) { + if addr.IsValid() { + c.dst.Store(&addr) + } +} + +func (c *udpConn) Read(b []byte) (n int, err error) { + n, addr, err := c.ReadFromUDPAddrPort(b) + c.src.Store(&addr) + return n, err +} + +func (c *udpConn) Write(b []byte) (n int, err error) { + dst := c.dst.Load() + if dst == nil { + return len(b), nil // ignore + } + return c.WriteToUDPAddrPort(b, *dst) +} + type MediaConf struct { sdp.MediaConfig Processor media.PCM16Processor @@ -49,25 +99,23 @@ func NewMediaPort(log logger.Logger, mon *stats.CallMonitor, conf *MediaConfig, return NewMediaPortWith(log, mon, nil, conf, sampleRate) } -func NewMediaPortWith(log logger.Logger, mon *stats.CallMonitor, conn rtp.UDPConn, conf *MediaConfig, sampleRate int) (*MediaPort, error) { +func NewMediaPortWith(log logger.Logger, mon *stats.CallMonitor, conn UDPConn, conf *MediaConfig, sampleRate int) (*MediaPort, error) { + if conn == nil { + c, err := rtp.ListenUDPPortRange(conf.Ports.Start, conf.Ports.End, netip.AddrFrom4([4]byte{0, 0, 0, 0})) + if err != nil { + return nil, err + } + conn = c + } mediaTimeout := make(chan struct{}) p := &MediaPort{ log: log, mon: mon, externalIP: conf.IP, mediaTimeout: mediaTimeout, - conn: rtp.NewConnWith(conn, &rtp.ConnConfig{ - MediaTimeoutInitial: conf.MediaTimeoutInitial, - MediaTimeout: conf.MediaTimeout, - TimeoutCallback: func() { - close(mediaTimeout) - }, - }), - audioOut: media.NewSwitchWriter(sampleRate), - audioIn: media.NewSwitchWriter(sampleRate), - } - if err := p.conn.ListenAndServe(conf.Ports.Start, conf.Ports.End, "0.0.0.0"); err != nil { - return nil, err + port: newUDPConn(conn), + audioOut: media.NewSwitchWriter(sampleRate), + audioIn: media.NewSwitchWriter(sampleRate), } p.log.Debugw("listening for media on UDP", "port", p.Port()) return p, nil @@ -78,13 +126,16 @@ type MediaPort struct { log logger.Logger mon *stats.CallMonitor externalIP netip.Addr - conn *rtp.Conn + port *udpConn mediaTimeout <-chan struct{} + mediaReceived core.Fuse dtmfAudioEnabled bool closed atomic.Bool mu sync.Mutex conf *MediaConf + sess rtp.Session + hnd atomic.Pointer[rtp.Handler] dtmfOutRTP *rtp.Stream dtmfOutAudio media.PCM16Writer @@ -96,7 +147,7 @@ type MediaPort struct { } func (p *MediaPort) EnableTimeout(enabled bool) { - p.conn.EnableTimeout(enabled) + //p.conn.EnableTimeout(enabled) // FIXME } func (p *MediaPort) Close() { @@ -119,15 +170,15 @@ func (p *MediaPort) Close() { p.dtmfOutAudio = nil } p.dtmfIn.Store(nil) - _ = p.conn.Close() + _ = p.port.Close() } func (p *MediaPort) Port() int { - return p.conn.LocalAddr().Port + return p.port.LocalAddr().(*net.UDPAddr).Port } func (p *MediaPort) Received() <-chan struct{} { - return p.conn.Received() + return p.mediaReceived.Watch() } func (p *MediaPort) Timeout() <-chan struct{} { @@ -153,8 +204,8 @@ func (p *MediaPort) GetAudioWriter() media.PCM16Writer { } // NewOffer generates an SDP offer for the media. -func (p *MediaPort) NewOffer() *sdp.Offer { - return sdp.NewOffer(p.externalIP, p.Port()) +func (p *MediaPort) NewOffer(encrypted bool) (*sdp.Offer, error) { + return sdp.NewOffer(p.externalIP, p.Port(), encrypted) } // SetAnswer decodes and applies SDP answer for offer from NewOffer. SetConfig must be called with the decoded configuration. @@ -184,30 +235,104 @@ func (p *MediaPort) SetOffer(offerData []byte) (*sdp.Answer, *MediaConf, error) } func (p *MediaPort) SetConfig(c *MediaConf) error { + var crypto string + if c.Crypto != nil { + crypto = c.Crypto.Profile.String() + } p.log.Infow("using codecs", "audio-codec", c.Audio.Codec.Info().SDPName, "audio-rtp", c.Audio.Type, "dtmf-rtp", c.Audio.DTMFType, + "srtp", crypto, ) + p.port.SetDst(c.Remote) + var ( + sess rtp.Session + err error + ) + if c.Crypto != nil { + sess, err = srtp.NewSession(p.port, c.Crypto) + } else { + sess = rtp.NewSession(p.port) + } + if err != nil { + return err + } + p.mu.Lock() defer p.mu.Unlock() - if ip := c.Remote; ip.IsValid() { - p.conn.SetDestAddr(&net.UDPAddr{ - IP: ip.Addr().AsSlice(), - Port: int(ip.Port()), - }) - } + p.port.SetDst(c.Remote) p.conf = c + p.sess = sess - p.setupOutput() + if err = p.setupOutput(); err != nil { + return err + } p.setupInput() return nil } +func (p *MediaPort) rtpLoop(sess rtp.Session) { + // Need a loop to process all incoming packets. + first := true + for { + r, ssrc, err := sess.AcceptStream() + if err != nil { + if !errors.Is(err, net.ErrClosed) { + p.log.Errorw("cannot accept RTP stream", err) + } + return + } + p.mediaReceived.Break() + if first { + first = false + p.log.Infow("accepting media", "ssrc", ssrc) + go p.rtpReadLoop(r) + } else { + p.log.Warnw("ignoring media", nil, "ssrc", ssrc) + } + } +} + +func (p *MediaPort) rtpReadLoop(r rtp.ReadStream) { + buf := make([]byte, 1500) + var h rtp.Header + for { + h = rtp.Header{} + n, err := r.ReadRTP(&h, buf) + if err == io.EOF { + return + } else if err != nil { + p.log.Errorw("read RTP failed", err) + return + } + + ptr := p.hnd.Load() + if ptr == nil { + continue + } + hnd := *ptr + if hnd == nil { + continue + } + err = hnd.HandleRTP(&h, buf[:n]) + if err != nil { + p.log.Errorw("handle RTP failed", err) + continue + } + } +} + // Must be called holding the lock -func (p *MediaPort) setupOutput() { +func (p *MediaPort) setupOutput() error { + go p.rtpLoop(p.sess) + w, err := p.sess.OpenWriteStream() + if err != nil { + return err + } + // TODO: this says "audio", but actually includes DTMF too - s := rtp.NewSeqWriter(newRTPStatsWriter(p.mon, "audio", p.conn)) + s := rtp.NewSeqWriter(newRTPStatsWriter(p.mon, "audio", w)) p.audioOutRTP = s.NewStream(p.conf.Audio.Type, p.conf.Audio.Codec.Info().RTPClockRate) // Encoding pipeline (LK -> SIP) @@ -230,6 +355,7 @@ func (p *MediaPort) setupOutput() { if w := p.audioOut.Swap(audioOut); w != nil { _ = w.Close() } + return nil } func (p *MediaPort) setupInput() { @@ -241,19 +367,20 @@ func (p *MediaPort) setupInput() { mux.SetDefault(newRTPStatsHandler(p.mon, "", nil)) mux.Register(p.conf.Audio.Type, newRTPStatsHandler(p.mon, p.conf.Audio.Codec.Info().SDPName, audioHandler)) if p.conf.Audio.DTMFType != 0 { - mux.Register(p.conf.Audio.DTMFType, newRTPStatsHandler(p.mon, dtmf.SDPName, rtp.HandlerFunc(func(pck *rtp.Packet) error { + mux.Register(p.conf.Audio.DTMFType, newRTPStatsHandler(p.mon, dtmf.SDPName, rtp.HandlerFunc(func(h *rtp.Header, payload []byte) error { ptr := p.dtmfIn.Load() if ptr == nil { return nil } fnc := *ptr - if ev, ok := dtmf.DecodeRTP(pck); ok && fnc != nil { + if ev, ok := dtmf.DecodeRTP(h, payload); ok && fnc != nil { fnc(ev) } return nil }))) } - p.conn.OnRTP(mux) + var hnd rtp.Handler = mux + p.hnd.Store(&hnd) } // SetDTMFAudio forces SIP to generate audio dTMF tones in addition to digital signals. diff --git a/pkg/sip/media_port_test.go b/pkg/sip/media_port_test.go index 94494bbb..21c2b085 100644 --- a/pkg/sip/media_port_test.go +++ b/pkg/sip/media_port_test.go @@ -37,17 +37,60 @@ import ( ) type testUDPConn struct { - addr *net.UDPAddr + addr netip.AddrPort closed chan struct{} buf chan []byte peer atomic.Pointer[testUDPConn] } -func (c *testUDPConn) LocalAddr() net.Addr { - return c.addr +func (c *testUDPConn) Read(b []byte) (int, error) { + n, _, err := c.ReadFromUDPAddrPort(b) + return n, err +} + +func (c *testUDPConn) Write(b []byte) (int, error) { + return c.WriteToUDPAddrPort(b, netip.AddrPort{}) +} + +func (c *testUDPConn) RemoteAddr() net.Addr { + p := c.peer.Load() + if p == nil { + return &net.UDPAddr{} + } + return p.LocalAddr() +} + +func (c *testUDPConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *testUDPConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *testUDPConn) SetWriteDeadline(t time.Time) error { + return nil +} + +func (c *testUDPConn) ReadFromUDPAddrPort(buf []byte) (int, netip.AddrPort, error) { + peer := c.peer.Load() + if peer == nil { + return 0, netip.AddrPort{}, io.ErrClosedPipe + } + select { + case <-c.closed: + return 0, netip.AddrPort{}, io.ErrClosedPipe + case data := <-c.buf: + n := copy(buf, data) + var err error + if n < len(data) { + err = io.ErrShortBuffer + } + return n, peer.addr, err + } } -func (c *testUDPConn) WriteToUDP(buf []byte, addr *net.UDPAddr) (int, error) { +func (c *testUDPConn) WriteToUDPAddrPort(buf []byte, addr netip.AddrPort) (int, error) { peer := c.peer.Load() if peer == nil { return 0, io.ErrClosedPipe @@ -65,21 +108,10 @@ func (c *testUDPConn) WriteToUDP(buf []byte, addr *net.UDPAddr) (int, error) { } } -func (c *testUDPConn) ReadFromUDP(buf []byte) (int, *net.UDPAddr, error) { - peer := c.peer.Load() - if peer == nil { - return 0, nil, io.ErrClosedPipe - } - select { - case <-c.closed: - return 0, nil, io.ErrClosedPipe - case data := <-c.buf: - n := copy(buf, data) - var err error - if n < len(data) { - err = io.ErrShortBuffer - } - return n, peer.addr, err +func (c *testUDPConn) LocalAddr() net.Addr { + return &net.UDPAddr{ + IP: c.addr.Addr().AsSlice(), + Port: int(c.addr.Port()), } } @@ -90,20 +122,20 @@ func (c *testUDPConn) Close() error { return nil } -func newUDPConn(i int) *testUDPConn { +func newTestConn(i int) *testUDPConn { return &testUDPConn{ - addr: &net.UDPAddr{ - IP: net.IPv4(byte(i), byte(i), byte(i), byte(i)), - Port: 10000 * i, - }, + addr: netip.AddrPortFrom( + netip.AddrFrom4([4]byte{byte(i), byte(i), byte(i), byte(i)}), + uint16(10000*i), + ), buf: make(chan []byte, 10), closed: make(chan struct{}), } } func newUDPPipe() (c1, c2 *testUDPConn) { - c1 = newUDPConn(1) - c2 = newUDPConn(2) + c1 = newTestConn(1) + c2 = newTestConn(2) c1.peer.Store(c2) c2.peer.Store(c1) return @@ -151,11 +183,18 @@ func TestMediaPort(t *testing.T) { nativeRate *= 2 // error in RFC } - for _, rate := range []int{ - nativeRate, - 48000, + for _, tconf := range []struct { + Rate int + Encrypted bool + }{ + {nativeRate, false}, + {48000, true}, } { - t.Run(strconv.Itoa(rate), func(t *testing.T) { + suff := "" + if tconf.Encrypted { + suff = " srtp" + } + t.Run(fmt.Sprintf("%d%s", tconf.Rate, suff), func(t *testing.T) { c1, c2 := newUDPPipe() log := logger.GetLogger() @@ -163,18 +202,19 @@ func TestMediaPort(t *testing.T) { m1, err := NewMediaPortWith(log.WithName("one"), nil, c1, &MediaConfig{ IP: newIP("1.1.1.1"), Ports: rtcconfig.PortRange{Start: 10000}, - }, rate) + }, tconf.Rate) require.NoError(t, err) defer m1.Close() m2, err := NewMediaPortWith(log.WithName("two"), nil, c2, &MediaConfig{ IP: newIP("2.2.2.2"), Ports: rtcconfig.PortRange{Start: 20000}, - }, rate) + }, tconf.Rate) require.NoError(t, err) defer m2.Close() - offer := m1.NewOffer() + offer, err := m1.NewOffer(tconf.Encrypted) + require.NoError(t, err) offerData, err := offer.SDP.Marshal() require.NoError(t, err) @@ -200,15 +240,15 @@ func TestMediaPort(t *testing.T) { require.Equal(t, info.SDPName, m2.Config().Audio.Codec.Info().SDPName) var buf1 media.PCM16Sample - m1.WriteAudioTo(media.NewPCM16BufferWriter(&buf1, rate)) + m1.WriteAudioTo(media.NewPCM16BufferWriter(&buf1, tconf.Rate)) var buf2 media.PCM16Sample - m2.WriteAudioTo(media.NewPCM16BufferWriter(&buf2, rate)) + m2.WriteAudioTo(media.NewPCM16BufferWriter(&buf2, tconf.Rate)) w1 := m1.GetAudioWriter() w2 := m2.GetAudioWriter() - packetSize := uint32(rate / int(time.Second/rtp.DefFrameDur)) + packetSize := uint32(tconf.Rate / int(time.Second/rtp.DefFrameDur)) sample1 := make(media.PCM16Sample, packetSize) sample2 := make(media.PCM16Sample, packetSize) for i := range packetSize { @@ -217,7 +257,7 @@ func TestMediaPort(t *testing.T) { } writes := 1 - if rate == nativeRate { + if tconf.Rate == nativeRate { expChain := fmt.Sprintf("Switch(%d) -> %s(encode) -> RTP(%d)", nativeRate, name, nativeRate) require.Equal(t, expChain, w1.String()) require.Equal(t, expChain, w2.String()) diff --git a/pkg/sip/outbound.go b/pkg/sip/outbound.go index a9df96df..7d9da999 100644 --- a/pkg/sip/outbound.go +++ b/pkg/sip/outbound.go @@ -398,7 +398,10 @@ func (c *outboundCall) sipSignal(ctx context.Context) error { cancel() }() - sdpOffer := c.media.NewOffer() + sdpOffer, err := c.media.NewOffer(false) // FIXME: enable for TLS? + if err != nil { + return err + } sdpOfferData, err := sdpOffer.SDP.Marshal() if err != nil { return err diff --git a/pkg/sip/service_test.go b/pkg/sip/service_test.go index af8790fb..87088c0f 100644 --- a/pkg/sip/service_test.go +++ b/pkg/sip/service_test.go @@ -108,7 +108,8 @@ func testInvite(t *testing.T, h Handler, hidden bool, from, to string, test func sipClient, err := sipgo.NewClient(sipUserAgent) require.NoError(t, err) - offer := sdp.NewOffer(localIP, 0xB0B) + offer, err := sdp.NewOffer(localIP, 0xB0B, false) + require.NoError(t, err) offerData, err := offer.SDP.Marshal() require.NoError(t, err) diff --git a/pkg/siptest/client.go b/pkg/siptest/client.go index 4eacdec2..28866fe5 100644 --- a/pkg/siptest/client.go +++ b/pkg/siptest/client.go @@ -217,31 +217,31 @@ func (c *Client) Close() { func (c *Client) setupRTPReceiver() { var lastTs atomic.Uint32 - c.mux = rtp.NewMux(rtp.HandlerFunc(func(pck *rtp.Packet) error { - lastTs.Store(pck.Timestamp) + c.mux = rtp.NewMux(rtp.HandlerFunc(func(hdr *rtp.Header, payload []byte) error { + lastTs.Store(hdr.Timestamp) h := c.recordHandler.Load() if h != nil { - return (*h).HandleRTP(pck) + return (*h).HandleRTP(hdr, payload) } return nil })) - c.mux.Register(101, rtp.HandlerFunc(func(pck *rtp.Packet) error { + c.mux.Register(101, rtp.HandlerFunc(func(hdr *rtp.Header, payload []byte) error { ts := lastTs.Load() var diff int64 if ts > 0 { - diff = int64(pck.Timestamp) - int64(ts) + diff = int64(hdr.Timestamp) - int64(ts) } if diff > int64(c.audioCodec.Info().RTPClockRate) || diff < -int64(c.audioCodec.Info().RTPClockRate) { - c.log.Info("reveived out of sync DTMF message", "dtmfTs", pck.Timestamp, "lastTs", ts) + c.log.Info("reveived out of sync DTMF message", "dtmfTs", hdr.Timestamp, "lastTs", ts) return nil } if c.conf.OnDTMF == nil { return nil } - if ev, ok := dtmf.DecodeRTP(pck); ok { + if ev, ok := dtmf.DecodeRTP(hdr, payload); ok { c.conf.OnDTMF(ev) } return nil @@ -655,7 +655,7 @@ func (c *Client) WaitSignals(ctx context.Context, vals []int, w io.WriteCloser) pkts := make(chan *rtp.Packet, 1) done := make(chan struct{}) - h := rtp.Handler(rtp.HandlerFunc(func(pkt *rtp.Packet) error { + h := rtp.Handler(rtp.HandlerFunc(func(hdr *rtp.Header, payload []byte) error { // Make sure er do not send on a closed channel select { case <-done: @@ -668,7 +668,7 @@ func (c *Client) WaitSignals(ctx context.Context, vals []int, w io.WriteCloser) close(pkts) close(done) return ctx.Err() - case pkts <- pkt: + case pkts <- &rtp.Packet{Header: *hdr, Payload: slices.Clone(payload)}: } return nil @@ -687,7 +687,7 @@ func (c *Client) WaitSignals(ctx context.Context, vals []int, w io.WriteCloser) continue } decoded = decoded[:0] - if err := dec.HandleRTP(p); err != nil { + if err := dec.HandleRTP(&p.Header, p.Payload); err != nil { return err } if ws != nil {