From c12addb2a8e7b481eee13967105c077af2585640 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Thu, 7 Nov 2024 14:33:02 +0100 Subject: [PATCH 1/5] Notify remote to stop publishing when last local subscriber is closed. --- mcu_common.go | 3 +- mcu_janus.go | 2 + mcu_janus_publisher.go | 10 ++-- mcu_janus_remote_publisher.go | 6 ++ mcu_proxy.go | 2 +- mcu_proxy_test.go | 104 ++++++++++++++++++++++++++++++++++ mcu_test.go | 2 +- proxy/proxy_server.go | 54 +++++++++++++++++- proxy/proxy_server_test.go | 2 +- 9 files changed, 175 insertions(+), 10 deletions(-) diff --git a/mcu_common.go b/mcu_common.go index c5ce443c..af22d4af 100644 --- a/mcu_common.go +++ b/mcu_common.go @@ -166,6 +166,7 @@ type RemotePublisherController interface { PublisherId() string StartPublishing(ctx context.Context, publisher McuRemotePublisherProperties) error + StopPublishing(ctx context.Context, publisher McuRemotePublisherProperties) error GetStreams(ctx context.Context) ([]PublisherStream, error) } @@ -214,7 +215,7 @@ type McuPublisher interface { GetStreams(ctx context.Context) ([]PublisherStream, error) PublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error - UnpublishRemote(ctx context.Context, remoteId string) error + UnpublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error } type McuSubscriber interface { diff --git a/mcu_janus.go b/mcu_janus.go index fbf53c72..e25adca4 100644 --- a/mcu_janus.go +++ b/mcu_janus.go @@ -785,6 +785,8 @@ func (m *mcuJanus) getOrCreateRemotePublisher(ctx context.Context, controller Re settings: settings, }, + controller: controller, + port: int(port), rtcpPort: int(rtcp_port), } diff --git a/mcu_janus_publisher.go b/mcu_janus_publisher.go index a30b9a6a..9e82d80c 100644 --- a/mcu_janus_publisher.go +++ b/mcu_janus_publisher.go @@ -380,8 +380,8 @@ func (p *mcuJanusPublisher) GetStreams(ctx context.Context) ([]PublisherStream, return streams, nil } -func getPublisherRemoteId(id string, remoteId string) string { - return fmt.Sprintf("%s@%s", id, remoteId) +func getPublisherRemoteId(id string, remoteId string, hostname string, port int, rtcpPort int) string { + return fmt.Sprintf("%s-%s@%s:%d:%d", id, remoteId, hostname, port, rtcpPort) } func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error { @@ -389,7 +389,7 @@ func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, remoteId string, "request": "publish_remotely", "room": p.roomId, "publisher_id": streamTypeUserIds[p.streamType], - "remote_id": getPublisherRemoteId(p.id, remoteId), + "remote_id": getPublisherRemoteId(p.id, remoteId, hostname, port, rtcpPort), "host": hostname, "port": port, "rtcp_port": rtcpPort, @@ -421,12 +421,12 @@ func (p *mcuJanusPublisher) PublishRemote(ctx context.Context, remoteId string, return nil } -func (p *mcuJanusPublisher) UnpublishRemote(ctx context.Context, remoteId string) error { +func (p *mcuJanusPublisher) UnpublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error { msg := map[string]interface{}{ "request": "unpublish_remotely", "room": p.roomId, "publisher_id": streamTypeUserIds[p.streamType], - "remote_id": getPublisherRemoteId(p.id, remoteId), + "remote_id": getPublisherRemoteId(p.id, remoteId, hostname, port, rtcpPort), } response, err := p.handle.Request(ctx, msg) if err != nil { diff --git a/mcu_janus_remote_publisher.go b/mcu_janus_remote_publisher.go index f834d2ba..9a3575bb 100644 --- a/mcu_janus_remote_publisher.go +++ b/mcu_janus_remote_publisher.go @@ -34,6 +34,8 @@ type mcuJanusRemotePublisher struct { ref atomic.Int64 + controller RemotePublisherController + port int rtcpPort int } @@ -116,6 +118,10 @@ func (p *mcuJanusRemotePublisher) Close(ctx context.Context) { return } + if err := p.controller.StopPublishing(ctx, p); err != nil { + log.Printf("Error stopping remote publisher %s in room %d: %s", p.id, p.roomId, err) + } + p.mu.Lock() if handle := p.handle; handle != nil { response, err := p.handle.Request(ctx, map[string]interface{}{ diff --git a/mcu_proxy.go b/mcu_proxy.go index 0abbaa0c..0d8c5375 100644 --- a/mcu_proxy.go +++ b/mcu_proxy.go @@ -227,7 +227,7 @@ func (p *mcuProxyPublisher) PublishRemote(ctx context.Context, remoteId string, return errors.New("remote publishing not supported for proxy publishers") } -func (p *mcuProxyPublisher) UnpublishRemote(ctx context.Context, remoteId string) error { +func (p *mcuProxyPublisher) UnpublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error { return errors.New("remote publishing not supported for proxy publishers") } diff --git a/mcu_proxy_test.go b/mcu_proxy_test.go index c0f82be1..73c3260a 100644 --- a/mcu_proxy_test.go +++ b/mcu_proxy_test.go @@ -1502,6 +1502,110 @@ func Test_ProxyRemotePublisher(t *testing.T) { defer sub.Close(context.Background()) } +func Test_ProxyMultipleRemotePublisher(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + + etcd := NewEtcdForTest(t) + + grpcServer1, addr1 := NewGrpcServerForTest(t) + grpcServer2, addr2 := NewGrpcServerForTest(t) + grpcServer3, addr3 := NewGrpcServerForTest(t) + + hub1 := &mockGrpcServerHub{} + hub2 := &mockGrpcServerHub{} + hub3 := &mockGrpcServerHub{} + grpcServer1.hub = hub1 + grpcServer2.hub = hub2 + grpcServer3.hub = hub3 + + SetEtcdValue(etcd, "/grpctargets/one", []byte("{\"address\":\""+addr1+"\"}")) + SetEtcdValue(etcd, "/grpctargets/two", []byte("{\"address\":\""+addr2+"\"}")) + SetEtcdValue(etcd, "/grpctargets/three", []byte("{\"address\":\""+addr3+"\"}")) + + server1 := NewProxyServerForTest(t, "DE") + server2 := NewProxyServerForTest(t, "US") + server3 := NewProxyServerForTest(t, "US") + + mcu1 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ + etcd: etcd, + servers: []*TestProxyServerHandler{ + server1, + server2, + server3, + }, + }) + mcu2 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ + etcd: etcd, + servers: []*TestProxyServerHandler{ + server1, + server2, + server3, + }, + }) + mcu3 := newMcuProxyForTestWithOptions(t, proxyTestOptions{ + etcd: etcd, + servers: []*TestProxyServerHandler{ + server1, + server2, + server3, + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "the-publisher" + pubSid := "1234567890" + pubListener := &MockMcuListener{ + publicId: pubId + "-public", + } + pubInitiator := &MockMcuInitiator{ + country: "DE", + } + + session1 := &ClientSession{ + publicId: pubId, + publishers: make(map[StreamType]McuPublisher), + } + hub1.addSession(session1) + defer hub1.removeSession(session1) + + pub, err := mcu1.NewPublisher(ctx, pubListener, pubId, pubSid, StreamTypeVideo, NewPublisherSettings{ + MediaTypes: MediaTypeVideo | MediaTypeAudio, + }, pubInitiator) + require.NoError(t, err) + + defer pub.Close(context.Background()) + + session1.mu.Lock() + session1.publishers[StreamTypeVideo] = pub + session1.publisherWaiters.Wakeup() + session1.mu.Unlock() + + sub1Listener := &MockMcuListener{ + publicId: "subscriber-public-1", + } + sub1Initiator := &MockMcuInitiator{ + country: "US", + } + sub1, err := mcu2.NewSubscriber(ctx, sub1Listener, pubId, StreamTypeVideo, sub1Initiator) + require.NoError(t, err) + + defer sub1.Close(context.Background()) + + sub2Listener := &MockMcuListener{ + publicId: "subscriber-public-2", + } + sub2Initiator := &MockMcuInitiator{ + country: "US", + } + sub2, err := mcu3.NewSubscriber(ctx, sub2Listener, pubId, StreamTypeVideo, sub2Initiator) + require.NoError(t, err) + + defer sub2.Close(context.Background()) +} + func Test_ProxyRemotePublisherWait(t *testing.T) { CatchLogForTest(t) t.Parallel() diff --git a/mcu_test.go b/mcu_test.go index ee84bdef..1fb6841f 100644 --- a/mcu_test.go +++ b/mcu_test.go @@ -229,7 +229,7 @@ func (p *TestMCUPublisher) PublishRemote(ctx context.Context, remoteId string, h return errors.New("remote publishing not supported") } -func (p *TestMCUPublisher) UnpublishRemote(ctx context.Context, remoteId string) error { +func (p *TestMCUPublisher) UnpublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error { return errors.New("remote publishing not supported") } diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index d3655e8f..0e6634dc 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -856,6 +856,28 @@ func (p *proxyRemotePublisher) StartPublishing(ctx context.Context, publisher si return nil } +func (p *proxyRemotePublisher) StopPublishing(ctx context.Context, publisher signaling.McuRemotePublisherProperties) error { + conn, err := p.proxy.getRemoteConnection(p.remoteUrl) + if err != nil { + return err + } + + if _, err := conn.RequestMessage(ctx, &signaling.ProxyClientMessage{ + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "unpublish-remote", + ClientId: p.publisherId, + Hostname: p.proxy.remoteHostname, + Port: publisher.Port(), + RtcpPort: publisher.RtcpPort(), + }, + }); err != nil { + return err + } + + return nil +} + func (p *proxyRemotePublisher) GetStreams(ctx context.Context) ([]signaling.PublisherStream, error) { conn, err := p.proxy.getRemoteConnection(p.remoteUrl) if err != nil { @@ -1125,7 +1147,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s ctx2, cancel = context.WithTimeout(ctx, s.mcuTimeout) defer cancel() - if err := publisher.UnpublishRemote(ctx2, session.PublicId()); err != nil { + if err := publisher.UnpublishRemote(ctx2, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil { log.Printf("Error unpublishing old %s %s to remote %s (port=%d, rtcpPort=%d): %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, cmd.Port, cmd.RtcpPort, err) session.sendMessage(message.NewWrappedErrorServerMessage(err)) return @@ -1141,6 +1163,36 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s } } + response := &signaling.ProxyServerMessage{ + Id: message.Id, + Type: "command", + Command: &signaling.CommandProxyServerMessage{ + Id: cmd.ClientId, + }, + } + session.sendMessage(response) + case "unpublish-remote": + client := s.GetClient(cmd.ClientId) + if client == nil { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + publisher, ok := client.(signaling.McuPublisher) + if !ok { + session.sendMessage(message.NewErrorServerMessage(UnknownClient)) + return + } + + ctx2, cancel := context.WithTimeout(ctx, s.mcuTimeout) + defer cancel() + + if err := publisher.UnpublishRemote(ctx2, session.PublicId(), cmd.Hostname, cmd.Port, cmd.RtcpPort); err != nil { + log.Printf("Error unpublishing %s %s from remote %s: %s", publisher.StreamType(), cmd.ClientId, cmd.Hostname, err) + session.sendMessage(message.NewWrappedErrorServerMessage(err)) + return + } + response := &signaling.ProxyServerMessage{ Id: message.Id, Type: "command", diff --git a/proxy/proxy_server_test.go b/proxy/proxy_server_test.go index 395a1d77..e12b66b5 100644 --- a/proxy/proxy_server_test.go +++ b/proxy/proxy_server_test.go @@ -431,7 +431,7 @@ func (p *TestMCUPublisher) PublishRemote(ctx context.Context, remoteId string, h return errors.New("not implemented") } -func (p *TestMCUPublisher) UnpublishRemote(ctx context.Context, remoteId string) error { +func (p *TestMCUPublisher) UnpublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error { return errors.New("not implemented") } From af7eda29b29f09e4e0ae60fafdd65fac7a2dbdaa Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Thu, 7 Nov 2024 14:33:50 +0100 Subject: [PATCH 2/5] Add test for remote subscribing. --- proxy/proxy_server_test.go | 217 +++++++++++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) diff --git a/proxy/proxy_server_test.go b/proxy/proxy_server_test.go index e12b66b5..cb6f42f0 100644 --- a/proxy/proxy_server_test.go +++ b/proxy/proxy_server_test.go @@ -618,3 +618,220 @@ func TestProxyCodecs(t *testing.T) { } } } + +type RemoteSubscriberTestMCU struct { + TestMCU + + publisher *TestRemotePublisher + subscriber *TestRemoteSubscriber +} + +func NewRemoteSubscriberTestMCU(t *testing.T) *RemoteSubscriberTestMCU { + return &RemoteSubscriberTestMCU{ + TestMCU: TestMCU{ + t: t, + }, + } +} + +type TestRemotePublisher struct { + t *testing.T + + streamType signaling.StreamType + refcnt atomic.Int32 + closed context.Context + closeFunc context.CancelFunc +} + +func (p *TestRemotePublisher) Id() string { + return "id" +} + +func (p *TestRemotePublisher) Sid() string { + return "sid" +} + +func (p *TestRemotePublisher) StreamType() signaling.StreamType { + return p.streamType +} + +func (p *TestRemotePublisher) MaxBitrate() int { + return 0 +} + +func (p *TestRemotePublisher) Close(ctx context.Context) { + if count := p.refcnt.Add(-1); assert.True(p.t, count >= 0) && count == 0 { + p.closeFunc() + } +} + +func (p *TestRemotePublisher) SendMessage(ctx context.Context, message *signaling.MessageClientMessage, data *signaling.MessageClientMessageData, callback func(error, map[string]interface{})) { + callback(errors.New("not implemented"), nil) +} + +func (p *TestRemotePublisher) Port() int { + return 1 +} + +func (p *TestRemotePublisher) RtcpPort() int { + return 2 +} + +func (m *RemoteSubscriberTestMCU) NewRemotePublisher(ctx context.Context, listener signaling.McuListener, controller signaling.RemotePublisherController, streamType signaling.StreamType) (signaling.McuRemotePublisher, error) { + require.Nil(m.t, m.publisher) + assert.EqualValues(m.t, "video", streamType) + closeCtx, closeFunc := context.WithCancel(context.Background()) + m.publisher = &TestRemotePublisher{ + t: m.t, + + streamType: streamType, + closed: closeCtx, + closeFunc: closeFunc, + } + m.publisher.refcnt.Add(1) + return m.publisher, nil +} + +type TestRemoteSubscriber struct { + t *testing.T + + publisher *TestRemotePublisher + closed context.Context + closeFunc context.CancelFunc +} + +func (s *TestRemoteSubscriber) Id() string { + return "id" +} + +func (s *TestRemoteSubscriber) Sid() string { + return "sid" +} + +func (s *TestRemoteSubscriber) StreamType() signaling.StreamType { + return s.publisher.StreamType() +} + +func (s *TestRemoteSubscriber) MaxBitrate() int { + return 0 +} + +func (s *TestRemoteSubscriber) Close(ctx context.Context) { + s.publisher.Close(ctx) + s.closeFunc() +} + +func (s *TestRemoteSubscriber) SendMessage(ctx context.Context, message *signaling.MessageClientMessage, data *signaling.MessageClientMessageData, callback func(error, map[string]interface{})) { + callback(errors.New("not implemented"), nil) +} + +func (s *TestRemoteSubscriber) Publisher() string { + return s.publisher.Id() +} + +func (m *RemoteSubscriberTestMCU) NewRemoteSubscriber(ctx context.Context, listener signaling.McuListener, publisher signaling.McuRemotePublisher) (signaling.McuRemoteSubscriber, error) { + require.Nil(m.t, m.subscriber) + pub, ok := publisher.(*TestRemotePublisher) + require.True(m.t, ok) + closeCtx, closeFunc := context.WithCancel(context.Background()) + m.subscriber = &TestRemoteSubscriber{ + t: m.t, + + publisher: pub, + closed: closeCtx, + closeFunc: closeFunc, + } + pub.refcnt.Add(1) + return m.subscriber, nil +} + +func TestProxyRemoteSubscriber(t *testing.T) { + signaling.CatchLogForTest(t) + assert := assert.New(t) + require := require.New(t) + proxy, key, server := newProxyServerForTest(t) + + mcu := NewRemoteSubscriberTestMCU(t) + proxy.mcu = mcu + // Unused but must be set so remote subscribing works + proxy.tokenId = "token" + proxy.tokenKey = key + proxy.remoteHostname = "test-hostname" + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewProxyTestClient(ctx, t, server.URL) + defer client.CloseWithBye() + + require.NoError(client.SendHello(key)) + + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client.RunUntilLoad(ctx, 0) + assert.NoError(err) + + publisherId := "the-publisher-id" + claims := &signaling.TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now().Add(-maxTokenAge / 2)), + Issuer: TokenIdForTest, + Subject: publisherId, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(key) + require.NoError(err) + + require.NoError(client.WriteJSON(&signaling.ProxyClientMessage{ + Id: "2345", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "create-subscriber", + StreamType: signaling.StreamTypeVideo, + PublisherId: publisherId, + RemoteUrl: "https://remote-hostname", + RemoteToken: tokenString, + }, + })) + + var clientId string + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("2345", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + clientId = message.Command.Id + } + } + + require.NoError(client.WriteJSON(&signaling.ProxyClientMessage{ + Id: "3456", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "delete-subscriber", + ClientId: clientId, + }, + })) + + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("3456", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + assert.Equal(clientId, message.Command.Id) + } + } + + if assert.NotNil(mcu.publisher) && assert.NotNil(mcu.subscriber) { + select { + case <-mcu.subscriber.closed.Done(): + case <-ctx.Done(): + assert.Fail("subscriber was not closed") + } + select { + case <-mcu.publisher.closed.Done(): + case <-ctx.Done(): + assert.Fail("publisher was not closed") + } + } +} From 469e97f4838f388124eda02f7ba6f5c716f9c724 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Thu, 7 Nov 2024 16:57:12 +0100 Subject: [PATCH 3/5] Make JanusGateway an interface to help with testing. --- janus_client.go | 25 ++++++++++++++++++++----- mcu_janus.go | 27 ++++++++++++++++++++++----- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/janus_client.go b/janus_client.go index b7b33a5f..5716bef8 100644 --- a/janus_client.go +++ b/janus_client.go @@ -219,6 +219,17 @@ type dummyGatewayListener struct { func (l *dummyGatewayListener) ConnectionInterrupted() { } +type JanusGatewayInterface interface { + Info(context.Context) (*InfoMsg, error) + Create(context.Context) (*JanusSession, error) + Close() error + + send(map[string]interface{}, *transaction) (uint64, error) + removeTransaction(uint64) + + removeSession(*JanusSession) +} + // Gateway represents a connection to an instance of the Janus Gateway. type JanusGateway struct { listener GatewayListener @@ -560,12 +571,18 @@ func (gateway *JanusGateway) Create(ctx context.Context) (*JanusSession, error) // Store this session gateway.Lock() + defer gateway.Unlock() gateway.Sessions[session.Id] = session - gateway.Unlock() return session, nil } +func (gateway *JanusGateway) removeSession(session *JanusSession) { + gateway.Lock() + defer gateway.Unlock() + delete(gateway.Sessions, session.Id) +} + // Session represents a session instance on the Janus Gateway. type JanusSession struct { // Id is the session_id of this session @@ -578,7 +595,7 @@ type JanusSession struct { // and Session.Unlock() methods provided by the embedded sync.Mutex. sync.Mutex - gateway *JanusGateway + gateway JanusGatewayInterface } func (session *JanusSession) send(msg map[string]interface{}, t *transaction) (uint64, error) { @@ -670,9 +687,7 @@ func (session *JanusSession) Destroy(ctx context.Context) (*janus.AckMsg, error) } // Remove this session from the gateway - session.gateway.Lock() - delete(session.gateway.Sessions, session.Id) - session.gateway.Unlock() + session.gateway.removeSession(session) return ack, nil } diff --git a/mcu_janus.go b/mcu_janus.go index e25adca4..1a680036 100644 --- a/mcu_janus.go +++ b/mcu_janus.go @@ -78,6 +78,11 @@ func convertIntValue(value interface{}) (uint64, error) { return uint64(t), nil case uint64: return t, nil + case int: + if t < 0 { + return 0, fmt.Errorf("Unsupported int number: %+v", t) + } + return uint64(t), nil case int64: if t < 0 { return 0, fmt.Errorf("Unsupported int64 number: %+v", t) @@ -92,7 +97,7 @@ func convertIntValue(value interface{}) (uint64, error) { } return uint64(r), nil default: - return 0, fmt.Errorf("Unknown number type: %+v", t) + return 0, fmt.Errorf("Unknown number type: %+v (%T)", t, t) } } @@ -170,7 +175,9 @@ type mcuJanus struct { settings McuSettings - gw *JanusGateway + createJanusGateway func(ctx context.Context, wsURL string, listener GatewayListener) (JanusGatewayInterface, error) + + gw JanusGatewayInterface session *JanusSession handle *JanusHandle @@ -213,6 +220,9 @@ func NewMcuJanus(ctx context.Context, url string, config *goconf.ConfigFile) (Mc publishers: make(map[string]*mcuJanusPublisher), remotePublishers: make(map[string]*mcuJanusRemotePublisher), + createJanusGateway: func(ctx context.Context, wsURL string, listener GatewayListener) (JanusGatewayInterface, error) { + return NewJanusGateway(ctx, wsURL, listener) + }, reconnectInterval: initialReconnectInterval, } mcu.onConnected.Store(emptyOnConnected) @@ -222,8 +232,10 @@ func NewMcuJanus(ctx context.Context, url string, config *goconf.ConfigFile) (Mc mcu.doReconnect(context.Background()) }) mcu.reconnectTimer.Stop() - if err := mcu.reconnect(ctx); err != nil { - return nil, err + if mcu.url != "" { + if err := mcu.reconnect(ctx); err != nil { + return nil, err + } } return mcu, nil } @@ -252,7 +264,7 @@ func (m *mcuJanus) disconnect() { func (m *mcuJanus) reconnect(ctx context.Context) error { m.disconnect() - gw, err := NewJanusGateway(ctx, m.url, m) + gw, err := m.createJanusGateway(ctx, m.url, m) if err != nil { return err } @@ -317,6 +329,11 @@ func (m *mcuJanus) hasRemotePublisher() bool { } func (m *mcuJanus) Start(ctx context.Context) error { + if m.url == "" { + if err := m.reconnect(ctx); err != nil { + return err + } + } info, err := m.gw.Info(ctx) if err != nil { return err From 6038d667307031e0e415721379d07b525a8a50c3 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Thu, 7 Nov 2024 16:57:31 +0100 Subject: [PATCH 4/5] Add some Janus tests. --- mcu_janus_test.go | 584 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 584 insertions(+) create mode 100644 mcu_janus_test.go diff --git a/mcu_janus_test.go b/mcu_janus_test.go new file mode 100644 index 00000000..f7407d80 --- /dev/null +++ b/mcu_janus_test.go @@ -0,0 +1,584 @@ +/** + * Standalone signaling server for the Nextcloud Spreed app. + * Copyright (C) 2024 struktur AG + * + * @author Joachim Bauch + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package signaling + +import ( + "context" + "encoding/json" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/dlintw/goconf" + "github.com/notedit/janus-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type TestJanusHandle struct { + id uint64 +} + +type TestJanusRoom struct { + id uint64 +} + +type TestJanusHandler func(room *TestJanusRoom, body map[string]interface{}) (interface{}, *janus.ErrorMsg) + +type TestJanusGateway struct { + t *testing.T + + sid atomic.Uint64 + tid atomic.Uint64 + hid atomic.Uint64 + rid atomic.Uint64 + mu sync.Mutex + + sessions map[uint64]*JanusSession + transactions map[uint64]*transaction + handles map[uint64]*TestJanusHandle + rooms map[uint64]*TestJanusRoom + handlers map[string]TestJanusHandler +} + +func NewTestJanusGateway(t *testing.T) *TestJanusGateway { + gateway := &TestJanusGateway{ + t: t, + + sessions: make(map[uint64]*JanusSession), + transactions: make(map[uint64]*transaction), + handles: make(map[uint64]*TestJanusHandle), + rooms: make(map[uint64]*TestJanusRoom), + handlers: make(map[string]TestJanusHandler), + } + + t.Cleanup(func() { + assert := assert.New(t) + gateway.mu.Lock() + defer gateway.mu.Unlock() + assert.Len(gateway.sessions, 0) + assert.Len(gateway.transactions, 0) + assert.Len(gateway.handles, 0) + assert.Len(gateway.rooms, 0) + }) + + return gateway +} + +func (g *TestJanusGateway) registerHandlers(handlers map[string]TestJanusHandler) { + g.mu.Lock() + defer g.mu.Unlock() + for name, handler := range handlers { + g.handlers[name] = handler + } +} + +func (g *TestJanusGateway) Info(ctx context.Context) (*InfoMsg, error) { + return &InfoMsg{ + Name: "TestJanus", + Version: 1400, + VersionString: "1.4.0", + Author: "struktur AG", + DataChannels: true, + FullTrickle: true, + Plugins: map[string]janus.PluginInfo{ + pluginVideoRoom: { + Name: "Test VideoRoom plugin", + VersionString: "0.0.0", + Author: "struktur AG", + }, + }, + }, nil +} + +func (g *TestJanusGateway) Create(ctx context.Context) (*JanusSession, error) { + sid := g.sid.Add(1) + session := &JanusSession{ + Id: sid, + Handles: make(map[uint64]*JanusHandle), + gateway: g, + } + g.mu.Lock() + defer g.mu.Unlock() + g.sessions[sid] = session + return session, nil +} + +func (g *TestJanusGateway) Close() error { + return nil +} + +func (g *TestJanusGateway) processMessage(session *JanusSession, handle *TestJanusHandle, body map[string]interface{}) interface{} { + request := body["request"].(string) + switch request { + case "create": + room := &TestJanusRoom{ + id: g.rid.Add(1), + } + g.rooms[room.id] = room + + return &janus.SuccessMsg{ + PluginData: janus.PluginData{ + Plugin: pluginVideoRoom, + Data: map[string]interface{}{ + "room": room.id, + }, + }, + } + case "join": + rid := body["room"].(float64) + room := g.rooms[uint64(rid)] + if room == nil { + return &janus.ErrorMsg{ + Err: janus.ErrorData{ + Code: JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM, + Reason: "Room not found", + }, + } + } + + assert.Equal(g.t, "publisher", body["ptype"]) + return &janus.EventMsg{ + Session: session.Id, + Handle: handle.id, + Plugindata: janus.PluginData{ + Plugin: pluginVideoRoom, + Data: map[string]interface{}{ + "room": room.id, + }, + }, + } + case "destroy": + rid := body["room"].(float64) + room := g.rooms[uint64(rid)] + if room == nil { + return &janus.ErrorMsg{ + Err: janus.ErrorData{ + Code: JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM, + Reason: "Room not found", + }, + } + } + + delete(g.rooms, uint64(rid)) + + return &janus.SuccessMsg{ + PluginData: janus.PluginData{ + Plugin: pluginVideoRoom, + Data: map[string]interface{}{}, + }, + } + default: + rid := body["room"].(float64) + room := g.rooms[uint64(rid)] + if room == nil { + return &janus.ErrorMsg{ + Err: janus.ErrorData{ + Code: JANUS_VIDEOROOM_ERROR_NO_SUCH_ROOM, + Reason: "Room not found", + }, + } + } + + handler, found := g.handlers[request] + if found { + var err *janus.ErrorMsg + result, err := handler(room, body) + if err != nil { + result = err + } + return result + } + } + + return nil +} + +func (g *TestJanusGateway) processRequest(msg map[string]interface{}) interface{} { + method, found := msg["janus"] + if !found { + return nil + } + + sid := msg["session_id"].(float64) + g.mu.Lock() + defer g.mu.Unlock() + session := g.sessions[uint64(sid)] + if session == nil { + return &janus.ErrorMsg{ + Err: janus.ErrorData{ + Code: JANUS_ERROR_SESSION_NOT_FOUND, + Reason: "Session not found", + }, + } + } + + switch method { + case "attach": + handle := &TestJanusHandle{ + id: g.hid.Add(1), + } + + g.handles[handle.id] = handle + + return &janus.SuccessMsg{ + Data: janus.SuccessData{ + ID: handle.id, + }, + } + case "detach": + hid := msg["handle_id"].(float64) + handle, found := g.handles[uint64(hid)] + if found { + delete(g.handles, handle.id) + } + if handle == nil { + return &janus.ErrorMsg{ + Err: janus.ErrorData{ + Code: JANUS_ERROR_HANDLE_NOT_FOUND, + Reason: "Handle not found", + }, + } + } + + return &janus.AckMsg{} + case "destroy": + delete(g.sessions, session.Id) + return &janus.AckMsg{} + case "message": + hid := msg["handle_id"].(float64) + handle, found := g.handles[uint64(hid)] + if !found { + return &janus.ErrorMsg{ + Err: janus.ErrorData{ + Code: JANUS_ERROR_HANDLE_NOT_FOUND, + Reason: "Handle not found", + }, + } + } + + body := msg["body"].(map[string]interface{}) + return g.processMessage(session, handle, body) + } + + return nil +} + +func (g *TestJanusGateway) send(msg map[string]interface{}, t *transaction) (uint64, error) { + tid := g.tid.Add(1) + + data, err := json.Marshal(msg) + require.NoError(g.t, err) + err = json.Unmarshal(data, &msg) + require.NoError(g.t, err) + + go t.run() + + g.mu.Lock() + defer g.mu.Unlock() + g.transactions[tid] = t + + go func() { + result := g.processRequest(msg) + if !assert.NotNil(g.t, result, "Unsupported request %+v", msg) { + result = &janus.ErrorMsg{ + Err: janus.ErrorData{ + Code: JANUS_ERROR_UNKNOWN, + Reason: "Not implemented", + }, + } + } + + t.add(result) + }() + + return tid, nil +} + +func (g *TestJanusGateway) removeTransaction(id uint64) { + g.mu.Lock() + defer g.mu.Unlock() + delete(g.transactions, id) +} + +func (g *TestJanusGateway) removeSession(session *JanusSession) { + g.mu.Lock() + defer g.mu.Unlock() + delete(g.sessions, session.Id) +} + +func newMcuJanusForTesting(t *testing.T) (*mcuJanus, *TestJanusGateway) { + gateway := NewTestJanusGateway(t) + + config := goconf.NewConfigFile() + mcu, err := NewMcuJanus(context.Background(), "", config) + require.NoError(t, err) + t.Cleanup(func() { + mcu.Stop() + }) + + mcuJanus := mcu.(*mcuJanus) + mcuJanus.createJanusGateway = func(ctx context.Context, wsURL string, listener GatewayListener) (JanusGatewayInterface, error) { + return gateway, nil + } + require.NoError(t, mcu.Start(context.Background())) + return mcuJanus, gateway +} + +type TestMcuListener struct { + id string +} + +func (t *TestMcuListener) PublicId() string { + return t.id +} + +func (t *TestMcuListener) OnUpdateOffer(client McuClient, offer map[string]interface{}) { + +} + +func (t *TestMcuListener) OnIceCandidate(client McuClient, candidate interface{}) { + +} + +func (t *TestMcuListener) OnIceCompleted(client McuClient) { + +} + +func (t *TestMcuListener) SubscriberSidUpdated(subscriber McuSubscriber) { + +} + +func (t *TestMcuListener) PublisherClosed(publisher McuPublisher) { + +} + +func (t *TestMcuListener) SubscriberClosed(subscriber McuSubscriber) { + +} + +type TestMcuController struct { + id string +} + +func (c *TestMcuController) PublisherId() string { + return c.id +} + +func (c *TestMcuController) StartPublishing(ctx context.Context, publisher McuRemotePublisherProperties) error { + // TODO: Check parameters? + return nil +} + +func (c *TestMcuController) StopPublishing(ctx context.Context, publisher McuRemotePublisherProperties) error { + // TODO: Check parameters? + return nil +} + +func (c *TestMcuController) GetStreams(ctx context.Context) ([]PublisherStream, error) { + streams := []PublisherStream{ + { + Mid: "0", + Mindex: 0, + Type: "audio", + Codec: "opus", + }, + } + return streams, nil +} + +type TestMcuInitiator struct { + country string +} + +func (i *TestMcuInitiator) Country() string { + return i.country +} + +func Test_JanusPublisherSubscriber(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + require := require.New(t) + + mcu, gateway := newMcuJanusForTesting(t) + gateway.registerHandlers(map[string]TestJanusHandler{}) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "publisher-id" + listener1 := &TestMcuListener{ + id: pubId, + } + + settings1 := NewPublisherSettings{} + initiator1 := &TestMcuInitiator{ + country: "DE", + } + + pub, err := mcu.NewPublisher(ctx, listener1, pubId, "sid", StreamTypeVideo, settings1, initiator1) + require.NoError(err) + defer pub.Close(context.Background()) + + listener2 := &TestMcuListener{ + id: pubId, + } + + initiator2 := &TestMcuInitiator{ + country: "DE", + } + sub, err := mcu.NewSubscriber(ctx, listener2, pubId, StreamTypeVideo, initiator2) + require.NoError(err) + defer sub.Close(context.Background()) +} + +func Test_JanusSubscriberPublisher(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + require := require.New(t) + + mcu, gateway := newMcuJanusForTesting(t) + gateway.registerHandlers(map[string]TestJanusHandler{}) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + pubId := "publisher-id" + listener1 := &TestMcuListener{ + id: pubId, + } + + settings1 := NewPublisherSettings{} + initiator1 := &TestMcuInitiator{ + country: "DE", + } + + ready := make(chan struct{}) + done := make(chan struct{}) + + go func() { + defer close(done) + time.Sleep(100 * time.Millisecond) + pub, err := mcu.NewPublisher(ctx, listener1, pubId, "sid", StreamTypeVideo, settings1, initiator1) + require.NoError(err) + defer func() { + <-ready + pub.Close(context.Background()) + }() + }() + + listener2 := &TestMcuListener{ + id: pubId, + } + + initiator2 := &TestMcuInitiator{ + country: "DE", + } + sub, err := mcu.NewSubscriber(ctx, listener2, pubId, StreamTypeVideo, initiator2) + require.NoError(err) + defer sub.Close(context.Background()) + close(ready) + <-done +} + +func Test_JanusRemotePublisher(t *testing.T) { + CatchLogForTest(t) + t.Parallel() + assert := assert.New(t) + require := require.New(t) + + var added atomic.Int32 + var removed atomic.Int32 + + mcu, gateway := newMcuJanusForTesting(t) + gateway.registerHandlers(map[string]TestJanusHandler{ + "add_remote_publisher": func(room *TestJanusRoom, body map[string]interface{}) (interface{}, *janus.ErrorMsg) { + assert.EqualValues(1, room.id) + if streams := body["streams"].([]interface{}); assert.Len(streams, 1) { + stream := streams[0].(map[string]interface{}) + assert.Equal("0", stream["mid"]) + assert.EqualValues(0, stream["mindex"]) + assert.Equal("audio", stream["type"]) + assert.Equal("opus", stream["codec"]) + } + added.Add(1) + return &janus.SuccessMsg{ + PluginData: janus.PluginData{ + Plugin: pluginVideoRoom, + Data: map[string]interface{}{ + "id": 12345, + "port": 10000, + "rtcp_port": 10001, + }, + }, + }, nil + }, + "remove_remote_publisher": func(room *TestJanusRoom, body map[string]interface{}) (interface{}, *janus.ErrorMsg) { + assert.EqualValues(1, room.id) + removed.Add(1) + return &janus.SuccessMsg{ + PluginData: janus.PluginData{ + Plugin: pluginVideoRoom, + Data: map[string]interface{}{}, + }, + }, nil + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + listener1 := &TestMcuListener{ + id: "publisher-id", + } + + controller := &TestMcuController{ + id: listener1.id, + } + + pub, err := mcu.NewRemotePublisher(ctx, listener1, controller, StreamTypeVideo) + require.NoError(err) + defer pub.Close(context.Background()) + + assert.EqualValues(1, added.Load()) + assert.EqualValues(0, removed.Load()) + + listener2 := &TestMcuListener{ + id: "subscriber-id", + } + + sub, err := mcu.NewRemoteSubscriber(ctx, listener2, pub) + require.NoError(err) + defer sub.Close(context.Background()) + + pub.Close(context.Background()) + + assert.EqualValues(1, added.Load()) + // The publisher is ref-counted, and still referenced by the subscriber. + assert.EqualValues(0, removed.Load()) + + sub.Close(context.Background()) + + assert.EqualValues(1, added.Load()) + assert.EqualValues(1, removed.Load()) +} From 71ceadbf4ce8fb663e9da60dcae26063412fc6d1 Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Mon, 11 Nov 2024 10:17:25 +0100 Subject: [PATCH 5/5] Add more cases when to stop remote publishing. --- proxy/proxy_server.go | 12 + proxy/proxy_server_test.go | 511 +++++++++++++++++++++++++++++++++++++ proxy/proxy_session.go | 83 ++++++ 3 files changed, 606 insertions(+) diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index 0e6634dc..b256d099 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -1163,6 +1163,7 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s } } + session.AddRemotePublisher(publisher, cmd.Hostname, cmd.Port, cmd.RtcpPort) response := &signaling.ProxyServerMessage{ Id: message.Id, Type: "command", @@ -1193,6 +1194,8 @@ func (s *ProxyServer) processCommand(ctx context.Context, client *ProxyClient, s return } + session.RemoveRemotePublisher(publisher, cmd.Hostname, cmd.Port, cmd.RtcpPort) + response := &signaling.ProxyServerMessage{ Id: message.Id, Type: "command", @@ -1599,3 +1602,12 @@ func (s *ProxyServer) getRemoteConnection(url string) (*RemoteConnection, error) s.remoteConnections[url] = conn return conn, nil } + +func (s *ProxyServer) PublisherDeleted(publisher signaling.McuPublisher) { + s.sessionsLock.RLock() + defer s.sessionsLock.RUnlock() + + for _, session := range s.sessions { + session.OnPublisherDeleted(publisher) + } +} diff --git a/proxy/proxy_server_test.go b/proxy/proxy_server_test.go index cb6f42f0..973b6dc3 100644 --- a/proxy/proxy_server_test.go +++ b/proxy/proxy_server_test.go @@ -33,6 +33,7 @@ import ( "net/http/httptest" "os" "strings" + "sync" "sync/atomic" "testing" "time" @@ -835,3 +836,513 @@ func TestProxyRemoteSubscriber(t *testing.T) { } } } + +func TestProxyCloseRemoteOnSessionClose(t *testing.T) { + signaling.CatchLogForTest(t) + assert := assert.New(t) + require := require.New(t) + proxy, key, server := newProxyServerForTest(t) + + mcu := NewRemoteSubscriberTestMCU(t) + proxy.mcu = mcu + // Unused but must be set so remote subscribing works + proxy.tokenId = "token" + proxy.tokenKey = key + proxy.remoteHostname = "test-hostname" + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client := NewProxyTestClient(ctx, t, server.URL) + defer client.CloseWithBye() + + require.NoError(client.SendHello(key)) + + if hello, err := client.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client.RunUntilLoad(ctx, 0) + assert.NoError(err) + + publisherId := "the-publisher-id" + claims := &signaling.TokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now().Add(-maxTokenAge / 2)), + Issuer: TokenIdForTest, + Subject: publisherId, + }, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(key) + require.NoError(err) + + require.NoError(client.WriteJSON(&signaling.ProxyClientMessage{ + Id: "2345", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "create-subscriber", + StreamType: signaling.StreamTypeVideo, + PublisherId: publisherId, + RemoteUrl: "https://remote-hostname", + RemoteToken: tokenString, + }, + })) + + if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("2345", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + // Closing the session will cause any active remote publishers stop be stopped. + client.CloseWithBye() + + if assert.NotNil(mcu.publisher) && assert.NotNil(mcu.subscriber) { + select { + case <-mcu.subscriber.closed.Done(): + case <-ctx.Done(): + assert.Fail("subscriber was not closed") + } + select { + case <-mcu.publisher.closed.Done(): + case <-ctx.Done(): + assert.Fail("publisher was not closed") + } + } +} + +type UnpublishRemoteTestMCU struct { + TestMCU + + publisher atomic.Pointer[UnpublishRemoteTestPublisher] +} + +func NewUnpublishRemoteTestMCU(t *testing.T) *UnpublishRemoteTestMCU { + return &UnpublishRemoteTestMCU{ + TestMCU: TestMCU{ + t: t, + }, + } +} + +type UnpublishRemoteTestPublisher struct { + TestMCUPublisher + + t *testing.T + + mu sync.RWMutex + remoteId string + remoteData *remotePublisherData +} + +func (m *UnpublishRemoteTestMCU) NewPublisher(ctx context.Context, listener signaling.McuListener, id string, sid string, streamType signaling.StreamType, settings signaling.NewPublisherSettings, initiator signaling.McuInitiator) (signaling.McuPublisher, error) { + publisher := &UnpublishRemoteTestPublisher{ + TestMCUPublisher: TestMCUPublisher{ + id: id, + sid: sid, + streamType: streamType, + }, + + t: m.t, + } + m.publisher.Store(publisher) + return publisher, nil +} + +func (p *UnpublishRemoteTestPublisher) getRemoteId() string { + p.mu.RLock() + defer p.mu.RUnlock() + return p.remoteId +} + +func (p *UnpublishRemoteTestPublisher) getRemoteData() *remotePublisherData { + p.mu.RLock() + defer p.mu.RUnlock() + return p.remoteData +} + +func (p *UnpublishRemoteTestPublisher) clearRemote() { + p.mu.Lock() + defer p.mu.Unlock() + p.remoteId = "" + p.remoteData = nil +} + +func (p *UnpublishRemoteTestPublisher) PublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error { + p.mu.Lock() + defer p.mu.Unlock() + if assert.Empty(p.t, p.remoteId) { + p.remoteId = remoteId + p.remoteData = &remotePublisherData{ + hostname: hostname, + port: port, + rtcpPort: rtcpPort, + } + } + return nil +} + +func (p *UnpublishRemoteTestPublisher) UnpublishRemote(ctx context.Context, remoteId string, hostname string, port int, rtcpPort int) error { + p.mu.Lock() + defer p.mu.Unlock() + assert.Equal(p.t, remoteId, p.remoteId) + if remoteData := p.remoteData; assert.NotNil(p.t, remoteData) && + assert.Equal(p.t, remoteData.hostname, hostname) && + assert.EqualValues(p.t, remoteData.port, port) && + assert.EqualValues(p.t, remoteData.rtcpPort, rtcpPort) { + p.remoteId = "" + p.remoteData = nil + } + return nil +} + +func TestProxyUnpublishRemote(t *testing.T) { + signaling.CatchLogForTest(t) + assert := assert.New(t) + require := require.New(t) + proxy, key, server := newProxyServerForTest(t) + + mcu := NewUnpublishRemoteTestMCU(t) + proxy.mcu = mcu + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1 := NewProxyTestClient(ctx, t, server.URL) + defer client1.CloseWithBye() + + require.NoError(client1.SendHello(key)) + + if hello, err := client1.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client1.RunUntilLoad(ctx, 0) + assert.NoError(err) + + publisherId := "the-publisher-id" + require.NoError(client1.WriteJSON(&signaling.ProxyClientMessage{ + Id: "2345", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "create-publisher", + PublisherId: publisherId, + Sid: "1234-abcd", + StreamType: signaling.StreamTypeVideo, + PublisherSettings: &signaling.NewPublisherSettings{ + Bitrate: 1234567, + MediaTypes: signaling.MediaTypeAudio | signaling.MediaTypeVideo, + }, + }, + })) + + var clientId string + if message, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("2345", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + clientId = message.Command.Id + } + } + + client2 := NewProxyTestClient(ctx, t, server.URL) + defer client2.CloseWithBye() + + require.NoError(client2.SendHello(key)) + + hello2, err := client2.RunUntilHello(ctx) + if assert.NoError(err) { + assert.NotEmpty(hello2.Hello.SessionId, "%+v", hello2) + } + + _, err = client2.RunUntilLoad(ctx, 0) + assert.NoError(err) + + require.NoError(client2.WriteJSON(&signaling.ProxyClientMessage{ + Id: "3456", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "publish-remote", + StreamType: signaling.StreamTypeVideo, + ClientId: clientId, + Hostname: "remote-host", + Port: 10001, + RtcpPort: 10002, + }, + })) + + if message, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("3456", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + if publisher := mcu.publisher.Load(); assert.NotNil(publisher) { + assert.Equal(hello2.Hello.SessionId, publisher.getRemoteId()) + if remoteData := publisher.getRemoteData(); assert.NotNil(remoteData) { + assert.Equal("remote-host", remoteData.hostname) + assert.EqualValues(10001, remoteData.port) + assert.EqualValues(10002, remoteData.rtcpPort) + } + } + + require.NoError(client2.WriteJSON(&signaling.ProxyClientMessage{ + Id: "4567", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "unpublish-remote", + StreamType: signaling.StreamTypeVideo, + ClientId: clientId, + Hostname: "remote-host", + Port: 10001, + RtcpPort: 10002, + }, + })) + + if message, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("4567", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + if publisher := mcu.publisher.Load(); assert.NotNil(publisher) { + assert.Empty(publisher.getRemoteId()) + assert.Nil(publisher.getRemoteData()) + } +} + +func TestProxyUnpublishRemotePublisherClosed(t *testing.T) { + signaling.CatchLogForTest(t) + assert := assert.New(t) + require := require.New(t) + proxy, key, server := newProxyServerForTest(t) + + mcu := NewUnpublishRemoteTestMCU(t) + proxy.mcu = mcu + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1 := NewProxyTestClient(ctx, t, server.URL) + defer client1.CloseWithBye() + + require.NoError(client1.SendHello(key)) + + if hello, err := client1.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client1.RunUntilLoad(ctx, 0) + assert.NoError(err) + + publisherId := "the-publisher-id" + require.NoError(client1.WriteJSON(&signaling.ProxyClientMessage{ + Id: "2345", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "create-publisher", + PublisherId: publisherId, + Sid: "1234-abcd", + StreamType: signaling.StreamTypeVideo, + PublisherSettings: &signaling.NewPublisherSettings{ + Bitrate: 1234567, + MediaTypes: signaling.MediaTypeAudio | signaling.MediaTypeVideo, + }, + }, + })) + + var clientId string + if message, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("2345", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + clientId = message.Command.Id + } + } + + client2 := NewProxyTestClient(ctx, t, server.URL) + defer client2.CloseWithBye() + + require.NoError(client2.SendHello(key)) + + hello2, err := client2.RunUntilHello(ctx) + if assert.NoError(err) { + assert.NotEmpty(hello2.Hello.SessionId, "%+v", hello2) + } + + _, err = client2.RunUntilLoad(ctx, 0) + assert.NoError(err) + + require.NoError(client2.WriteJSON(&signaling.ProxyClientMessage{ + Id: "3456", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "publish-remote", + StreamType: signaling.StreamTypeVideo, + ClientId: clientId, + Hostname: "remote-host", + Port: 10001, + RtcpPort: 10002, + }, + })) + + if message, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("3456", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + if publisher := mcu.publisher.Load(); assert.NotNil(publisher) { + assert.Equal(hello2.Hello.SessionId, publisher.getRemoteId()) + if remoteData := publisher.getRemoteData(); assert.NotNil(remoteData) { + assert.Equal("remote-host", remoteData.hostname) + assert.EqualValues(10001, remoteData.port) + assert.EqualValues(10002, remoteData.rtcpPort) + } + } + + require.NoError(client1.WriteJSON(&signaling.ProxyClientMessage{ + Id: "4567", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "delete-publisher", + ClientId: clientId, + }, + })) + + if message, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("4567", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + // Remote publishing was not stopped explicitly... + if publisher := mcu.publisher.Load(); assert.NotNil(publisher) { + assert.Equal(hello2.Hello.SessionId, publisher.getRemoteId()) + if remoteData := publisher.getRemoteData(); assert.NotNil(remoteData) { + assert.Equal("remote-host", remoteData.hostname) + assert.EqualValues(10001, remoteData.port) + assert.EqualValues(10002, remoteData.rtcpPort) + } + } + + // ...but the session no longer contains information on the remote publisher. + if data, err := proxy.cookie.DecodePublic(hello2.Hello.SessionId); assert.NoError(err) { + session := proxy.GetSession(data.Sid) + if assert.NotNil(session) { + session.remotePublishersLock.Lock() + defer session.remotePublishersLock.Unlock() + assert.Empty(session.remotePublishers) + } + } + + if publisher := mcu.publisher.Load(); assert.NotNil(publisher) { + publisher.clearRemote() + } +} + +func TestProxyUnpublishRemoteOnSessionClose(t *testing.T) { + signaling.CatchLogForTest(t) + assert := assert.New(t) + require := require.New(t) + proxy, key, server := newProxyServerForTest(t) + + mcu := NewUnpublishRemoteTestMCU(t) + proxy.mcu = mcu + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + client1 := NewProxyTestClient(ctx, t, server.URL) + defer client1.CloseWithBye() + + require.NoError(client1.SendHello(key)) + + if hello, err := client1.RunUntilHello(ctx); assert.NoError(err) { + assert.NotEmpty(hello.Hello.SessionId, "%+v", hello) + } + + _, err := client1.RunUntilLoad(ctx, 0) + assert.NoError(err) + + publisherId := "the-publisher-id" + require.NoError(client1.WriteJSON(&signaling.ProxyClientMessage{ + Id: "2345", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "create-publisher", + PublisherId: publisherId, + Sid: "1234-abcd", + StreamType: signaling.StreamTypeVideo, + PublisherSettings: &signaling.NewPublisherSettings{ + Bitrate: 1234567, + MediaTypes: signaling.MediaTypeAudio | signaling.MediaTypeVideo, + }, + }, + })) + + var clientId string + if message, err := client1.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("2345", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + clientId = message.Command.Id + } + } + + client2 := NewProxyTestClient(ctx, t, server.URL) + defer client2.CloseWithBye() + + require.NoError(client2.SendHello(key)) + + hello2, err := client2.RunUntilHello(ctx) + if assert.NoError(err) { + assert.NotEmpty(hello2.Hello.SessionId, "%+v", hello2) + } + + _, err = client2.RunUntilLoad(ctx, 0) + assert.NoError(err) + + require.NoError(client2.WriteJSON(&signaling.ProxyClientMessage{ + Id: "3456", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "publish-remote", + StreamType: signaling.StreamTypeVideo, + ClientId: clientId, + Hostname: "remote-host", + Port: 10001, + RtcpPort: 10002, + }, + })) + + if message, err := client2.RunUntilMessage(ctx); assert.NoError(err) { + assert.Equal("3456", message.Id) + if err := checkMessageType(message, "command"); assert.NoError(err) { + require.NotEmpty(message.Command.Id) + } + } + + if publisher := mcu.publisher.Load(); assert.NotNil(publisher) { + assert.Equal(hello2.Hello.SessionId, publisher.getRemoteId()) + if remoteData := publisher.getRemoteData(); assert.NotNil(remoteData) { + assert.Equal("remote-host", remoteData.hostname) + assert.EqualValues(10001, remoteData.port) + assert.EqualValues(10002, remoteData.rtcpPort) + } + } + + // Closing the session will cause any active remote publishers stop be stopped. + client2.CloseWithBye() + + if publisher := mcu.publisher.Load(); assert.NotNil(publisher) { + assert.Empty(publisher.getRemoteId()) + assert.Nil(publisher.getRemoteData()) + } +} diff --git a/proxy/proxy_session.go b/proxy/proxy_session.go index ed9ac260..de6645be 100644 --- a/proxy/proxy_session.go +++ b/proxy/proxy_session.go @@ -23,6 +23,7 @@ package main import ( "context" + "fmt" "log" "sync" "sync/atomic" @@ -36,6 +37,12 @@ const ( sessionExpirationTime = time.Minute ) +type remotePublisherData struct { + hostname string + port int + rtcpPort int +} + type ProxySession struct { proxy *ProxyServer id string @@ -55,6 +62,9 @@ type ProxySession struct { subscribersLock sync.Mutex subscribers map[string]signaling.McuSubscriber subscriberIds map[signaling.McuSubscriber]string + + remotePublishersLock sync.Mutex + remotePublishers map[signaling.McuPublisher]map[string]*remotePublisherData } func NewProxySession(proxy *ProxyServer, sid uint64, id string) *ProxySession { @@ -121,6 +131,7 @@ func (s *ProxySession) Close() { s.closeFunc() s.clearPublishers() s.clearSubscribers() + s.clearRemotePublishers() s.proxy.DeleteSession(s.Sid()) } @@ -287,6 +298,8 @@ func (s *ProxySession) DeletePublisher(publisher signaling.McuPublisher) string delete(s.publishers, id) delete(s.publisherIds, publisher) + delete(s.remotePublishers, publisher) + go s.proxy.PublisherDeleted(publisher) return id } @@ -329,6 +342,22 @@ func (s *ProxySession) clearPublishers() { clear(s.publisherIds) } +func (s *ProxySession) clearRemotePublishers() { + s.remotePublishersLock.Lock() + defer s.remotePublishersLock.Unlock() + + go func(remotePublishers map[signaling.McuPublisher]map[string]*remotePublisherData) { + for publisher, entries := range remotePublishers { + for _, data := range entries { + if err := publisher.UnpublishRemote(context.Background(), s.PublicId(), data.hostname, data.port, data.rtcpPort); err != nil { + log.Printf("Error unpublishing %s %s from remote %s: %s", publisher.StreamType(), publisher.Id(), data.hostname, err) + } + } + } + }(s.remotePublishers) + s.remotePublishers = nil +} + func (s *ProxySession) clearSubscribers() { s.publishersLock.Lock() defer s.publishersLock.Unlock() @@ -349,4 +378,58 @@ func (s *ProxySession) clearSubscribers() { func (s *ProxySession) NotifyDisconnected() { s.clearPublishers() s.clearSubscribers() + s.clearRemotePublishers() +} + +func (s *ProxySession) AddRemotePublisher(publisher signaling.McuPublisher, hostname string, port int, rtcpPort int) bool { + s.remotePublishersLock.Lock() + defer s.remotePublishersLock.Unlock() + + remote, found := s.remotePublishers[publisher] + if !found { + remote = make(map[string]*remotePublisherData) + if s.remotePublishers == nil { + s.remotePublishers = make(map[signaling.McuPublisher]map[string]*remotePublisherData) + } + s.remotePublishers[publisher] = remote + } + + key := fmt.Sprintf("%s:%d%d", hostname, port, rtcpPort) + if _, found := remote[key]; found { + return false + } + + data := &remotePublisherData{ + hostname: hostname, + port: port, + rtcpPort: rtcpPort, + } + remote[key] = data + return true +} + +func (s *ProxySession) RemoveRemotePublisher(publisher signaling.McuPublisher, hostname string, port int, rtcpPort int) { + s.remotePublishersLock.Lock() + defer s.remotePublishersLock.Unlock() + + remote, found := s.remotePublishers[publisher] + if !found { + return + } + + key := fmt.Sprintf("%s:%d%d", hostname, port, rtcpPort) + delete(remote, key) + if len(remote) == 0 { + delete(s.remotePublishers, publisher) + if len(s.remotePublishers) == 0 { + s.remotePublishers = nil + } + } +} + +func (s *ProxySession) OnPublisherDeleted(publisher signaling.McuPublisher) { + s.remotePublishersLock.Lock() + defer s.remotePublishersLock.Unlock() + + delete(s.remotePublishers, publisher) }