diff --git a/Cargo.toml b/Cargo.toml index 66330f2..df56c46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,9 +40,10 @@ tokio-rustls = { version = "0.24.1", optional = true } tokio-native-tls = { version = "0.3.1", optional = true } [features] -default = ["client", "server"] +default = ["native_client", "server"] -client = ["tokio-tungstenite"] +client = ["tokio-tungstenite-wasm"] +native_client = ["client", "tokio-tungstenite"] tungstenite_common = ["tokio-tungstenite"] diff --git a/src/client.rs b/src/client.rs index 769ddbc..0756dc1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -52,13 +52,15 @@ use crate::socket::{InMessage, MessageSignal, SocketConfig}; use crate::CloseFrame; use crate::Error; use crate::Message; +use crate::RawMessage; use crate::Request; use crate::Socket; use async_trait::async_trait; use base64::Engine; -use enfync::TryAdopt; +use enfync::Handle; use http::header::HeaderName; use http::HeaderValue; +use futures::{SinkExt, StreamExt}; use std::fmt; use std::future::Future; use std::time::Duration; @@ -250,6 +252,34 @@ pub trait ClientExt: Send { } } +/// Abstract interface used by clients to connect to servers. +/// +/// The connector must expose a handle representing the runtime that the client will run on. The runtime should +/// be compatible with the connection method (e.g. `tokio` for `tokio_tungstenite::connect()`, +/// `wasm_bindgen_futures::spawn_local()` for a WASM connector, etc.). +#[async_trait] +pub trait ClientConnector { + type Handle: enfync::Handle; + type Message: Into + From + std::fmt::Debug + Send + 'static; + type WSError: std::error::Error + Into; + type Socket: + SinkExt + + StreamExt> + + Unpin + + Send + + 'static; + type ConnectError: std::error::Error + Send; + + /// Get the connector's runtime handle. + fn handle(&self) -> Self::Handle; + + /// Connect to a websocket server. + /// + /// Returns `Err` if the request is invalid. + async fn connect(&self, request: Request) -> Result; +} + +/// An `ezsockets` client. #[derive(Debug)] pub struct Client { to_socket_sender: mpsc::UnboundedSender, @@ -345,24 +375,24 @@ impl Client { } /// Connect to a websocket server using the default client connector. -/// - Requires a tokio runtime. +/// - Requires feature `native_client`. +/// - May only be invoked from within a tokio runtime. +#[cfg(feature = "native_client")] pub async fn connect( client_fn: impl FnOnce(Client) -> E, config: ClientConfig, ) -> (Client, impl Future>) { - let client_connector = enfync::builtin::native::TokioHandle::try_adopt() - .expect("ezsockets::client::connect() only works with tokio runtimes; use connect_with() instead"); + let client_connector = crate::client_connector_tokio::ClientConnectorTokio::default(); let (handle, mut future) = connect_with(client_fn, config, client_connector); let future = async move { future.extract().await.unwrap_or(Err("client actor crashed".into())) }; (handle, future) } /// Connect to a websocket server with the provided client connector. -/// - TODO: add ClientConnector trait (currently uses default client connector: tokio-tungstenite in tokio runtime) pub fn connect_with( client_fn: impl FnOnce(Client) -> E, config: ClientConfig, - client_connector: impl enfync::Handle + Send + Sync + 'static, + client_connector: impl ClientConnector + Send + Sync + 'static, ) -> (Client, enfync::PendingResult>) { let (to_socket_sender, mut to_socket_receiver) = mpsc::unbounded_channel(); let (client_call_sender, client_call_receiver) = mpsc::unbounded_channel(); @@ -371,13 +401,13 @@ pub fn connect_with( client_call_sender, }; let mut client = client_fn(handle.clone()); - let client_connector_clone = client_connector.clone(); - let future = client_connector.spawn(async move { + let runtime_handle = client_connector.handle(); + let future = runtime_handle.spawn(async move { tracing::info!("connecting to {}...", config.url); let Some(socket) = client_connect( config.max_initial_connect_attempts, &config, - client_connector_clone.clone(), + &client_connector, &mut to_socket_receiver, &mut client, ) @@ -393,7 +423,7 @@ pub fn connect_with( client_call_receiver, socket, config, - client_connector: client_connector_clone, + client_connector, }; actor.run().await?; Ok(()) @@ -401,7 +431,7 @@ pub fn connect_with( (handle, future) } -struct ClientActor { +struct ClientActor { client: E, to_socket_receiver: mpsc::UnboundedReceiver, client_call_receiver: mpsc::UnboundedReceiver, @@ -410,7 +440,7 @@ struct ClientActor { client_connector: C, } -impl ClientActor { +impl ClientActor { async fn run(&mut self) -> Result<(), Error> { loop { tokio::select! { @@ -460,7 +490,7 @@ impl ClientActor { let Some(socket) = client_connect( self.config.max_reconnect_attempts, &self.config, - self.client_connector.clone(), + &self.client_connector, &mut self.to_socket_receiver, &mut self.client, ).await? else { @@ -485,7 +515,7 @@ impl ClientActor { let Some(socket) = client_connect( self.config.max_reconnect_attempts, &self.config, - self.client_connector.clone(), + &self.client_connector, &mut self.to_socket_receiver, &mut self.client, ).await? else { @@ -507,10 +537,10 @@ impl ClientActor { } /// Returns Ok(Some(socket)) if connecting succeeded, Ok(None) if the client closed itself, and `Err` if an error occurred. -async fn client_connect( +async fn client_connect( max_attempts: usize, config: &ClientConfig, - client_connector: impl enfync::Handle, + client_connector: &Connector, to_socket_receiver: &mut mpsc::UnboundedReceiver, client: &mut E, ) -> Result, Error> { @@ -518,9 +548,9 @@ async fn client_connect( // connection attempt tracing::info!("connecting attempt no: {}...", i); let connect_http_request = config.connect_http_request(); - let result = tokio_tungstenite::connect_async(connect_http_request).await; //todo: ClientConnector::connect() + let result = client_connector.connect(connect_http_request).await; match result { - Ok((socket_impl, _)) => { + Ok(socket_impl) => { tracing::info!("successfully connected"); if let Err(err) = client.on_connect().await { tracing::error!("calling on_connect() failed due to {}, closing client", err); @@ -529,7 +559,7 @@ async fn client_connect( let socket = Socket::new( socket_impl, config.socket_config.clone().unwrap_or_default(), - client_connector, + client_connector.handle(), ); return Ok(Some(socket)); } diff --git a/src/client_connectors/client_connector_tokio.rs b/src/client_connectors/client_connector_tokio.rs new file mode 100644 index 0000000..9e7de44 --- /dev/null +++ b/src/client_connectors/client_connector_tokio.rs @@ -0,0 +1,46 @@ +use crate::client::ClientConnector; +use crate::Request; +use enfync::TryAdopt; +use tokio_tungstenite::tungstenite; + +/// Implementation of [`ClientConnector`] for tokio runtimes. +#[derive(Clone)] +pub struct ClientConnectorTokio { + handle: enfync::builtin::native::TokioHandle, +} + +impl ClientConnectorTokio { + pub fn new(handle: tokio::runtime::Handle) -> Self { + Self { handle: handle.into() } + } +} + +impl Default for ClientConnectorTokio { + fn default() -> Self { + let handle = enfync::builtin::native::TokioHandle::try_adopt() + .expect("ClientConnectorTokio::default() only works inside a tokio runtime; use ClientConnectorTokionew() instead"); + Self { handle } + } +} + +#[async_trait::async_trait] +impl ClientConnector for ClientConnectorTokio { + type Handle = enfync::builtin::native::TokioHandle; + type Message = tungstenite::Message; + type WSError = tungstenite::error::Error; + type Socket = tokio_tungstenite::WebSocketStream>; + type ConnectError = tungstenite::error::Error; + + /// Get the connector's runtime handle. + fn handle(&self) -> Self::Handle { + self.handle.clone() + } + + /// Connect to a websocket server. + /// + /// Returns `Err` if the request is invalid. + async fn connect(&self, request: Request) -> Result { + let (socket, _) = tokio_tungstenite::connect_async(request).await?; + Ok(socket) + } +} diff --git a/src/client_connectors/mod.rs b/src/client_connectors/mod.rs new file mode 100644 index 0000000..b7914f2 --- /dev/null +++ b/src/client_connectors/mod.rs @@ -0,0 +1,6 @@ + +#[cfg(feature = "native_client")] +pub mod client_connector_tokio; + +//#[cfg(feature = "wasm_client")] +//pub mod client_connector_wasm; diff --git a/src/lib.rs b/src/lib.rs index 2604f00..30f352b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,9 +5,11 @@ //! //! Refer to [`client`] or [`server`] module for detailed implementation guides. +mod client_connectors; mod server_runners; mod socket; +pub use client_connectors::*; pub use server_runners::*; pub use socket::CloseCode;