From c9ed0e193fb57fe3b91ce9debdc66bdba1006341 Mon Sep 17 00:00:00 2001 From: Claudio Costa Date: Mon, 16 Sep 2024 16:49:20 -0600 Subject: [PATCH] [MM-60189] Rate limit forwarding PLI requests (#154) * Send PLI request on receiving screen track * Rate limit forwarding PLI requests * Tests --- client/api.go | 36 +++++++++-- client/api_test.go | 140 +++++++++++++++++++++++++++++++++++++++++ client/client.go | 9 +-- client/helper_test.go | 12 +++- client/rtc.go | 11 +++- client/websocket.go | 6 +- service/rtc/call.go | 2 + service/rtc/session.go | 41 +++++++++--- 8 files changed, 232 insertions(+), 25 deletions(-) diff --git a/client/api.go b/client/api.go index 61d7f8f..a8647e6 100644 --- a/client/api.go +++ b/client/api.go @@ -6,12 +6,14 @@ package client import ( "context" "encoding/json" + "errors" "fmt" "io" "log/slog" "net/http" "time" + "github.com/pion/rtcp" "github.com/pion/webrtc/v3" ) @@ -52,7 +54,9 @@ func (c *Client) Unmute(track webrtc.TrackLocal) error { rtcpBuf := make([]byte, receiveMTU) for { if _, _, rtcpErr := sender.Read(rtcpBuf); rtcpErr != nil { - c.log.Error("failed to read rtcp", slog.String("err", rtcpErr.Error())) + if !errors.Is(rtcpErr, io.EOF) { + c.log.Error("failed to read rtcp", slog.String("err", rtcpErr.Error())) + } return } } @@ -113,16 +117,38 @@ func (c *Client) StartScreenShare(tracks []webrtc.TrackLocal) (*webrtc.RTPTransc sender := trx.Sender() - go func() { + rtcpHandler := func(rid string) { defer c.log.Debug("exiting RTCP handler") + var n int + var err error rtcpBuf := make([]byte, receiveMTU) for { - if _, _, rtcpErr := sender.Read(rtcpBuf); rtcpErr != nil { - c.log.Error("failed to read rtcp", slog.String("err", rtcpErr.Error())) + if rid != "" { + n, _, err = sender.ReadSimulcast(rtcpBuf, rid) + } else { + n, _, err = sender.Read(rtcpBuf) + } + if err != nil { + if !errors.Is(err, io.EOF) { + c.log.Error("failed to read RTCP packet", slog.String("err", err.Error())) + } return } + if pkts, err := rtcp.Unmarshal(rtcpBuf[:n]); err != nil { + c.log.Error("failed to unmarshal RTCP packet", slog.String("err", err.Error())) + } else { + c.emit(RTCSenderRTCPPacketEvent, map[string]any{ + "pkts": pkts, + "rid": rid, + "sender": sender, + }) + } } - }() + } + + for _, track := range tracks { + go rtcpHandler(track.RID()) + } return trx, nil } diff --git a/client/api_test.go b/client/api_test.go index 3fa6564..7bdaba5 100644 --- a/client/api_test.go +++ b/client/api_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "github.com/stretchr/testify/require" @@ -1024,3 +1025,142 @@ func TestAPIScreenShareAndVoice(t *testing.T) { require.Greater(t, packets, 10) } + +func TestAPIScreenSharePLI(t *testing.T) { + th := setupTestHelper(t, "calls0") + + nReceivers := 10 + + // Setup + userClients := make([]*Client, nReceivers) + userConnectChs := make([]chan struct{}, nReceivers) + userCloseChs := make([]chan struct{}, nReceivers) + + for i := 0; i < len(userClients); i++ { + userConnectChs[i] = make(chan struct{}) + userCloseChs[i] = make(chan struct{}) + + var err error + userClients[i], err = New(Config{ + SiteURL: th.apiURL, + AuthToken: th.userAPIClient.AuthToken, + ChannelID: th.channels["calls0"].Id, + }) + require.NoError(t, err) + require.NotNil(t, userClients[i]) + + client := userClients[i] + connectedCh := userConnectChs[i] + err = client.On(RTCConnectEvent, func(_ any) error { + close(connectedCh) + return nil + }) + require.NoError(t, err) + } + + adminConnectCh := make(chan struct{}) + err := th.adminClient.On(RTCConnectEvent, func(_ any) error { + close(adminConnectCh) + return nil + }) + require.NoError(t, err) + + for i := 0; i < len(userClients); i++ { + go func(i int) { + err := userClients[i].Connect() + require.NoError(t, err) + }(i) + } + + go func() { + err := th.adminClient.Connect() + require.NoError(t, err) + }() + + for i := 0; i < len(userClients); i++ { + select { + case <-userConnectChs[i]: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for user connect event") + } + } + + select { + case <-adminConnectCh: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for admin connect event") + } + + adminCloseCh := make(chan struct{}) + + // Test logic + + // admin screen shares, users should receive the track + adminScreenTrack := th.newScreenTrack(webrtc.MimeTypeVP8) + _, err = th.adminClient.StartScreenShare([]webrtc.TrackLocal{adminScreenTrack}) + require.NoError(t, err) + go th.screenTrackWriter(adminScreenTrack, adminCloseCh) + + var pliCount int + err = th.adminClient.On(RTCSenderRTCPPacketEvent, func(ctx any) error { + m := ctx.(map[string]any) + for _, pkt := range m["pkts"].([]rtcp.Packet) { + if _, ok := pkt.(*rtcp.PictureLossIndication); ok { + pliCount++ + } + } + return nil + }) + require.NoError(t, err) + + time.Sleep(2 * time.Second) + + err = th.adminClient.StopScreenShare() + require.NoError(t, err) + + // Teardown + + for i := 0; i < len(userClients); i++ { + client := userClients[i] + closeCh := userCloseChs[i] + err = client.On(CloseEvent, func(_ any) error { + close(closeCh) + return nil + }) + require.NoError(t, err) + } + + err = th.adminClient.On(CloseEvent, func(_ any) error { + close(adminCloseCh) + return nil + }) + require.NoError(t, err) + + for i := 0; i < len(userClients); i++ { + go func(i int) { + err := userClients[i].Close() + require.NoError(t, err) + }(i) + } + + go func() { + err := th.adminClient.Close() + require.NoError(t, err) + }() + + for i := 0; i < len(userClients); i++ { + select { + case <-userCloseChs[i]: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for close event") + } + } + + select { + case <-adminCloseCh: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for close event") + } + + require.Equal(t, 1, pliCount) +} diff --git a/client/client.go b/client/client.go index 7963c88..21627f3 100644 --- a/client/client.go +++ b/client/client.go @@ -23,9 +23,10 @@ type EventHandler func(ctx any) error type EventType string const ( - RTCConnectEvent EventType = "RTCConnect" - RTCDisconnectEvent EventType = "RTCDisconnect" - RTCTrackEvent EventType = "RTCTrack" + RTCConnectEvent EventType = "RTCConnect" + RTCDisconnectEvent EventType = "RTCDisconnect" + RTCTrackEvent EventType = "RTCTrack" + RTCSenderRTCPPacketEvent EventType = "RTCSenderRTCPPacket" CloseEvent EventType = "Close" ErrorEvent EventType = "Error" @@ -47,7 +48,7 @@ const ( func (e EventType) IsValid() bool { switch e { - case RTCConnectEvent, RTCDisconnectEvent, RTCTrackEvent, + case RTCConnectEvent, RTCDisconnectEvent, RTCTrackEvent, RTCSenderRTCPPacketEvent, CloseEvent, ErrorEvent, WSConnectEvent, WSDisconnectEvent, diff --git a/client/helper_test.go b/client/helper_test.go index 7e87eca..dac4367 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "log" + "log/slog" "os" "testing" "time" @@ -45,7 +46,7 @@ const ( userPass = "U$er-sample1" teamName = "calls" nChannels = 2 - waitTimeout = 5 * time.Second + waitTimeout = 10 * time.Second ) func (th *TestHelper) newScreenTrack(mimeType string) *webrtc.TrackLocalStaticRTP { @@ -326,6 +327,11 @@ func setupTestHelper(tb testing.TB, channelName string) *TestHelper { tb.Helper() var err error + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + })) + th := &TestHelper{ tb: tb, channels: make(map[string]*model.Channel), @@ -354,7 +360,7 @@ func setupTestHelper(tb testing.TB, channelName string) *TestHelper { SiteURL: th.apiURL, AuthToken: th.adminAPIClient.AuthToken, ChannelID: channelID, - }) + }, WithLogger(logger)) require.NoError(tb, err) require.NotNil(tb, th.adminClient) @@ -362,7 +368,7 @@ func setupTestHelper(tb testing.TB, channelName string) *TestHelper { SiteURL: th.apiURL, AuthToken: th.userAPIClient.AuthToken, ChannelID: channelID, - }) + }, WithLogger(logger)) require.NoError(tb, err) require.NotNil(tb, th.userClient) diff --git a/client/rtc.go b/client/rtc.go index cc59d7c..4514512 100644 --- a/client/rtc.go +++ b/client/rtc.go @@ -14,6 +14,7 @@ import ( "sync/atomic" "github.com/pion/interceptor" + "github.com/pion/rtcp" "github.com/pion/webrtc/v3" ) @@ -242,14 +243,22 @@ func (c *Client) initRTCSession() error { return } + if trackType == TrackTypeScreen { + c.log.Debug("sending PLI request for received screen track", slog.String("trackID", track.ID()), slog.Any("SSRC", track.SSRC())) + if err := pc.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}); err != nil { + c.log.Error("failed to write RTCP packet", slog.String("err", err.Error())) + } + } + c.mut.Lock() c.receivers[sessionID] = append(c.receivers[sessionID], receiver) c.mut.Unlock() // RTCP handler go func(rid string) { + var err error + rtcpBuf := make([]byte, receiveMTU) for { - rtcpBuf := make([]byte, receiveMTU) if rid != "" { _, _, err = receiver.ReadSimulcast(rtcpBuf, rid) } else { diff --git a/client/websocket.go b/client/websocket.go index 5953d4c..e5c60f9 100644 --- a/client/websocket.go +++ b/client/websocket.go @@ -111,14 +111,14 @@ func (c *Client) handleWSEventHello(ev *model.WebSocketEvent) (isReconnect bool, } if connID != c.currentConnID { - c.log.Debug("new connection id from server") + c.log.Debug("new connection id from server", slog.String("connID", connID)) } if c.originalConnID == "" { - c.log.Debug("initial ws connection") + c.log.Debug("initial ws connection", slog.String("originalConnID", connID)) c.originalConnID = connID } else { - c.log.Debug("ws reconnected successfully") + c.log.Debug("ws reconnected successfully", slog.String("originalConnID", c.originalConnID)) c.wsLastDisconnect = time.Time{} c.wsReconnectInterval = 0 isReconnect = true diff --git a/service/rtc/call.go b/service/rtc/call.go index 36e7b57..511c246 100644 --- a/service/rtc/call.go +++ b/service/rtc/call.go @@ -8,6 +8,7 @@ import ( "sync" "github.com/pion/webrtc/v3" + "golang.org/x/time/rate" "github.com/mattermost/mattermost/server/public/shared/mlog" ) @@ -16,6 +17,7 @@ type call struct { id string sessions map[string]*session screenSession *session + pliLimiters map[webrtc.SSRC]*rate.Limiter metrics Metrics mut sync.RWMutex diff --git a/service/rtc/session.go b/service/rtc/session.go index 22c458d..badf67f 100644 --- a/service/rtc/session.go +++ b/service/rtc/session.go @@ -11,6 +11,8 @@ import ( "sync" "time" + "golang.org/x/time/rate" + "github.com/mattermost/rtcd/service/rtc/vad" "github.com/pion/interceptor/pkg/cc" @@ -89,9 +91,10 @@ func (s *Server) addSession(cfg SessionConfig, peerConn *webrtc.PeerConnection, if c == nil { // call is missing, creating one c = &call{ - id: cfg.CallID, - sessions: map[string]*session{}, - metrics: s.metrics, + id: cfg.CallID, + sessions: map[string]*session{}, + pliLimiters: map[webrtc.SSRC]*rate.Limiter{}, + metrics: s.metrics, } g.calls[c.id] = c } @@ -284,7 +287,14 @@ func (s *session) handleSenderRTCP(sender *webrtc.RTPSender) { return } for _, pkt := range pkts { - if _, ok := pkt.(*rtcp.PictureLossIndication); ok { + if p, ok := pkt.(*rtcp.PictureLossIndication); ok { + // When a PLI is received the request is forwarded + // to the peer generating the track (e.g. presenter). + + for _, dstSSRC := range p.DestinationSSRC() { + s.log.Debug("received PLI request for track", mlog.String("sessionID", s.cfg.SessionID), mlog.Uint("SSRC", dstSSRC)) + } + screenSession := s.call.getScreenSession() if screenSession == nil { s.log.Error("screenSession should not be nil", mlog.String("sessionID", s.cfg.SessionID)) @@ -308,11 +318,24 @@ func (s *session) handleSenderRTCP(sender *webrtc.RTPSender) { return } - // When a PLI is received the request is forwarded - // to the peer generating the track (e.g. presenter). - if err := screenSession.rtcConn.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(screenTrack.SSRC())}}); err != nil { - s.log.Error("failed to write RTCP packet", mlog.Err(err), mlog.String("sessionID", s.cfg.SessionID)) - return + s.call.mut.Lock() + // We allow at most one PLI request per second for a given SSRC to avoid overloading the sender. + // If a receiving client were to miss it due to rate limiting (e.g. joining right in the second of backoff), + // it will request it again and eventually get it. + limiter, ok := s.call.pliLimiters[screenTrack.SSRC()] + if !ok { + s.log.Debug("creating new PLI limiter for track", mlog.Uint("SSRC", screenTrack.SSRC())) + limiter = rate.NewLimiter(1, 1) + s.call.pliLimiters[screenTrack.SSRC()] = limiter + } + s.call.mut.Unlock() + + if limiter.Allow() { + s.log.Debug("forwarding PLI request for track", mlog.String("sessionID", s.cfg.SessionID), mlog.Uint("SSRC", screenTrack.SSRC())) + if err := screenSession.rtcConn.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(screenTrack.SSRC())}}); err != nil { + s.log.Error("failed to write RTCP packet", mlog.Err(err), mlog.String("sessionID", s.cfg.SessionID)) + return + } } } }