diff --git a/main.go b/main.go index b033181..c98288c 100644 --- a/main.go +++ b/main.go @@ -4,32 +4,29 @@ package main import ( "encoding/json" + "flag" "fmt" "log" "net" "net/http" "strconv" "strings" - "sync" - - "flag" + "github.com/GRVYDEV/lightspeed-webrtc/ws" "github.com/gorilla/websocket" "github.com/pion/interceptor" "github.com/pion/rtp" "github.com/pion/webrtc/v3" - "github.com/pion/webrtc/v3/pkg/media/samplebuilder" ) var ( - videoBuilder *samplebuilder.SampleBuilder - addr = flag.String("addr", "localhost", "http service address") - ip = flag.String("ip", "none", "IP address for webrtc") - wsPort = flag.Int("ws-port", 8080, "Port for websocket") - rtpPort = flag.Int("rtp-port", 65535, "Port for RTP") - ports = flag.String("ports", "20000-20500", "Port range for webrtc") - upgrader = websocket.Upgrader{ + addr = flag.String("addr", "localhost", "http service address") + ip = flag.String("ip", "none", "IP address for webrtc") + wsPort = flag.Int("ws-port", 8080, "Port for websocket") + rtpPort = flag.Int("rtp-port", 65535, "Port for RTP") + ports = flag.String("ports", "20000-20500", "Port range for webrtc") + upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } @@ -37,26 +34,12 @@ var ( audioTrack *webrtc.TrackLocalStaticRTP - // lock for peerConnections and trackLocals - listLock sync.RWMutex - peerConnections []peerConnectionState - trackLocals map[string]*webrtc.TrackLocalStaticRTP + hub *ws.Hub ) -type websocketMessage struct { - Event string `json:"event"` - Data string `json:"data"` -} - -type peerConnectionState struct { - peerConnection *webrtc.PeerConnection - websocket *threadSafeWriter -} - func main() { flag.Parse() log.SetFlags(0) - trackLocals = map[string]*webrtc.TrackLocalStaticRTP{} // Open a UDP Listener for RTP Packets on port 65535 listener, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP(*addr), Port: *rtpPort}) @@ -83,6 +66,9 @@ func main() { panic(err) } + hub = ws.NewHub() + go hub.Run() + // start HTTP server go func() { http.HandleFunc("/websocket", websocketHandler) @@ -162,34 +148,18 @@ func createWebrtcApi() *webrtc.API { return webrtc.NewAPI(webrtc.WithMediaEngine(m), webrtc.WithInterceptorRegistry(i), webrtc.WithSettingEngine(s)) } -func cleanConnection(peerConnection *webrtc.PeerConnection) { - listLock.Lock() - defer listLock.Unlock() - - for i := range peerConnections { - if peerConnection == peerConnections[i].peerConnection { - peerConnections[i] = peerConnections[len(peerConnections)-1] - peerConnections[len(peerConnections)-1] = peerConnectionState{} - peerConnections = peerConnections[:len(peerConnections)-1] - return - } - } -} - // Handle incoming websockets func websocketHandler(w http.ResponseWriter, r *http.Request) { // Upgrade HTTP request to Websocket - unsafeConn, err := upgrader.Upgrade(w, r, nil) + conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Print("upgrade:", err) return } - c := &threadSafeWriter{unsafeConn, sync.Mutex{}} - // When this frame returns close the Websocket - defer c.Close() //nolint + defer conn.Close() //nolint // Create API that takes IP and port range into account api := createWebrtcApi() @@ -231,17 +201,12 @@ func websocketHandler(w http.ResponseWriter, r *http.Request) { } }() - // Add our new PeerConnection to global list - listLock.Lock() - peerConnections = append(peerConnections, peerConnectionState{peerConnection, c}) - noConnections := len(peerConnections) - for _, conn := range peerConnections { - if msg, err := json.Marshal(noConnections); err == nil { - conn.websocket.WriteJSON(&websocketMessage{Event: "connections", Data: string(msg)}) - } - } - fmt.Printf("Connections: %d\n", len(peerConnections)) - listLock.Unlock() + c := ws.NewClient(hub, conn, peerConnection) + + go c.WriteLoop() + + // Add to the hub + hub.Register <- c // Trickle ICE. Emit server candidate to client peerConnection.OnICECandidate(func(i *webrtc.ICECandidate) { @@ -255,11 +220,13 @@ func websocketHandler(w http.ResponseWriter, r *http.Request) { return } - if writeErr := c.WriteJSON(&websocketMessage{ - Event: "candidate", + if msg, err := json.Marshal(ws.WebsocketMessage{ + Event: ws.MessageTypeCandidate, Data: string(candidateString), - }); writeErr != nil { - log.Println(writeErr) + }); err == nil { + c.Send <- msg + } else { + log.Println(err) } }) @@ -270,9 +237,10 @@ func websocketHandler(w http.ResponseWriter, r *http.Request) { if err := peerConnection.Close(); err != nil { log.Print(err) } - case webrtc.PeerConnectionStateClosed: - cleanConnection(peerConnection) + hub.Unregister <- c + case webrtc.PeerConnectionStateClosed: + hub.Unregister <- c } }) @@ -290,61 +258,16 @@ func websocketHandler(w http.ResponseWriter, r *http.Request) { log.Print(err) } - if err = c.WriteJSON(&websocketMessage{ - Event: "offer", + if msg, err := json.Marshal(ws.WebsocketMessage{ + Event: ws.MessageTypeOffer, Data: string(offerString), - }); err != nil { - log.Print(err) + }); err == nil { + c.Send <- msg + } else { + log.Printf("could not marshal ws message: %s", err) } - message := &websocketMessage{} - for { - _, raw, err := c.ReadMessage() - if err != nil { - log.Println(err) - return - } else if err := json.Unmarshal(raw, &message); err != nil { - log.Println(err) - return - } - - switch message.Event { - case "candidate": - - candidate := webrtc.ICECandidateInit{} - if err := json.Unmarshal([]byte(message.Data), &candidate); err != nil { - log.Println(err) - return - } - - if err := peerConnection.AddICECandidate(candidate); err != nil { - log.Println(err) - return - } - case "answer": - - answer := webrtc.SessionDescription{} - if err := json.Unmarshal([]byte(message.Data), &answer); err != nil { - log.Println(err) - return - } - - if err := peerConnection.SetRemoteDescription(answer); err != nil { - log.Println(err) - return - } - } - } -} - -type threadSafeWriter struct { - *websocket.Conn - sync.Mutex -} - -func (t *threadSafeWriter) WriteJSON(v interface{}) error { - t.Lock() - defer t.Unlock() + go hub.SendInfo(hub.GetInfo()) // non-blocking broadcast, required as the read loop is not started yet. - return t.Conn.WriteJSON(v) + c.ReadLoop() } diff --git a/ws/client.go b/ws/client.go new file mode 100644 index 0000000..daa86a9 --- /dev/null +++ b/ws/client.go @@ -0,0 +1,144 @@ +package ws + +import ( + "encoding/json" + "log" + "time" + + "github.com/gorilla/websocket" + "github.com/pion/webrtc/v3" +) + +// Client is a middleman between the websocket connection and the hub. +type Client struct { + hub *Hub + + // The websocket connection. + conn *websocket.Conn + + // Buffered channel of outbound messages. + Send chan []byte + + // webRTC peer connection + PeerConnection *webrtc.PeerConnection +} + +func NewClient(hub *Hub, conn *websocket.Conn, webrtcConn *webrtc.PeerConnection) *Client { + return &Client{ + hub: hub, + conn: conn, + Send: make(chan []byte), + PeerConnection: webrtcConn, + } +} + +// ReadLoop pumps messages from the websocket connection to the hub. +// +// The application runs ReadLoop in a per-connection goroutine. The application +// ensures that there is at most one reader on a connection by executing all +// reads from this goroutine. +func (c *Client) ReadLoop() { + defer func() { + c.hub.Unregister <- c + c.conn.Close() + }() + c.conn.SetReadLimit(maxMessageSize) + c.conn.SetReadDeadline(time.Now().Add(pongWait)) + c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + message := &WebsocketMessage{} + for { + // _, message, err := c.conn.ReadMessage() + _, raw, err := c.conn.ReadMessage() + if err != nil { + log.Printf("could not read message: %s", err) + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Println("ws closed unexpected") + } + return + } + + err = json.Unmarshal(raw, &message) + if err != nil { + log.Printf("could not unmarshal ws message: %s", err) + return + } + + switch message.Event { + case MessageTypeCandidate: + candidate := webrtc.ICECandidateInit{} + if err := json.Unmarshal([]byte(message.Data), &candidate); err != nil { + log.Printf("could not unmarshal candidate msg: %s", err) + return + } + + if err := c.PeerConnection.AddICECandidate(candidate); err != nil { + log.Printf("could not add ice candidate: %s", err) + return + } + + case MessageTypeAnswer: + answer := webrtc.SessionDescription{} + if err := json.Unmarshal([]byte(message.Data), &answer); err != nil { + log.Printf("could not unmarshal answer msg: %s", err) + return + } + + if err := c.PeerConnection.SetRemoteDescription(answer); err != nil { + log.Printf("could not set remote description: %s", err) + return + } + } + + // we do not send anything to the other clients! + //message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) + //c.hub.Broadcast <- message + } +} + +// WriteLoop pumps messages from the hub to the websocket connection. +// +// A goroutine running WriteLoop is started for each connection. The +// application ensures that there is at most one writer to a connection by +// executing all writes from this goroutine. +func (c *Client) WriteLoop() { + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + c.conn.Close() + }() + for { + select { + case message, ok := <-c.Send: + _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if !ok { + // The hub closed the channel. + _ = c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + w, err := c.conn.NextWriter(websocket.TextMessage) + if err != nil { + return + } + _, _ = w.Write(message) + + // Add queued messages to the current websocket message. + n := len(c.Send) + for i := 0; i < n; i++ { + _, _ = w.Write([]byte{'\n'}) + message = <-c.Send + _, _ = w.Write(message) + } + + if err := w.Close(); err != nil { + return + } + + case <-ticker.C: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} diff --git a/ws/hub.go b/ws/hub.go new file mode 100644 index 0000000..a887580 --- /dev/null +++ b/ws/hub.go @@ -0,0 +1,88 @@ +package ws + +import ( + "encoding/json" + "log" + "time" +) + +const ( + maxMessageSize = 4096 + pongWait = 2 * time.Minute + pingPeriod = time.Minute + writeWait = 10 * time.Second +) + +type Info struct { + NoConnections int `json:"no_connections"` +} + +type Hub struct { + // Registered clients. + clients map[*Client]struct{} + + // Broadcast messages to all clients. + Broadcast chan []byte + + // Register a new client to the hub. + Register chan *Client + + // Unregister a client from the hub. + Unregister chan *Client +} + +func NewHub() *Hub { + return &Hub{ + clients: make(map[*Client]struct{}), + Broadcast: make(chan []byte), + Register: make(chan *Client), + Unregister: make(chan *Client), + } +} + +// NoClients returns the number of clients registered +func (h *Hub) NoClients() int { + return len(h.clients) +} + +// Run is the main hub event loop handling register, unregister and broadcast events. +func (h *Hub) Run() { + for { + select { + case client := <-h.Register: + h.clients[client] = struct{}{} + case client := <-h.Unregister: + if _, ok := h.clients[client]; ok { + delete(h.clients, client) + close(client.Send) + go h.SendInfo(h.GetInfo()) // this way the number of clients does not change between calling the goroutine and executing it + } + case message := <-h.Broadcast: + for client := range h.clients { + client.Send <- message + } + } + } +} + +func (h *Hub) GetInfo() Info { + return Info{ + NoConnections: h.NoClients(), + } +} + +// SendInfo broadcasts hub statistics to all clients. +func (h *Hub) SendInfo(info Info) { + i, err := json.Marshal(info) + if err != nil { + log.Printf("could not marshal ws info: %s", err) + } + if msg, err := json.Marshal(WebsocketMessage{ + Event: MessageTypeInfo, + Data: string(i), + }); err == nil { + h.Broadcast <- msg + } else { + log.Printf("could not marshal ws message: %s", err) + } +} \ No newline at end of file diff --git a/ws/message.go b/ws/message.go new file mode 100644 index 0000000..6c2d721 --- /dev/null +++ b/ws/message.go @@ -0,0 +1,13 @@ +package ws + +const ( + MessageTypeAnswer = "answer" + MessageTypeCandidate = "candidate" + MessageTypeOffer = "offer" + MessageTypeInfo = "info" +) + +type WebsocketMessage struct { + Event string `json:"event"` + Data string `json:"data"` +}