From 414f9730610fbe6c3da34591b4c452bd5e1d172d Mon Sep 17 00:00:00 2001 From: th4s Date: Mon, 16 Dec 2024 19:17:49 +0100 Subject: [PATCH 1/3] feat: add async channel api --- Cargo.lock | 2 + matchbox_socket/Cargo.toml | 3 +- matchbox_socket/src/webrtc_socket/socket.rs | 169 +++++++++++++++++++- 3 files changed, 171 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a479eeb5..4adcda70 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3306,6 +3306,7 @@ dependencies = [ "serde-wasm-bindgen", "serde_json", "thiserror 2.0.6", + "tokio-util", "wasm-bindgen", "wasm-bindgen-futures", "web-sys", @@ -5014,6 +5015,7 @@ checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" dependencies = [ "bytes", "futures-core", + "futures-io", "futures-sink", "pin-project-lite", "tokio", diff --git a/matchbox_socket/Cargo.toml b/matchbox_socket/Cargo.toml index f421939a..e598b7f0 100644 --- a/matchbox_socket/Cargo.toml +++ b/matchbox_socket/Cargo.toml @@ -40,9 +40,11 @@ once_cell = { version = "1.17", default-features = false, features = [ "alloc", ] } derive_more = { version = "1.0", features = ["display", "from"] } +tokio-util = { version = "0.7", features = ["io", "compat"]} ggrs = { version = "0.10", default-features = false, optional = true } bincode = { version = "1.3", default-features = false, optional = true } +bytes = { version = "1.1", default-features = false } [target.'cfg(target_arch = "wasm32")'.dependencies] ggrs = { version = "0.10", default-features = false, optional = true, features = [ @@ -79,7 +81,6 @@ async-tungstenite = { version = "0.28", default-features = false, features = [ "async-tls", ] } webrtc = { version = "0.12", default-features = false } -bytes = { version = "1.1", default-features = false } async-compat = { version = "0.2", default-features = false } [dev-dependencies] diff --git a/matchbox_socket/src/webrtc_socket/socket.rs b/matchbox_socket/src/webrtc_socket/socket.rs index c125cd2d..09db2b70 100644 --- a/matchbox_socket/src/webrtc_socket/socket.rs +++ b/matchbox_socket/src/webrtc_socket/socket.rs @@ -6,11 +6,19 @@ use crate::{ }, Error, }; -use futures::{future::Fuse, select, Future, FutureExt, StreamExt}; +use bytes::Bytes; +use futures::{ + future::Fuse, select, AsyncRead, AsyncWrite, Future, FutureExt, Sink, SinkExt, Stream, + StreamExt, TryStreamExt, +}; use futures_channel::mpsc::{SendError, TrySendError, UnboundedReceiver, UnboundedSender}; use log::{debug, error}; use matchbox_protocol::PeerId; -use std::{collections::HashMap, pin::Pin, time::Duration}; +use std::{collections::HashMap, future::ready, pin::Pin, task::Poll, time::Duration}; +use tokio_util::{ + compat::TokioAsyncWriteCompatExt, + io::{CopyToBytes, SinkWriter}, +}; /// Configuration options for an ICE server connection. /// See also: @@ -310,6 +318,119 @@ impl WebRtcChannel { } } +impl Stream for WebRtcChannel { + type Item = (PeerId, Packet); + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut rx = Pin::new(&mut self.get_mut().rx); + rx.as_mut().poll_next(cx) + } +} + +impl Sink<(PeerId, Packet)> for WebRtcChannel { + type Error = SendError; + + fn poll_ready( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut tx = Pin::new(&mut self.get_mut().tx); + tx.as_mut().poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: (PeerId, Packet)) -> Result<(), Self::Error> { + let mut tx = Pin::new(&mut self.get_mut().tx); + tx.as_mut().start_send(item) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut tx = Pin::new(&mut self.get_mut().tx); + tx.as_mut().poll_flush(cx) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut tx = Pin::new(&mut self.get_mut().tx); + tx.as_mut().poll_close(cx) + } +} + +/// A channel which supports reading and writing raw bytes. +pub struct RawPeerChannel { + id: Option, + remote: PeerId, + reader: R, + writer: W, +} + +impl RawPeerChannel { + /// Returns the id of this peer. + /// + /// Also see [`WebRtcSocket::id`]. + pub fn id(&self) -> Option { + self.id + } + + /// Returns the id of the remote peer to which this channel is connected. + pub fn remote(&self) -> PeerId { + self.remote + } +} + +impl AsyncRead for RawPeerChannel +where + Self: Unpin, + R: AsyncRead + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let mut reader = Pin::new(&mut self.get_mut().reader); + reader.as_mut().poll_read(cx, buf) + } +} + +impl AsyncWrite for RawPeerChannel +where + Self: Unpin, + W: AsyncWrite + Unpin, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let mut writer = Pin::new(&mut self.get_mut().writer); + writer.as_mut().poll_write(cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut writer = Pin::new(&mut self.get_mut().writer); + writer.as_mut().poll_flush(cx) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut writer = Pin::new(&mut self.get_mut().writer); + writer.as_mut().poll_close(cx) + } +} + /// Contains a set of [`WebRtcChannel`]s and connection metadata. #[derive(Debug)] pub struct WebRtcSocket { @@ -566,6 +687,30 @@ impl WebRtcSocket { .ok_or(ChannelError::Taken) } + /// Converts the [`WebRtcChannel`] into a [`RawPeerChannel`]. + pub fn take_raw( + &mut self, + ) -> Result, ChannelError> { + let remote = self + .connected_peers() + .next() + .ok_or(ChannelError::NotFound)?; + + let channel = self.take_channel(0)?; + let id = self.id(); + + let (reader, writer) = compat_read_write(remote, channel.rx, channel.tx); + + let peer_channel = RawPeerChannel { + id, + remote, + reader, + writer, + }; + + Ok(peer_channel) + } + /// Returns whether any socket channel is closed pub fn any_channel_closed(&self) -> bool { self.channels @@ -685,6 +830,26 @@ async fn run_socket( } } +fn compat_read_write( + remote: PeerId, + stream: UnboundedReceiver<(PeerId, Packet)>, + sink: UnboundedSender<(PeerId, Packet)>, +) -> (impl AsyncRead, impl AsyncWrite) { + let reader = stream + .then(|(_, packet)| ready(Ok::<_, std::io::Error>(packet))) + .into_async_read(); + + let writer = sink + .with(move |packet: Bytes| ready(Ok::<_, SendError>((remote, Box::from(packet.as_ref()))))); + + let writer = writer.sink_map_err(std::io::Error::other); + let writer = CopyToBytes::new(writer); + let writer = SinkWriter::new(writer); + let writer = TokioAsyncWriteCompatExt::compat_write(writer); + + (reader, writer) +} + #[cfg(test)] mod test { use crate::{ChannelConfig, Error, WebRtcSocketBuilder}; From 62c783c26ed46bd2cc40bcff78b706fa5bbd2fa0 Mon Sep 17 00:00:00 2001 From: th4s Date: Mon, 16 Dec 2024 19:31:24 +0100 Subject: [PATCH 2/3] fix: fix bug for taking channel --- matchbox_socket/src/webrtc_socket/socket.rs | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/matchbox_socket/src/webrtc_socket/socket.rs b/matchbox_socket/src/webrtc_socket/socket.rs index 09db2b70..ccb5ab8c 100644 --- a/matchbox_socket/src/webrtc_socket/socket.rs +++ b/matchbox_socket/src/webrtc_socket/socket.rs @@ -687,16 +687,22 @@ impl WebRtcSocket { .ok_or(ChannelError::Taken) } - /// Converts the [`WebRtcChannel`] into a [`RawPeerChannel`]. - pub fn take_raw( - &mut self, - ) -> Result, ChannelError> { - let remote = self + /// Takes the [`WebRtcChannel`] of a given [`PeerId`]. + pub fn take_channel_by_id(&mut self, id: PeerId) -> Result { + let pos = self .connected_peers() - .next() + .position(|peer_id| peer_id == id) .ok_or(ChannelError::NotFound)?; - let channel = self.take_channel(0)?; + self.take_channel(pos) + } + + /// Converts the [`WebRtcChannel`] of a given [`PeerId`] into a [`RawPeerChannel`]. + pub fn take_raw_by_id( + &mut self, + remote: PeerId, + ) -> Result, ChannelError> { + let channel = self.take_channel_by_id(remote)?; let id = self.id(); let (reader, writer) = compat_read_write(remote, channel.rx, channel.tx); From c1137ef93c729ab553e624c7b81019149e94fa79 Mon Sep 17 00:00:00 2001 From: th4s Date: Mon, 16 Dec 2024 20:16:45 +0100 Subject: [PATCH 3/3] fix(lint): make taplo happy --- matchbox_socket/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matchbox_socket/Cargo.toml b/matchbox_socket/Cargo.toml index e598b7f0..c2f5c287 100644 --- a/matchbox_socket/Cargo.toml +++ b/matchbox_socket/Cargo.toml @@ -40,7 +40,7 @@ once_cell = { version = "1.17", default-features = false, features = [ "alloc", ] } derive_more = { version = "1.0", features = ["display", "from"] } -tokio-util = { version = "0.7", features = ["io", "compat"]} +tokio-util = { version = "0.7", features = ["io", "compat"] } ggrs = { version = "0.10", default-features = false, optional = true } bincode = { version = "1.3", default-features = false, optional = true }