diff --git a/Cargo.lock b/Cargo.lock index a479eeb5..4adcda70 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3306,6 +3306,7 @@ dependencies = [ "serde-wasm-bindgen", "serde_json", "thiserror 2.0.6", + "tokio-util", "wasm-bindgen", "wasm-bindgen-futures", "web-sys", @@ -5014,6 +5015,7 @@ checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" dependencies = [ "bytes", "futures-core", + "futures-io", "futures-sink", "pin-project-lite", "tokio", diff --git a/matchbox_socket/Cargo.toml b/matchbox_socket/Cargo.toml index f421939a..c2f5c287 100644 --- a/matchbox_socket/Cargo.toml +++ b/matchbox_socket/Cargo.toml @@ -40,9 +40,11 @@ once_cell = { version = "1.17", default-features = false, features = [ "alloc", ] } derive_more = { version = "1.0", features = ["display", "from"] } +tokio-util = { version = "0.7", features = ["io", "compat"] } ggrs = { version = "0.10", default-features = false, optional = true } bincode = { version = "1.3", default-features = false, optional = true } +bytes = { version = "1.1", default-features = false } [target.'cfg(target_arch = "wasm32")'.dependencies] ggrs = { version = "0.10", default-features = false, optional = true, features = [ @@ -79,7 +81,6 @@ async-tungstenite = { version = "0.28", default-features = false, features = [ "async-tls", ] } webrtc = { version = "0.12", default-features = false } -bytes = { version = "1.1", default-features = false } async-compat = { version = "0.2", default-features = false } [dev-dependencies] diff --git a/matchbox_socket/src/webrtc_socket/socket.rs b/matchbox_socket/src/webrtc_socket/socket.rs index c125cd2d..ccb5ab8c 100644 --- a/matchbox_socket/src/webrtc_socket/socket.rs +++ b/matchbox_socket/src/webrtc_socket/socket.rs @@ -6,11 +6,19 @@ use crate::{ }, Error, }; -use futures::{future::Fuse, select, Future, FutureExt, StreamExt}; +use bytes::Bytes; +use futures::{ + future::Fuse, select, AsyncRead, AsyncWrite, Future, FutureExt, Sink, SinkExt, Stream, + StreamExt, TryStreamExt, +}; use futures_channel::mpsc::{SendError, TrySendError, UnboundedReceiver, UnboundedSender}; use log::{debug, error}; use matchbox_protocol::PeerId; -use std::{collections::HashMap, pin::Pin, time::Duration}; +use std::{collections::HashMap, future::ready, pin::Pin, task::Poll, time::Duration}; +use tokio_util::{ + compat::TokioAsyncWriteCompatExt, + io::{CopyToBytes, SinkWriter}, +}; /// Configuration options for an ICE server connection. /// See also: @@ -310,6 +318,119 @@ impl WebRtcChannel { } } +impl Stream for WebRtcChannel { + type Item = (PeerId, Packet); + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut rx = Pin::new(&mut self.get_mut().rx); + rx.as_mut().poll_next(cx) + } +} + +impl Sink<(PeerId, Packet)> for WebRtcChannel { + type Error = SendError; + + fn poll_ready( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut tx = Pin::new(&mut self.get_mut().tx); + tx.as_mut().poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: (PeerId, Packet)) -> Result<(), Self::Error> { + let mut tx = Pin::new(&mut self.get_mut().tx); + tx.as_mut().start_send(item) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut tx = Pin::new(&mut self.get_mut().tx); + tx.as_mut().poll_flush(cx) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut tx = Pin::new(&mut self.get_mut().tx); + tx.as_mut().poll_close(cx) + } +} + +/// A channel which supports reading and writing raw bytes. +pub struct RawPeerChannel { + id: Option, + remote: PeerId, + reader: R, + writer: W, +} + +impl RawPeerChannel { + /// Returns the id of this peer. + /// + /// Also see [`WebRtcSocket::id`]. + pub fn id(&self) -> Option { + self.id + } + + /// Returns the id of the remote peer to which this channel is connected. + pub fn remote(&self) -> PeerId { + self.remote + } +} + +impl AsyncRead for RawPeerChannel +where + Self: Unpin, + R: AsyncRead + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let mut reader = Pin::new(&mut self.get_mut().reader); + reader.as_mut().poll_read(cx, buf) + } +} + +impl AsyncWrite for RawPeerChannel +where + Self: Unpin, + W: AsyncWrite + Unpin, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let mut writer = Pin::new(&mut self.get_mut().writer); + writer.as_mut().poll_write(cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut writer = Pin::new(&mut self.get_mut().writer); + writer.as_mut().poll_flush(cx) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut writer = Pin::new(&mut self.get_mut().writer); + writer.as_mut().poll_close(cx) + } +} + /// Contains a set of [`WebRtcChannel`]s and connection metadata. #[derive(Debug)] pub struct WebRtcSocket { @@ -566,6 +687,36 @@ impl WebRtcSocket { .ok_or(ChannelError::Taken) } + /// Takes the [`WebRtcChannel`] of a given [`PeerId`]. + pub fn take_channel_by_id(&mut self, id: PeerId) -> Result { + let pos = self + .connected_peers() + .position(|peer_id| peer_id == id) + .ok_or(ChannelError::NotFound)?; + + self.take_channel(pos) + } + + /// Converts the [`WebRtcChannel`] of a given [`PeerId`] into a [`RawPeerChannel`]. + pub fn take_raw_by_id( + &mut self, + remote: PeerId, + ) -> Result, ChannelError> { + let channel = self.take_channel_by_id(remote)?; + let id = self.id(); + + let (reader, writer) = compat_read_write(remote, channel.rx, channel.tx); + + let peer_channel = RawPeerChannel { + id, + remote, + reader, + writer, + }; + + Ok(peer_channel) + } + /// Returns whether any socket channel is closed pub fn any_channel_closed(&self) -> bool { self.channels @@ -685,6 +836,26 @@ async fn run_socket( } } +fn compat_read_write( + remote: PeerId, + stream: UnboundedReceiver<(PeerId, Packet)>, + sink: UnboundedSender<(PeerId, Packet)>, +) -> (impl AsyncRead, impl AsyncWrite) { + let reader = stream + .then(|(_, packet)| ready(Ok::<_, std::io::Error>(packet))) + .into_async_read(); + + let writer = sink + .with(move |packet: Bytes| ready(Ok::<_, SendError>((remote, Box::from(packet.as_ref()))))); + + let writer = writer.sink_map_err(std::io::Error::other); + let writer = CopyToBytes::new(writer); + let writer = SinkWriter::new(writer); + let writer = TokioAsyncWriteCompatExt::compat_write(writer); + + (reader, writer) +} + #[cfg(test)] mod test { use crate::{ChannelConfig, Error, WebRtcSocketBuilder};