Skip to content

Commit

Permalink
Initial support for SRTP.
Browse files Browse the repository at this point in the history
  • Loading branch information
dennwc committed Nov 25, 2024
1 parent ba2922c commit df0348f
Show file tree
Hide file tree
Showing 16 changed files with 873 additions and 169 deletions.
6 changes: 3 additions & 3 deletions pkg/media/dtmf/dtmf.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ func Decode(data []byte) (Event, error) {
}, nil
}

func DecodeRTP(p *rtp.Packet) (Event, bool) {
if !p.Marker {
func DecodeRTP(h *rtp.Header, payload []byte) (Event, bool) {
if !h.Marker {
return Event{}, false
}
ev, err := Decode(p.Payload)
ev, err := Decode(payload)
if err != nil {
return Event{}, false
}
Expand Down
22 changes: 12 additions & 10 deletions pkg/media/rtp/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package rtp

import (
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -131,9 +132,11 @@ func (c *Conn) Listen(portMin, portMax int, listenAddr string) error {
if listenAddr == "" {
listenAddr = "0.0.0.0"
}

var err error
c.conn, err = ListenUDPPortRange(portMin, portMax, net.ParseIP(listenAddr))
ip, err := netip.ParseAddr(listenAddr)
if err != nil {
return err
}
c.conn, err = ListenUDPPortRange(portMin, portMax, ip)
if err != nil {
return err
}
Expand Down Expand Up @@ -167,24 +170,23 @@ func (c *Conn) readLoop() {
close(c.received)
}
if h := c.onRTP.Load(); h != nil {
_ = (*h).HandleRTP(&p)
_ = (*h).HandleRTP(&p.Header, p.Payload)
}
}
}

func (c *Conn) WriteRTP(p *rtp.Packet) error {
func (c *Conn) WriteRTP(h *rtp.Header, payload []byte) (int, error) {
addr := c.dest.Load()
if addr == nil {
return nil
return 0, nil
}
data, err := p.Marshal()
data, err := (&rtp.Packet{Header: *h, Payload: payload}).Marshal()
if err != nil {
return err
return 0, err
}
c.wmu.Lock()
defer c.wmu.Unlock()
_, err = c.conn.WriteToUDP(data, addr)
return err
return c.conn.WriteToUDP(data, addr)
}

func (c *Conn) ReadRTP() (*rtp.Packet, *net.UDPAddr, error) {
Expand Down
11 changes: 6 additions & 5 deletions pkg/media/rtp/jitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ package rtp
import (
"time"

"github.com/livekit/server-sdk-go/v2/pkg/jitter"
"github.com/pion/rtp"

"github.com/livekit/server-sdk-go/v2/pkg/jitter"
)

const (
Expand All @@ -41,11 +42,11 @@ type jitterHandler struct {
buf *jitter.Buffer
}

func (h *jitterHandler) HandleRTP(p *rtp.Packet) error {
h.buf.Push(p)
func (r *jitterHandler) HandleRTP(h *rtp.Header, payload []byte) error {
r.buf.Push(&rtp.Packet{Header: *h, Payload: payload})
var last error
for _, p := range h.buf.Pop(false) {
if err := h.h.HandleRTP(p); err != nil {
for _, p := range r.buf.Pop(false) {
if err := r.h.HandleRTP(&p.Header, p.Payload); err != nil {
last = err
}
}
Expand Down
7 changes: 4 additions & 3 deletions pkg/media/rtp/listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ import (
"errors"
"math/rand"
"net"
"net/netip"
)

var ListenErr = errors.New("failed to listen on udp port")

func ListenUDPPortRange(portMin, portMax int, IP net.IP) (*net.UDPConn, error) {
func ListenUDPPortRange(portMin, portMax int, ip netip.Addr) (*net.UDPConn, error) {
if portMin == 0 && portMax == 0 {
return net.ListenUDP("udp", &net.UDPAddr{
IP: IP,
IP: ip.AsSlice(),
Port: 0,
})
}
Expand All @@ -48,7 +49,7 @@ func ListenUDPPortRange(portMin, portMax int, IP net.IP) (*net.UDPConn, error) {
portCurrent := portStart

for {
c, e := net.ListenUDP("udp", &net.UDPAddr{IP: IP, Port: portCurrent})
c, e := net.ListenUDP("udp", &net.UDPAddr{IP: ip.AsSlice(), Port: portCurrent})
if e == nil {
return c, nil
}
Expand Down
18 changes: 9 additions & 9 deletions pkg/media/rtp/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,25 @@ type Mux struct {

// HandleRTP selects a Handler based on payload type.
// Types can be registered with Register. If no handler is set, a default one will be used.
func (m *Mux) HandleRTP(p *rtp.Packet) error {
func (m *Mux) HandleRTP(h *rtp.Header, payload []byte) error {
if m == nil {
return nil
}
var h Handler
var r Handler
m.mu.RLock()
if p.PayloadType < byte(len(m.static)) {
h = m.static[p.PayloadType]
if h.PayloadType < byte(len(m.static)) {
r = m.static[h.PayloadType]
} else {
h = m.dynamic[p.PayloadType]
r = m.dynamic[h.PayloadType]
}
if h == nil {
h = m.def
if r == nil {
r = m.def
}
m.mu.RUnlock()
if h == nil {
if r == nil {
return nil
}
return h.HandleRTP(p)
return r.HandleRTP(h, payload)
}

// SetDefault sets a default RTP handler.
Expand Down
49 changes: 25 additions & 24 deletions pkg/media/rtp/rtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package rtp
import (
"fmt"
"math/rand/v2"
"slices"
"sync"

"github.com/pion/interceptor"
Expand All @@ -31,21 +32,21 @@ type BytesFrame interface {
}

type Writer interface {
WriteRTP(p *rtp.Packet) error
WriteRTP(h *rtp.Header, payload []byte) (int, error)
}

type Reader interface {
ReadRTP() (*rtp.Packet, interceptor.Attributes, error)
}

type Handler interface {
HandleRTP(p *rtp.Packet) error
HandleRTP(h *rtp.Header, payload []byte) error
}

type HandlerFunc func(p *rtp.Packet) error
type HandlerFunc func(h *rtp.Header, payload []byte) error

func (fnc HandlerFunc) HandleRTP(p *rtp.Packet) error {
return fnc(p)
func (fnc HandlerFunc) HandleRTP(h *rtp.Header, payload []byte) error {
return fnc(h, payload)
}

func HandleLoop(r Reader, h Handler) error {
Expand All @@ -54,7 +55,7 @@ func HandleLoop(r Reader, h Handler) error {
if err != nil {
return err
}
err = h.HandleRTP(p)
err = h.HandleRTP(&p.Header, p.Payload)
if err != nil {
return err
}
Expand All @@ -64,26 +65,27 @@ func HandleLoop(r Reader, h Handler) error {
// Buffer is a Writer that clones and appends RTP packets into a slice.
type Buffer []*Packet

func (b *Buffer) WriteRTP(p *Packet) error {
p2 := p.Clone()
*b = append(*b, p2)
func (b *Buffer) WriteRTP(h *rtp.Header, payload []byte) error {
*b = append(*b, &rtp.Packet{
Header: *h,
Payload: slices.Clone(payload),
})
return nil
}

// NewSeqWriter creates an RTP writer that automatically increments the sequence number.
func NewSeqWriter(w Writer) *SeqWriter {
s := &SeqWriter{w: w}
s.p = rtp.Packet{
Header: rtp.Header{
Version: 2,
SSRC: rand.Uint32(),
SequenceNumber: 0,
},
s.h = rtp.Header{
Version: 2,
SSRC: rand.Uint32(),
SequenceNumber: 0,
}
return s
}

type Packet = rtp.Packet
type Header = rtp.Header

type Event struct {
Type byte
Expand All @@ -95,20 +97,19 @@ type Event struct {
type SeqWriter struct {
mu sync.Mutex
w Writer
p Packet
h Header
}

func (s *SeqWriter) WriteEvent(ev *Event) error {
s.mu.Lock()
defer s.mu.Unlock()
s.p.PayloadType = ev.Type
s.p.Payload = ev.Payload
s.p.Marker = ev.Marker
s.p.Timestamp = ev.Timestamp
if err := s.w.WriteRTP(&s.p); err != nil {
s.h.PayloadType = ev.Type
s.h.Marker = ev.Marker
s.h.Timestamp = ev.Timestamp
if _, err := s.w.WriteRTP(&s.h, ev.Payload); err != nil {
return err
}
s.p.Header.SequenceNumber++
s.h.SequenceNumber++
return nil
}

Expand Down Expand Up @@ -211,6 +212,6 @@ func (s *MediaStreamIn[T]) String() string {
return fmt.Sprintf("RTP(%d) -> %s", s.Writer.SampleRate(), s.Writer)
}

func (s *MediaStreamIn[T]) HandleRTP(p *rtp.Packet) error {
return s.Writer.WriteSample(T(p.Payload))
func (s *MediaStreamIn[T]) HandleRTP(_ *rtp.Header, payload []byte) error {
return s.Writer.WriteSample(T(payload))
}
Loading

0 comments on commit df0348f

Please sign in to comment.