Skip to content

Commit

Permalink
Merge pull request #461 from th4s/async-channels
Browse files Browse the repository at this point in the history
feat: add async api
  • Loading branch information
johanhelsing authored Dec 17, 2024
2 parents 98339b6 + c1137ef commit 750c9d9
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 3 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion matchbox_socket/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ once_cell = { version = "1.17", default-features = false, features = [
"alloc",
] }
derive_more = { version = "1.0", features = ["display", "from"] }
tokio-util = { version = "0.7", features = ["io", "compat"] }

ggrs = { version = "0.10", default-features = false, optional = true }
bincode = { version = "1.3", default-features = false, optional = true }
bytes = { version = "1.1", default-features = false }

[target.'cfg(target_arch = "wasm32")'.dependencies]
ggrs = { version = "0.10", default-features = false, optional = true, features = [
Expand Down Expand Up @@ -79,7 +81,6 @@ async-tungstenite = { version = "0.28", default-features = false, features = [
"async-tls",
] }
webrtc = { version = "0.12", default-features = false }
bytes = { version = "1.1", default-features = false }
async-compat = { version = "0.2", default-features = false }

[dev-dependencies]
Expand Down
175 changes: 173 additions & 2 deletions matchbox_socket/src/webrtc_socket/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,19 @@ use crate::{
},
Error,
};
use futures::{future::Fuse, select, Future, FutureExt, StreamExt};
use bytes::Bytes;
use futures::{
future::Fuse, select, AsyncRead, AsyncWrite, Future, FutureExt, Sink, SinkExt, Stream,
StreamExt, TryStreamExt,
};
use futures_channel::mpsc::{SendError, TrySendError, UnboundedReceiver, UnboundedSender};
use log::{debug, error};
use matchbox_protocol::PeerId;
use std::{collections::HashMap, pin::Pin, time::Duration};
use std::{collections::HashMap, future::ready, pin::Pin, task::Poll, time::Duration};
use tokio_util::{
compat::TokioAsyncWriteCompatExt,
io::{CopyToBytes, SinkWriter},
};

/// Configuration options for an ICE server connection.
/// See also: <https://developer.mozilla.org/en-US/docs/Web/API/RTCIceServer#example>
Expand Down Expand Up @@ -310,6 +318,119 @@ impl WebRtcChannel {
}
}

impl Stream for WebRtcChannel {
type Item = (PeerId, Packet);

fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let mut rx = Pin::new(&mut self.get_mut().rx);
rx.as_mut().poll_next(cx)
}
}

impl Sink<(PeerId, Packet)> for WebRtcChannel {
type Error = SendError;

fn poll_ready(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let mut tx = Pin::new(&mut self.get_mut().tx);
tx.as_mut().poll_ready(cx)
}

fn start_send(self: Pin<&mut Self>, item: (PeerId, Packet)) -> Result<(), Self::Error> {
let mut tx = Pin::new(&mut self.get_mut().tx);
tx.as_mut().start_send(item)
}

fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let mut tx = Pin::new(&mut self.get_mut().tx);
tx.as_mut().poll_flush(cx)
}

fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let mut tx = Pin::new(&mut self.get_mut().tx);
tx.as_mut().poll_close(cx)
}
}

/// A channel which supports reading and writing raw bytes.
pub struct RawPeerChannel<R, W> {
id: Option<PeerId>,
remote: PeerId,
reader: R,
writer: W,
}

impl<R, W> RawPeerChannel<R, W> {
/// Returns the id of this peer.
///
/// Also see [`WebRtcSocket::id`].
pub fn id(&self) -> Option<PeerId> {
self.id
}

/// Returns the id of the remote peer to which this channel is connected.
pub fn remote(&self) -> PeerId {
self.remote
}
}

impl<R, W> AsyncRead for RawPeerChannel<R, W>
where
Self: Unpin,
R: AsyncRead + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let mut reader = Pin::new(&mut self.get_mut().reader);
reader.as_mut().poll_read(cx, buf)
}
}

impl<R, W> AsyncWrite for RawPeerChannel<R, W>
where
Self: Unpin,
W: AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let mut writer = Pin::new(&mut self.get_mut().writer);
writer.as_mut().poll_write(cx, buf)
}

fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
let mut writer = Pin::new(&mut self.get_mut().writer);
writer.as_mut().poll_flush(cx)
}

fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
let mut writer = Pin::new(&mut self.get_mut().writer);
writer.as_mut().poll_close(cx)
}
}

/// Contains a set of [`WebRtcChannel`]s and connection metadata.
#[derive(Debug)]
pub struct WebRtcSocket {
Expand Down Expand Up @@ -566,6 +687,36 @@ impl WebRtcSocket {
.ok_or(ChannelError::Taken)
}

/// Takes the [`WebRtcChannel`] of a given [`PeerId`].
pub fn take_channel_by_id(&mut self, id: PeerId) -> Result<WebRtcChannel, ChannelError> {
let pos = self
.connected_peers()
.position(|peer_id| peer_id == id)
.ok_or(ChannelError::NotFound)?;

self.take_channel(pos)
}

/// Converts the [`WebRtcChannel`] of a given [`PeerId`] into a [`RawPeerChannel`].
pub fn take_raw_by_id(
&mut self,
remote: PeerId,
) -> Result<RawPeerChannel<impl AsyncRead, impl AsyncWrite>, ChannelError> {
let channel = self.take_channel_by_id(remote)?;
let id = self.id();

let (reader, writer) = compat_read_write(remote, channel.rx, channel.tx);

let peer_channel = RawPeerChannel {
id,
remote,
reader,
writer,
};

Ok(peer_channel)
}

/// Returns whether any socket channel is closed
pub fn any_channel_closed(&self) -> bool {
self.channels
Expand Down Expand Up @@ -685,6 +836,26 @@ async fn run_socket(
}
}

fn compat_read_write(
remote: PeerId,
stream: UnboundedReceiver<(PeerId, Packet)>,
sink: UnboundedSender<(PeerId, Packet)>,
) -> (impl AsyncRead, impl AsyncWrite) {
let reader = stream
.then(|(_, packet)| ready(Ok::<_, std::io::Error>(packet)))
.into_async_read();

let writer = sink
.with(move |packet: Bytes| ready(Ok::<_, SendError>((remote, Box::from(packet.as_ref())))));

let writer = writer.sink_map_err(std::io::Error::other);
let writer = CopyToBytes::new(writer);
let writer = SinkWriter::new(writer);
let writer = TokioAsyncWriteCompatExt::compat_write(writer);

(reader, writer)
}

#[cfg(test)]
mod test {
use crate::{ChannelConfig, Error, WebRtcSocketBuilder};
Expand Down

0 comments on commit 750c9d9

Please sign in to comment.