Skip to content

Commit

Permalink
[MM-60189] Rate limit forwarding PLI requests (#154)
Browse files Browse the repository at this point in the history
* Send PLI request on receiving screen track

* Rate limit forwarding PLI requests

* Tests
  • Loading branch information
streamer45 authored Sep 16, 2024
1 parent 3e1e123 commit c9ed0e1
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 25 deletions.
36 changes: 31 additions & 5 deletions client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ package client
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"time"

"github.com/pion/rtcp"
"github.com/pion/webrtc/v3"
)

Expand Down Expand Up @@ -52,7 +54,9 @@ func (c *Client) Unmute(track webrtc.TrackLocal) error {
rtcpBuf := make([]byte, receiveMTU)
for {
if _, _, rtcpErr := sender.Read(rtcpBuf); rtcpErr != nil {
c.log.Error("failed to read rtcp", slog.String("err", rtcpErr.Error()))
if !errors.Is(rtcpErr, io.EOF) {
c.log.Error("failed to read rtcp", slog.String("err", rtcpErr.Error()))
}
return
}
}
Expand Down Expand Up @@ -113,16 +117,38 @@ func (c *Client) StartScreenShare(tracks []webrtc.TrackLocal) (*webrtc.RTPTransc

sender := trx.Sender()

go func() {
rtcpHandler := func(rid string) {
defer c.log.Debug("exiting RTCP handler")
var n int
var err error
rtcpBuf := make([]byte, receiveMTU)
for {
if _, _, rtcpErr := sender.Read(rtcpBuf); rtcpErr != nil {
c.log.Error("failed to read rtcp", slog.String("err", rtcpErr.Error()))
if rid != "" {
n, _, err = sender.ReadSimulcast(rtcpBuf, rid)
} else {
n, _, err = sender.Read(rtcpBuf)
}
if err != nil {
if !errors.Is(err, io.EOF) {
c.log.Error("failed to read RTCP packet", slog.String("err", err.Error()))
}
return
}
if pkts, err := rtcp.Unmarshal(rtcpBuf[:n]); err != nil {
c.log.Error("failed to unmarshal RTCP packet", slog.String("err", err.Error()))
} else {
c.emit(RTCSenderRTCPPacketEvent, map[string]any{
"pkts": pkts,
"rid": rid,
"sender": sender,
})
}
}
}()
}

for _, track := range tracks {
go rtcpHandler(track.RID())
}

return trx, nil
}
Expand Down
140 changes: 140 additions & 0 deletions client/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"
"time"

"github.com/pion/rtcp"
"github.com/pion/webrtc/v3"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -1024,3 +1025,142 @@ func TestAPIScreenShareAndVoice(t *testing.T) {

require.Greater(t, packets, 10)
}

func TestAPIScreenSharePLI(t *testing.T) {
th := setupTestHelper(t, "calls0")

nReceivers := 10

// Setup
userClients := make([]*Client, nReceivers)
userConnectChs := make([]chan struct{}, nReceivers)
userCloseChs := make([]chan struct{}, nReceivers)

for i := 0; i < len(userClients); i++ {
userConnectChs[i] = make(chan struct{})
userCloseChs[i] = make(chan struct{})

var err error
userClients[i], err = New(Config{
SiteURL: th.apiURL,
AuthToken: th.userAPIClient.AuthToken,
ChannelID: th.channels["calls0"].Id,
})
require.NoError(t, err)
require.NotNil(t, userClients[i])

client := userClients[i]
connectedCh := userConnectChs[i]
err = client.On(RTCConnectEvent, func(_ any) error {
close(connectedCh)
return nil
})
require.NoError(t, err)
}

adminConnectCh := make(chan struct{})
err := th.adminClient.On(RTCConnectEvent, func(_ any) error {
close(adminConnectCh)
return nil
})
require.NoError(t, err)

for i := 0; i < len(userClients); i++ {
go func(i int) {
err := userClients[i].Connect()
require.NoError(t, err)
}(i)
}

go func() {
err := th.adminClient.Connect()
require.NoError(t, err)
}()

for i := 0; i < len(userClients); i++ {
select {
case <-userConnectChs[i]:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for user connect event")
}
}

select {
case <-adminConnectCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for admin connect event")
}

adminCloseCh := make(chan struct{})

// Test logic

// admin screen shares, users should receive the track
adminScreenTrack := th.newScreenTrack(webrtc.MimeTypeVP8)
_, err = th.adminClient.StartScreenShare([]webrtc.TrackLocal{adminScreenTrack})
require.NoError(t, err)
go th.screenTrackWriter(adminScreenTrack, adminCloseCh)

var pliCount int
err = th.adminClient.On(RTCSenderRTCPPacketEvent, func(ctx any) error {
m := ctx.(map[string]any)
for _, pkt := range m["pkts"].([]rtcp.Packet) {
if _, ok := pkt.(*rtcp.PictureLossIndication); ok {
pliCount++
}
}
return nil
})
require.NoError(t, err)

time.Sleep(2 * time.Second)

err = th.adminClient.StopScreenShare()
require.NoError(t, err)

// Teardown

for i := 0; i < len(userClients); i++ {
client := userClients[i]
closeCh := userCloseChs[i]
err = client.On(CloseEvent, func(_ any) error {
close(closeCh)
return nil
})
require.NoError(t, err)
}

err = th.adminClient.On(CloseEvent, func(_ any) error {
close(adminCloseCh)
return nil
})
require.NoError(t, err)

for i := 0; i < len(userClients); i++ {
go func(i int) {
err := userClients[i].Close()
require.NoError(t, err)
}(i)
}

go func() {
err := th.adminClient.Close()
require.NoError(t, err)
}()

for i := 0; i < len(userClients); i++ {
select {
case <-userCloseChs[i]:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for close event")
}
}

select {
case <-adminCloseCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for close event")
}

require.Equal(t, 1, pliCount)
}
9 changes: 5 additions & 4 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ type EventHandler func(ctx any) error
type EventType string

const (
RTCConnectEvent EventType = "RTCConnect"
RTCDisconnectEvent EventType = "RTCDisconnect"
RTCTrackEvent EventType = "RTCTrack"
RTCConnectEvent EventType = "RTCConnect"
RTCDisconnectEvent EventType = "RTCDisconnect"
RTCTrackEvent EventType = "RTCTrack"
RTCSenderRTCPPacketEvent EventType = "RTCSenderRTCPPacket"

CloseEvent EventType = "Close"
ErrorEvent EventType = "Error"
Expand All @@ -47,7 +48,7 @@ const (

func (e EventType) IsValid() bool {
switch e {
case RTCConnectEvent, RTCDisconnectEvent, RTCTrackEvent,
case RTCConnectEvent, RTCDisconnectEvent, RTCTrackEvent, RTCSenderRTCPPacketEvent,
CloseEvent,
ErrorEvent,
WSConnectEvent, WSDisconnectEvent,
Expand Down
12 changes: 9 additions & 3 deletions client/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"log"
"log/slog"
"os"
"testing"
"time"
Expand Down Expand Up @@ -45,7 +46,7 @@ const (
userPass = "U$er-sample1"
teamName = "calls"
nChannels = 2
waitTimeout = 5 * time.Second
waitTimeout = 10 * time.Second
)

func (th *TestHelper) newScreenTrack(mimeType string) *webrtc.TrackLocalStaticRTP {
Expand Down Expand Up @@ -326,6 +327,11 @@ func setupTestHelper(tb testing.TB, channelName string) *TestHelper {
tb.Helper()
var err error

logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelDebug,
}))

th := &TestHelper{
tb: tb,
channels: make(map[string]*model.Channel),
Expand Down Expand Up @@ -354,15 +360,15 @@ func setupTestHelper(tb testing.TB, channelName string) *TestHelper {
SiteURL: th.apiURL,
AuthToken: th.adminAPIClient.AuthToken,
ChannelID: channelID,
})
}, WithLogger(logger))
require.NoError(tb, err)
require.NotNil(tb, th.adminClient)

th.userClient, err = New(Config{
SiteURL: th.apiURL,
AuthToken: th.userAPIClient.AuthToken,
ChannelID: channelID,
})
}, WithLogger(logger))
require.NoError(tb, err)
require.NotNil(tb, th.userClient)

Expand Down
11 changes: 10 additions & 1 deletion client/rtc.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"sync/atomic"

"github.com/pion/interceptor"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v3"
)

Expand Down Expand Up @@ -242,14 +243,22 @@ func (c *Client) initRTCSession() error {
return
}

if trackType == TrackTypeScreen {
c.log.Debug("sending PLI request for received screen track", slog.String("trackID", track.ID()), slog.Any("SSRC", track.SSRC()))
if err := pc.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}); err != nil {
c.log.Error("failed to write RTCP packet", slog.String("err", err.Error()))
}
}

c.mut.Lock()
c.receivers[sessionID] = append(c.receivers[sessionID], receiver)
c.mut.Unlock()

// RTCP handler
go func(rid string) {
var err error
rtcpBuf := make([]byte, receiveMTU)
for {
rtcpBuf := make([]byte, receiveMTU)
if rid != "" {
_, _, err = receiver.ReadSimulcast(rtcpBuf, rid)
} else {
Expand Down
6 changes: 3 additions & 3 deletions client/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ func (c *Client) handleWSEventHello(ev *model.WebSocketEvent) (isReconnect bool,
}

if connID != c.currentConnID {
c.log.Debug("new connection id from server")
c.log.Debug("new connection id from server", slog.String("connID", connID))
}

if c.originalConnID == "" {
c.log.Debug("initial ws connection")
c.log.Debug("initial ws connection", slog.String("originalConnID", connID))
c.originalConnID = connID
} else {
c.log.Debug("ws reconnected successfully")
c.log.Debug("ws reconnected successfully", slog.String("originalConnID", c.originalConnID))
c.wsLastDisconnect = time.Time{}
c.wsReconnectInterval = 0
isReconnect = true
Expand Down
2 changes: 2 additions & 0 deletions service/rtc/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"sync"

"github.com/pion/webrtc/v3"
"golang.org/x/time/rate"

"github.com/mattermost/mattermost/server/public/shared/mlog"
)
Expand All @@ -16,6 +17,7 @@ type call struct {
id string
sessions map[string]*session
screenSession *session
pliLimiters map[webrtc.SSRC]*rate.Limiter
metrics Metrics

mut sync.RWMutex
Expand Down
Loading

0 comments on commit c9ed0e1

Please sign in to comment.