Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add async api #461

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion matchbox_socket/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ once_cell = { version = "1.17", default-features = false, features = [
"alloc",
] }
derive_more = { version = "1.0", features = ["display", "from"] }
tokio-util = { version = "0.7", features = ["io", "compat"] }

ggrs = { version = "0.10", default-features = false, optional = true }
bincode = { version = "1.3", default-features = false, optional = true }
bytes = { version = "1.1", default-features = false }

[target.'cfg(target_arch = "wasm32")'.dependencies]
ggrs = { version = "0.10", default-features = false, optional = true, features = [
Expand Down Expand Up @@ -79,7 +81,6 @@ async-tungstenite = { version = "0.28", default-features = false, features = [
"async-tls",
] }
webrtc = { version = "0.12", default-features = false }
bytes = { version = "1.1", default-features = false }
async-compat = { version = "0.2", default-features = false }

[dev-dependencies]
Expand Down
175 changes: 173 additions & 2 deletions matchbox_socket/src/webrtc_socket/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,19 @@ use crate::{
},
Error,
};
use futures::{future::Fuse, select, Future, FutureExt, StreamExt};
use bytes::Bytes;
use futures::{
future::Fuse, select, AsyncRead, AsyncWrite, Future, FutureExt, Sink, SinkExt, Stream,
StreamExt, TryStreamExt,
};
use futures_channel::mpsc::{SendError, TrySendError, UnboundedReceiver, UnboundedSender};
use log::{debug, error};
use matchbox_protocol::PeerId;
use std::{collections::HashMap, pin::Pin, time::Duration};
use std::{collections::HashMap, future::ready, pin::Pin, task::Poll, time::Duration};
use tokio_util::{
compat::TokioAsyncWriteCompatExt,
io::{CopyToBytes, SinkWriter},
};

/// Configuration options for an ICE server connection.
/// See also: <https://developer.mozilla.org/en-US/docs/Web/API/RTCIceServer#example>
Expand Down Expand Up @@ -310,6 +318,119 @@ impl WebRtcChannel {
}
}

impl Stream for WebRtcChannel {
type Item = (PeerId, Packet);

fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let mut rx = Pin::new(&mut self.get_mut().rx);
rx.as_mut().poll_next(cx)
}
}

impl Sink<(PeerId, Packet)> for WebRtcChannel {
type Error = SendError;

fn poll_ready(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let mut tx = Pin::new(&mut self.get_mut().tx);
tx.as_mut().poll_ready(cx)
}

fn start_send(self: Pin<&mut Self>, item: (PeerId, Packet)) -> Result<(), Self::Error> {
let mut tx = Pin::new(&mut self.get_mut().tx);
tx.as_mut().start_send(item)
}

fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let mut tx = Pin::new(&mut self.get_mut().tx);
tx.as_mut().poll_flush(cx)
}

fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let mut tx = Pin::new(&mut self.get_mut().tx);
tx.as_mut().poll_close(cx)
}
}

/// A channel which supports reading and writing raw bytes.
pub struct RawPeerChannel<R, W> {
id: Option<PeerId>,
remote: PeerId,
reader: R,
writer: W,
}

impl<R, W> RawPeerChannel<R, W> {
/// Returns the id of this peer.
///
/// Also see [`WebRtcSocket::id`].
pub fn id(&self) -> Option<PeerId> {
self.id
}

/// Returns the id of the remote peer to which this channel is connected.
pub fn remote(&self) -> PeerId {
self.remote
}
}

impl<R, W> AsyncRead for RawPeerChannel<R, W>
where
Self: Unpin,
R: AsyncRead + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let mut reader = Pin::new(&mut self.get_mut().reader);
reader.as_mut().poll_read(cx, buf)
}
}

impl<R, W> AsyncWrite for RawPeerChannel<R, W>
where
Self: Unpin,
W: AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let mut writer = Pin::new(&mut self.get_mut().writer);
writer.as_mut().poll_write(cx, buf)
}

fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
let mut writer = Pin::new(&mut self.get_mut().writer);
writer.as_mut().poll_flush(cx)
}

fn poll_close(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
let mut writer = Pin::new(&mut self.get_mut().writer);
writer.as_mut().poll_close(cx)
}
}

/// Contains a set of [`WebRtcChannel`]s and connection metadata.
#[derive(Debug)]
pub struct WebRtcSocket {
Expand Down Expand Up @@ -566,6 +687,36 @@ impl WebRtcSocket {
.ok_or(ChannelError::Taken)
}

/// Takes the [`WebRtcChannel`] of a given [`PeerId`].
pub fn take_channel_by_id(&mut self, id: PeerId) -> Result<WebRtcChannel, ChannelError> {
let pos = self
.connected_peers()
.position(|peer_id| peer_id == id)
.ok_or(ChannelError::NotFound)?;

self.take_channel(pos)
}

/// Converts the [`WebRtcChannel`] of a given [`PeerId`] into a [`RawPeerChannel`].
pub fn take_raw_by_id(
&mut self,
remote: PeerId,
) -> Result<RawPeerChannel<impl AsyncRead, impl AsyncWrite>, ChannelError> {
let channel = self.take_channel_by_id(remote)?;
let id = self.id();

let (reader, writer) = compat_read_write(remote, channel.rx, channel.tx);

let peer_channel = RawPeerChannel {
id,
remote,
reader,
writer,
};

Ok(peer_channel)
}

/// Returns whether any socket channel is closed
pub fn any_channel_closed(&self) -> bool {
self.channels
Expand Down Expand Up @@ -685,6 +836,26 @@ async fn run_socket(
}
}

fn compat_read_write(
remote: PeerId,
stream: UnboundedReceiver<(PeerId, Packet)>,
sink: UnboundedSender<(PeerId, Packet)>,
) -> (impl AsyncRead, impl AsyncWrite) {
let reader = stream
.then(|(_, packet)| ready(Ok::<_, std::io::Error>(packet)))
.into_async_read();

let writer = sink
.with(move |packet: Bytes| ready(Ok::<_, SendError>((remote, Box::from(packet.as_ref())))));

let writer = writer.sink_map_err(std::io::Error::other);
let writer = CopyToBytes::new(writer);
let writer = SinkWriter::new(writer);
let writer = TokioAsyncWriteCompatExt::compat_write(writer);

(reader, writer)
}

#[cfg(test)]
mod test {
use crate::{ChannelConfig, Error, WebRtcSocketBuilder};
Expand Down
Loading