From 33bef20cc0e0939f4b361eabfd72ec7820ed7978 Mon Sep 17 00:00:00 2001 From: Claudio Costa Date: Mon, 31 Oct 2022 17:26:42 +0100 Subject: [PATCH] Simplify handling of voice tracks (#77) --- service/rtc/call.go | 16 +++--- service/rtc/server.go | 22 ++++++-- service/rtc/session.go | 38 ++++--------- service/rtc/sfu.go | 124 ++++++++++++++++------------------------- service/rtc/utils.go | 6 ++ 5 files changed, 90 insertions(+), 116 deletions(-) diff --git a/service/rtc/call.go b/service/rtc/call.go index 6eab0da..e095b5d 100644 --- a/service/rtc/call.go +++ b/service/rtc/call.go @@ -31,15 +31,13 @@ func (c *call) addSession(cfg SessionConfig, rtcConn *webrtc.PeerConnection, clo } s := &session{ - cfg: cfg, - rtcConn: rtcConn, - iceInCh: make(chan []byte, signalChSize*2), - sdpInCh: make(chan []byte, signalChSize), - closeCh: make(chan struct{}), - closeCb: closeCb, - tracksCh: make(chan *webrtc.TrackLocalStaticRTP, tracksChSize), - trackEnableCh: make(chan bool, tracksChSize), - rtpSendersMap: make(map[*webrtc.TrackLocalStaticRTP]*webrtc.RTPSender), + cfg: cfg, + rtcConn: rtcConn, + iceInCh: make(chan []byte, signalChSize*2), + sdpInCh: make(chan []byte, signalChSize), + closeCh: make(chan struct{}), + closeCb: closeCb, + tracksCh: make(chan *webrtc.TrackLocalStaticRTP, tracksChSize), } c.sessions[cfg.SessionID] = s diff --git a/service/rtc/server.go b/service/rtc/server.go index 28b0d7e..8cb5862 100644 --- a/service/rtc/server.go +++ b/service/rtc/server.go @@ -279,11 +279,25 @@ func (s *Server) msgReader() { } call.mut.Unlock() case MuteMessage, UnmuteMessage: - select { - case session.trackEnableCh <- (msg.Type == MuteMessage): - default: - s.log.Error("failed to send track enable message: channel is full") + session.mut.RLock() + track := session.outVoiceTrack + session.mut.RUnlock() + if track == nil { + continue + } + + var enabled bool + if msg.Type == UnmuteMessage { + enabled = true } + + s.log.Debug("setting voice track state", + mlog.Bool("enabled", enabled), + mlog.String("sessionID", session.cfg.SessionID)) + + session.mut.Lock() + session.outVoiceTrackEnabled = enabled + session.mut.Unlock() default: s.log.Error("received unexpected message type") } diff --git a/service/rtc/session.go b/service/rtc/session.go index df21195..71a6ab5 100644 --- a/service/rtc/session.go +++ b/service/rtc/session.go @@ -9,8 +9,6 @@ import ( "sync" "time" - "github.com/mattermost/rtcd/service/random" - "github.com/pion/rtcp" "github.com/pion/webrtc/v3" @@ -35,8 +33,6 @@ type session struct { remoteScreenTrack *webrtc.TrackRemote rtcConn *webrtc.PeerConnection tracksCh chan *webrtc.TrackLocalStaticRTP - trackEnableCh chan bool - rtpSendersMap map[*webrtc.TrackLocalStaticRTP]*webrtc.RTPSender iceInCh chan []byte sdpInCh chan []byte @@ -113,7 +109,7 @@ func (s *session) handleICE(log mlog.LoggerIFace, m Metrics) { var candidate webrtc.ICECandidateInit if err := json.Unmarshal(data, &candidate); err != nil { - log.Error("failed to encode ice candidate", mlog.Err(err)) + log.Error("failed to encode ice candidate", mlog.Err(err), mlog.String("sessionID", s.cfg.SessionID)) continue } @@ -121,10 +117,10 @@ func (s *session) handleICE(log mlog.LoggerIFace, m Metrics) { continue } - log.Debug("setting ICE candidate for remote") + log.Debug("setting ICE candidate for remote", mlog.String("sessionID", s.cfg.SessionID)) if err := s.rtcConn.AddICECandidate(candidate); err != nil { - log.Error("failed to add ice candidate", mlog.Err(err)) + log.Error("failed to add ice candidate", mlog.Err(err), mlog.String("sessionID", s.cfg.SessionID)) m.IncRTCErrors(s.cfg.GroupID, "ice") continue } @@ -141,25 +137,26 @@ func (s *session) handlePLI(log mlog.LoggerIFace, call *call, sender *webrtc.RTP for { pkts, _, err := sender.ReadRTCP() if err != nil { - log.Error("failed to read RTCP packet", mlog.Err(err)) + log.Error("failed to read RTCP packet", + mlog.Err(err), mlog.String("sessionID", s.cfg.SessionID)) return } for _, pkt := range pkts { if _, ok := pkt.(*rtcp.PictureLossIndication); ok { screenSession := call.getScreenSession() if screenSession == nil { - log.Error("screenSession should not be nil") + log.Error("screenSession should not be nil", mlog.String("sessionID", s.cfg.SessionID)) return } screenTrack := screenSession.getRemoteScreenTrack() if screenTrack == nil { - log.Error("screenTrack should not be nil") + log.Error("screenTrack should not be nil", mlog.String("sessionID", s.cfg.SessionID)) return } if err := screenSession.rtcConn.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(screenTrack.SSRC())}}); err != nil { - log.Error("failed to write RTCP packet", mlog.Err(err)) + log.Error("failed to write RTCP packet", mlog.Err(err), mlog.String("sessionID", s.cfg.SessionID)) return } } @@ -168,20 +165,11 @@ func (s *session) handlePLI(log mlog.LoggerIFace, call *call, sender *webrtc.RTP } // addTrack adds the given track to the peer and starts negotiation. -func (s *session) addTrack(log mlog.LoggerIFace, c *call, sdpOutCh chan<- Message, track *webrtc.TrackLocalStaticRTP, enabled bool) error { - t := track - if !enabled { - dummyTrack, err := webrtc.NewTrackLocalStaticRTP(rtpAudioCodec, "voice", random.NewID()) - if err != nil { - return fmt.Errorf("failed to create new static track: %w", err) - } - t = dummyTrack - } - - sender, err := s.rtcConn.AddTrack(t) +func (s *session) addTrack(log mlog.LoggerIFace, c *call, sdpOutCh chan<- Message, track *webrtc.TrackLocalStaticRTP) error { + sender, err := s.rtcConn.AddTrack(track) if err != nil { return fmt.Errorf("failed to add track: %w", err) - } else if t.Kind() == webrtc.RTPCodecTypeVideo { + } else if track.Kind() == webrtc.RTPCodecTypeVideo { go s.handlePLI(log, c, sender) } @@ -223,10 +211,6 @@ func (s *session) addTrack(log mlog.LoggerIFace, c *call, sdpOutCh chan<- Messag return fmt.Errorf("failed to set remote description: %w", err) } - s.mut.Lock() - s.rtpSendersMap[track] = sender - s.mut.Unlock() - return nil } diff --git a/service/rtc/sfu.go b/service/rtc/sfu.go index 15ff771..854d518 100644 --- a/service/rtc/sfu.go +++ b/service/rtc/sfu.go @@ -164,19 +164,19 @@ func (s *Server) InitSession(cfg SessionConfig, closeCb func() error) error { } msg, err := newICEMessage(us, candidate) if err != nil { - s.log.Error("failed to create ICE message", mlog.Err(err)) + s.log.Error("failed to create ICE message", mlog.Err(err), mlog.String("sessionID", cfg.SessionID)) return } select { case s.receiveCh <- msg: default: - s.log.Error("failed to send ICE message: channel is full") + s.log.Error("failed to send ICE message: channel is full", mlog.String("sessionID", cfg.SessionID)) } }) peerConn.OnICEGatheringStateChange(func(state webrtc.ICEGathererState) { if state == webrtc.ICEGathererStateComplete { - s.log.Debug("ice gathering complete") + s.log.Debug("ice gathering complete", mlog.String("sessionID", cfg.SessionID)) } }) @@ -213,27 +213,34 @@ func (s *Server) InitSession(cfg SessionConfig, closeCb func() error) error { }) peerConn.OnTrack(func(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { - s.log.Debug("Got remote track!!!") - s.log.Debug(fmt.Sprintf("%+v", remoteTrack.Codec().RTPCodecCapability)) - s.log.Debug(fmt.Sprintf("Track has started, of type %d: %s", remoteTrack.PayloadType(), remoteTrack.Codec().MimeType)) - streamID := remoteTrack.StreamID() + trackType := remoteTrack.Codec().MimeType + + s.log.Debug("new track received", + mlog.Any("codec", remoteTrack.Codec().RTPCodecCapability), + mlog.Int("payload", int(remoteTrack.PayloadType())), + mlog.String("type", trackType), + mlog.String("streamID", streamID), + mlog.String("remoteTrackID", remoteTrack.ID()), + mlog.Int("SSRC", int(remoteTrack.SSRC())), + mlog.String("sessionID", us.cfg.SessionID), + ) var screenStreamID string if screenSession := call.getScreenSession(); screenSession != nil { screenStreamID = screenSession.getScreenStreamID() } - if remoteTrack.Codec().MimeType == rtpAudioCodec.MimeType { + if trackType == rtpAudioCodec.MimeType { trackType := "voice" if streamID == screenStreamID { - s.log.Debug("received screen sharing audio track") + s.log.Debug("received screen sharing audio track", mlog.String("sessionID", us.cfg.SessionID)) trackType = "screen-audio" } - outAudioTrack, err := webrtc.NewTrackLocalStaticRTP(rtpAudioCodec, trackType, random.NewID()) + outAudioTrack, err := webrtc.NewTrackLocalStaticRTP(rtpAudioCodec, genTrackID(trackType, us.cfg.SessionID), random.NewID()) if err != nil { - s.log.Error("failed to create local track", mlog.Err(err)) + s.log.Error("failed to create local track", mlog.Err(err), mlog.String("sessionID", us.cfg.SessionID)) return } @@ -262,14 +269,16 @@ func (s *Server) InitSession(cfg SessionConfig, closeCb func() error) error { buf := s.bufPool.Get().([]byte) i, _, err := remoteTrack.Read(buf) if err != nil { - s.log.Error("failed to read RTP packet", mlog.Err(err)) + s.log.Error("failed to read RTP packet", + mlog.Err(err), mlog.String("sessionID", us.cfg.SessionID)) s.metrics.IncRTCErrors(us.cfg.GroupID, "rtp") return } rtp := &rtp.Packet{} if err := rtp.Unmarshal(buf[:i]); err != nil { - s.log.Error("failed to unmarshal RTP packet", mlog.Err(err)) + s.log.Error("failed to unmarshal RTP packet", + mlog.Err(err), mlog.String("sessionID", us.cfg.SessionID)) s.metrics.IncRTCErrors(us.cfg.GroupID, "rtp") return } @@ -282,12 +291,14 @@ func (s *Server) InitSession(cfg SessionConfig, closeCb func() error) error { isEnabled := us.outVoiceTrackEnabled us.mut.RUnlock() if !isEnabled { + s.bufPool.Put(buf) continue } } if err := outAudioTrack.WriteRTP(rtp); err != nil && !errors.Is(err, io.ErrClosedPipe) { - s.log.Error("failed to write RTP packet", mlog.Err(err)) + s.log.Error("failed to write RTP packet", + mlog.Err(err), mlog.String("sessionID", us.cfg.SessionID)) s.metrics.IncRTCErrors(us.cfg.GroupID, "rtp") return } @@ -302,17 +313,19 @@ func (s *Server) InitSession(cfg SessionConfig, closeCb func() error) error { s.metrics.AddRTPPacketBytes("out", trackType, pLen) }) } - } else if remoteTrack.Codec().MimeType == rtpVideoCodecVP8.MimeType { + } else if trackType == rtpVideoCodecVP8.MimeType { if screenStreamID != "" && screenStreamID != streamID { - s.log.Error("received unexpected video track", mlog.String("streamID", streamID)) + s.log.Error("received unexpected video track", + mlog.String("streamID", streamID), mlog.String("sessionID", us.cfg.SessionID)) return } - s.log.Debug("received screen sharing stream", mlog.String("streamID", streamID)) + s.log.Debug("received screen sharing stream", mlog.String("streamID", streamID), mlog.String("sessionID", us.cfg.SessionID)) - outScreenTrack, err := webrtc.NewTrackLocalStaticRTP(rtpVideoCodecVP8, "screen", random.NewID()) + outScreenTrack, err := webrtc.NewTrackLocalStaticRTP(rtpVideoCodecVP8, genTrackID("screen", us.cfg.SessionID), random.NewID()) if err != nil { - s.log.Error("failed to create local track", mlog.Err(err)) + s.log.Error("failed to create local track", + mlog.Err(err), mlog.String("sessionID", us.cfg.SessionID)) return } us.mut.Lock() @@ -328,14 +341,19 @@ func (s *Server) InitSession(cfg SessionConfig, closeCb func() error) error { case ss.tracksCh <- outScreenTrack: default: s.log.Error("failed to send screen track: channel is full", - mlog.String("UserID", us.cfg.UserID), mlog.String("trackUserID", ss.cfg.UserID)) + mlog.String("UserID", us.cfg.UserID), + mlog.String("sessionID", us.cfg.SessionID), + mlog.String("trackUserID", ss.cfg.UserID), + mlog.String("trackSessionID", ss.cfg.SessionID), + ) } }) for { rtp, _, readErr := remoteTrack.ReadRTP() if readErr != nil { - s.log.Error("failed to read RTP packet", mlog.Err(readErr)) + s.log.Error("failed to read RTP packet", + mlog.Err(readErr), mlog.String("sessionID", us.cfg.SessionID)) s.metrics.IncRTCErrors(us.cfg.GroupID, "rtp") return } @@ -344,7 +362,8 @@ func (s *Server) InitSession(cfg SessionConfig, closeCb func() error) error { s.metrics.AddRTPPacketBytes("in", "screen", len(rtp.Payload)) if err := outScreenTrack.WriteRTP(rtp); err != nil && !errors.Is(err, io.ErrClosedPipe) { - s.log.Error("failed to write RTP packet", mlog.Err(err)) + s.log.Error("failed to write RTP packet", + mlog.Err(err), mlog.String("sessionID", us.cfg.SessionID)) s.metrics.IncRTCErrors(us.cfg.GroupID, "rtp") return } @@ -392,7 +411,7 @@ func (s *Server) InitSession(cfg SessionConfig, closeCb func() error) error { go func() { if err := s.handleTracks(call, us); err != nil { - s.log.Error("handleTracks failed", mlog.Err(err)) + s.log.Error("handleTracks failed", mlog.Err(err), mlog.String("sessionID", us.cfg.SessionID)) } }() }() @@ -469,25 +488,24 @@ func (s *Server) handleTracks(call *call, us *session) error { ss.mut.RLock() outVoiceTrack := ss.outVoiceTrack - isEnabled := ss.outVoiceTrackEnabled outScreenTrack := ss.outScreenTrack outScreenAudioTrack := ss.outScreenAudioTrack ss.mut.RUnlock() if outVoiceTrack != nil { - if err := us.addTrack(s.log, call, s.receiveCh, outVoiceTrack, isEnabled); err != nil { + if err := us.addTrack(s.log, call, s.receiveCh, outVoiceTrack); err != nil { s.metrics.IncRTCErrors(us.cfg.GroupID, "track") s.log.Error("failed to add voice track", mlog.Err(err), mlog.String("sessionID", us.cfg.SessionID)) } } if outScreenTrack != nil { - if err := us.addTrack(s.log, call, s.receiveCh, outScreenTrack, true); err != nil { + if err := us.addTrack(s.log, call, s.receiveCh, outScreenTrack); err != nil { s.metrics.IncRTCErrors(us.cfg.GroupID, "track") s.log.Error("failed to add screen track", mlog.Err(err), mlog.String("sessionID", us.cfg.SessionID)) } } if outScreenAudioTrack != nil { - if err := us.addTrack(s.log, call, s.receiveCh, outScreenAudioTrack, true); err != nil { + if err := us.addTrack(s.log, call, s.receiveCh, outScreenAudioTrack); err != nil { s.metrics.IncRTCErrors(us.cfg.GroupID, "track") s.log.Error("failed to add screen audio track", mlog.Err(err), mlog.String("sessionID", us.cfg.SessionID)) } @@ -501,7 +519,7 @@ func (s *Server) handleTracks(call *call, us *session) error { if !ok { return nil } - if err := us.addTrack(s.log, call, s.receiveCh, track, true); err != nil { + if err := us.addTrack(s.log, call, s.receiveCh, track); err != nil { s.metrics.IncRTCErrors(us.cfg.GroupID, "track") s.log.Error("failed to add track", mlog.Err(err), mlog.String("sessionID", us.cfg.SessionID)) continue @@ -514,61 +532,15 @@ func (s *Server) handleTracks(call *call, us *session) error { sdp, err := us.signaling(msg) if err != nil { s.metrics.IncRTCErrors(us.cfg.GroupID, "signaling") - s.log.Error("failed to signal", mlog.Err(err)) + s.log.Error("failed to signal", mlog.Err(err), mlog.String("sessionID", us.cfg.SessionID)) continue } select { case s.receiveCh <- newMessage(us, SDPMessage, sdp): default: - s.log.Error("failed to send SDP message: channel is full") - } - case muted, ok := <-us.trackEnableCh: - if !ok { - return nil - } - - us.mut.RLock() - track := us.outVoiceTrack - us.mut.RUnlock() - - if track == nil { - continue - } - - us.mut.Lock() - us.outVoiceTrackEnabled = !muted - us.mut.Unlock() - - dummyTrack, err := webrtc.NewTrackLocalStaticRTP(rtpAudioCodec, "voice", random.NewID()) - if err != nil { - s.log.Error("failed to create local track", mlog.Err(err)) - continue + s.log.Error("failed to send SDP message: channel is full", mlog.String("sessionID", us.cfg.SessionID)) } - - call.iterSessions(func(ss *session) { - if ss.cfg.UserID == us.cfg.UserID { - return - } - - ss.mut.RLock() - sender := ss.rtpSendersMap[track] - ss.mut.RUnlock() - - var replacingTrack *webrtc.TrackLocalStaticRTP - if muted { - replacingTrack = dummyTrack - } else { - replacingTrack = track - } - - if sender != nil { - s.log.Debug("replacing track on sender") - if err := sender.ReplaceTrack(replacingTrack); err != nil { - s.log.Error("failed to replace track", mlog.Err(err), mlog.String("sessionID", ss.cfg.SessionID)) - } - } - }) case <-us.closeCh: return nil } diff --git a/service/rtc/utils.go b/service/rtc/utils.go index c4e41f3..b379a39 100644 --- a/service/rtc/utils.go +++ b/service/rtc/utils.go @@ -8,6 +8,8 @@ import ( "fmt" "net" "time" + + "github.com/mattermost/rtcd/service/random" ) func resolveHost(host string, timeout time.Duration) (string, error) { @@ -24,3 +26,7 @@ func resolveHost(host string, timeout time.Duration) (string, error) { } return ip, err } + +func genTrackID(trackType, baseID string) string { + return trackType + "_" + baseID + "_" + random.NewID()[0:8] +}