diff --git a/Cargo.toml b/Cargo.toml index 6af1eee..dde6af0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,15 +17,22 @@ async-trait = "0.1.52" atomic_enum = "0.2.0" base64 = "0.21.0" futures = "0.3.21" -http = "0.2.6" +http = "0.2.8" tokio = { version = "1.17.0", features = ["sync", "rt", "macros", "time"] } tracing = "0.1.31" url = "2.2.2" cfg-if = "1.0.0" -axum = { version = "0.6.1", features = ["ws"], optional = true } -tokio-tungstenite = { version = "0.18.0", optional = true } -tokio-rustls = { version = "0.23.4", optional = true } +axum = { version = "0.6.1", optional = true } +axum-core = { version = "0.3.0", optional = true } +bytes = { version = "1.3.0", optional = true } +futures-util = { version = "0.3.25", default-features = false, features = ["alloc"], optional = true } +http-body = { version = "0.4.5", optional = true } +hyper = { version = "0.14.23", optional = true } +sha-1 = { version = "0.10.1", optional = true } + +tokio-tungstenite = { version = "0.20.0", optional = true } +tokio-rustls = { version = "0.24.1", optional = true } tokio-native-tls = { version = "0.3.1", optional = true } [features] @@ -33,9 +40,11 @@ default = ["client", "server"] client = ["tokio-tungstenite"] -server = [] -tungstenite = ["server", "tokio-tungstenite"] -axum = ["server", "dep:axum"] +tungstenite_common = ["tokio-tungstenite"] + +server = ["tungstenite_common"] +tungstenite = ["server"] +axum = ["server", "dep:axum", "axum-core", "bytes", "futures-util", "http-body", "hyper", "sha-1"] tls = [] native-tls = ["tls", "tokio-native-tls", "tokio-tungstenite/native-tls"] diff --git a/benches/my_benchmark.rs b/benches/my_benchmark.rs index eb18ae2..57b8b16 100644 --- a/benches/my_benchmark.rs +++ b/benches/my_benchmark.rs @@ -18,9 +18,9 @@ fn bench(b: &mut Bencher, client: &mut WebSocket>) { let nonce = rng.gen::(); let text = format!("Hello {}", nonce); let message = tungstenite::Message::Text(text.clone()); - client.write_message(message).unwrap(); + client.send(message).unwrap(); - while let tungstenite::Message::Text(received_text) = client.read_message().unwrap() { + while let tungstenite::Message::Text(received_text) = client.read().unwrap() { return Some(received_text); } return None; diff --git a/benches/tungstenite_server.rs b/benches/tungstenite_server.rs index bd8efe3..c4a19c4 100644 --- a/benches/tungstenite_server.rs +++ b/benches/tungstenite_server.rs @@ -4,10 +4,10 @@ pub fn run(listener: std::net::TcpListener) { while let Ok((stream, _address)) = listener.accept() { let mut websocket = accept(stream).unwrap(); loop { - let message = websocket.read_message().unwrap(); + let message = websocket.read().unwrap(); // println!("server | msg: {:?}", &message); match message { - Message::Text(text) => websocket.write_message(Message::Text(text)).unwrap(), + Message::Text(text) => websocket.send(Message::Text(text)).unwrap(), Message::Binary(_) => todo!(), Message::Ping(_) => todo!(), Message::Pong(_) => todo!(), diff --git a/src/lib.rs b/src/lib.rs index ede9dff..f82a648 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ //! Refer to [`client`] or [`server`] module for detailed implementation guides. mod socket; +mod server_runners; pub use socket::CloseCode; pub use socket::CloseFrame; @@ -18,11 +19,7 @@ pub use socket::Socket; pub use socket::SocketConfig; pub use socket::Stream; -#[cfg(feature = "axum")] -pub mod axum; - -#[cfg(feature = "tokio-tungstenite")] -pub mod tungstenite; +pub use server_runners::*; cfg_if::cfg_if! { if #[cfg(feature = "client")] { diff --git a/src/axum.rs b/src/server_runners/axum.rs similarity index 75% rename from src/axum.rs rename to src/server_runners/axum.rs index 4660303..97ef409 100644 --- a/src/axum.rs +++ b/src/server_runners/axum.rs @@ -1,4 +1,4 @@ -//! `axum` feature must be enabled in order to use this module. +//! The `axum` feature must be enabled in order to use this module. //! //! ```no_run //! use async_trait::async_trait; @@ -59,15 +59,12 @@ //! ``` use crate::socket::SocketConfig; -use crate::CloseCode; -use crate::CloseFrame; -use crate::RawMessage; +use crate::server_runners::axum_tungstenite::WebSocketUpgrade; +use crate::server_runners::axum_tungstenite::rejection::*; use crate::Server; use crate::ServerExt; use crate::Socket; use async_trait::async_trait; -use axum::extract::ws; -use axum::extract::ws::rejection::*; use axum::extract::ConnectInfo; use axum::extract::FromRequest; use axum::response::Response; @@ -82,7 +79,7 @@ use std::net::SocketAddr; /// See the [module docs](self) for an example. #[derive(Debug)] pub struct Upgrade { - ws: ws::WebSocketUpgrade, + ws: WebSocketUpgrade, address: SocketAddr, request: crate::Request, } @@ -120,49 +117,17 @@ where pure_req = pure_req.header(k, v); } let Ok(pure_req) = pure_req.body(()) else { - return Err(ConnectionNotUpgradable::default().into()); + return Err(InvalidConnectionHeader{}.into()); }; Ok(Self { - ws: ws::WebSocketUpgrade::from_request(req, state).await?, + ws: WebSocketUpgrade::from_request(req, state).await?, address, request: pure_req, }) } } -impl From for RawMessage { - fn from(message: ws::Message) -> Self { - match message { - ws::Message::Text(text) => RawMessage::Text(text), - ws::Message::Binary(binary) => RawMessage::Binary(binary), - ws::Message::Ping(ping) => RawMessage::Ping(ping), - ws::Message::Pong(pong) => RawMessage::Pong(pong), - ws::Message::Close(Some(close)) => RawMessage::Close(Some(CloseFrame { - code: CloseCode::try_from(close.code).unwrap_or(CloseCode::Abnormal), - reason: close.reason.into(), - })), - ws::Message::Close(None) => RawMessage::Close(None), - } - } -} - -impl From for ws::Message { - fn from(message: RawMessage) -> Self { - match message { - RawMessage::Text(text) => ws::Message::Text(text), - RawMessage::Binary(binary) => ws::Message::Binary(binary), - RawMessage::Ping(ping) => ws::Message::Ping(ping), - RawMessage::Pong(pong) => ws::Message::Pong(pong), - RawMessage::Close(Some(close)) => ws::Message::Close(Some(ws::CloseFrame { - code: close.code.into(), - reason: close.reason.into(), - })), - RawMessage::Close(None) => ws::Message::Close(None), - } - } -} - impl Upgrade { /// Finalize upgrading the connection and call the provided callback with /// the stream. diff --git a/src/server_runners/axum_tungstenite.rs b/src/server_runners/axum_tungstenite.rs new file mode 100644 index 0000000..f3072dc --- /dev/null +++ b/src/server_runners/axum_tungstenite.rs @@ -0,0 +1,574 @@ +//! Forked from [axum-tungstenite](https://crates.io/crates/axum-tungstenite). +//! +//! This module implements an axum-based websockets layer on top of tungstenite, as an alternative to `axum::extract::ws` +//! which does not expose tungstenite types. + +#![deny(unreachable_pub, private_in_public)] +#![allow(clippy::type_complexity)] +#![forbid(unsafe_code)] +#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))] +#![cfg_attr(test, allow(clippy::float_cmp))] + +use self::rejection::*; +use async_trait::async_trait; +use axum_core::{ + extract::FromRequestParts, + response::{IntoResponse, Response}, +}; +use bytes::Bytes; +use futures_util::{ + sink::{Sink, SinkExt}, + stream::{Stream, StreamExt}, +}; +use http::{ + header::{self, HeaderMap, HeaderName, HeaderValue}, + request::Parts, + Method, StatusCode, +}; +use hyper::upgrade::{OnUpgrade, Upgraded}; +use sha1::{Digest, Sha1}; +use std::{ + borrow::Cow, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio_tungstenite::{ + tungstenite::protocol::{self, WebSocketConfig}, + WebSocketStream, +}; + +#[doc(no_inline)] +pub use tokio_tungstenite::tungstenite::error::{ + CapacityError, Error, ProtocolError, TlsError, UrlError, +}; +#[doc(no_inline)] +pub use tokio_tungstenite::tungstenite::Message; + +/// Extractor for establishing WebSocket connections. +/// +/// See the [module docs](self) for an example. +#[derive(Debug)] +pub struct WebSocketUpgrade { + config: WebSocketConfig, + /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. + protocol: Option, + sec_websocket_key: HeaderValue, + on_upgrade: OnUpgrade, + on_failed_upgrade: F, + sec_websocket_protocol: Option, +} + +impl WebSocketUpgrade { + /// The target minimum size of the write buffer to reach before writing the data + /// to the underlying stream. + /// + /// The default value is 128 KiB. + /// + /// If set to `0` each message will be eagerly written to the underlying stream. + /// It is often more optimal to allow them to buffer a little, hence the default value. + /// + /// Note: [`flush`](SinkExt::flush) will always fully write the buffer regardless. + pub fn write_buffer_size(mut self, size: usize) -> Self { + self.config.write_buffer_size = size; + self + } + + /// The max size of the write buffer in bytes. Setting this can provide backpressure + /// in the case the write buffer is filling up due to write errors. + /// + /// The default value is unlimited. + /// + /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size) + /// when writes to the underlying stream are failing. So the **write buffer can not + /// fill up if you are not observing write errors even if not flushing**. + /// + /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size) + /// and probably a little more depending on error handling strategy. + pub fn max_write_buffer_size(mut self, max: usize) -> Self { + self.config.max_write_buffer_size = max; + self + } + + /// Set the maximum message size (defaults to 64 megabytes) + pub fn max_message_size(mut self, max: usize) -> Self { + self.config.max_message_size = Some(max); + self + } + + /// Set the maximum frame size (defaults to 16 megabytes) + pub fn max_frame_size(mut self, max: usize) -> Self { + self.config.max_frame_size = Some(max); + self + } + + /// Allow server to accept unmasked frames (defaults to false) + pub fn accept_unmasked_frames(mut self, accept: bool) -> Self { + self.config.accept_unmasked_frames = accept; + self + } + + /// Set the known protocols. + /// + /// If the protocol name specified by `Sec-WebSocket-Protocol` header + /// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and + /// return the protocol name. + /// + /// The protocols should be listed in decreasing order of preference: if the client offers + /// multiple protocols that the server could support, the server will pick the first one in + /// this list. + pub fn protocols(mut self, protocols: I) -> Self + where + I: IntoIterator, + I::Item: Into>, + { + if let Some(req_protocols) = self + .sec_websocket_protocol + .as_ref() + .and_then(|p| p.to_str().ok()) + { + self.protocol = protocols + .into_iter() + .map(Into::into) + .find(|protocol| { + req_protocols + .split(',') + .any(|req_protocol| req_protocol.trim() == protocol) + }) + .map(|protocol| match protocol { + Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(), + Cow::Borrowed(s) => HeaderValue::from_static(s), + }); + } + + self + } + + /// Finalize upgrading the connection and call the provided callback with + /// the stream. + /// + /// When using `WebSocketUpgrade`, the response produced by this method + /// should be returned from the handler. See the [module docs](self) for an + /// example. + pub fn on_upgrade(self, callback: F) -> Response + where + F: FnOnce(WebSocket) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + C: OnFailedUpdgrade, + { + let on_upgrade = self.on_upgrade; + let config = self.config; + let on_failed_upgrade = self.on_failed_upgrade; + + let protocol = self.protocol.clone(); + + tokio::spawn(async move { + let upgraded = match on_upgrade.await { + Ok(upgraded) => upgraded, + Err(err) => { + on_failed_upgrade.call(err); + return; + } + }; + + let socket = + WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config)) + .await; + let socket = WebSocket { + inner: socket, + protocol, + }; + callback(socket).await; + }); + + #[allow(clippy::declare_interior_mutable_const)] + const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); + #[allow(clippy::declare_interior_mutable_const)] + const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); + + let mut headers = HeaderMap::new(); + headers.insert(header::CONNECTION, UPGRADE); + headers.insert(header::UPGRADE, WEBSOCKET); + headers.insert( + header::SEC_WEBSOCKET_ACCEPT, + sign(self.sec_websocket_key.as_bytes()), + ); + + if let Some(protocol) = self.protocol { + headers.insert(header::SEC_WEBSOCKET_PROTOCOL, protocol); + } + + (StatusCode::SWITCHING_PROTOCOLS, headers).into_response() + } + + /// Provide a callback to call if upgrading the connection fails. + /// + /// The connection upgrade is performed in a background task. If that fails this callback + /// will be called. + /// + /// By default any errors will be silently ignored. + /// + /// # Example + /// + /// ``` + /// use axum::response::Response; + /// use axum_tungstenite::WebSocketUpgrade; + /// + /// async fn handler(ws: WebSocketUpgrade) -> Response { + /// ws.on_failed_upgrade(|error| { + /// report_error(error); + /// }) + /// .on_upgrade(|socket| async { /* ... */ }) + /// } + /// # + /// # fn report_error(_: hyper::Error) {} + /// ``` + pub fn on_failed_upgrade(self, callback: C2) -> WebSocketUpgrade + where + C2: OnFailedUpdgrade, + { + WebSocketUpgrade { + config: self.config, + protocol: self.protocol, + sec_websocket_key: self.sec_websocket_key, + on_upgrade: self.on_upgrade, + on_failed_upgrade: callback, + sec_websocket_protocol: self.sec_websocket_protocol, + } + } +} + +#[async_trait] +impl FromRequestParts for WebSocketUpgrade +where + S: Sync, +{ + type Rejection = WebSocketUpgradeRejection; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + if parts.method != Method::GET { + return Err(MethodNotGet.into()); + } + + if !header_contains(parts, header::CONNECTION, "upgrade") { + return Err(InvalidConnectionHeader.into()); + } + + if !header_eq(parts, header::UPGRADE, "websocket") { + return Err(InvalidUpgradeHeader.into()); + } + + if !header_eq(parts, header::SEC_WEBSOCKET_VERSION, "13") { + return Err(InvalidWebSocketVersionHeader.into()); + } + + let sec_websocket_key = parts + .headers + .get(header::SEC_WEBSOCKET_KEY) + .ok_or(WebSocketKeyHeaderMissing)? + .clone(); + + let on_upgrade = parts + .extensions + .remove::() + .ok_or(ConnectionNotUpgradable)?; + + let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned(); + + Ok(Self { + config: Default::default(), + protocol: None, + sec_websocket_key, + on_upgrade, + on_failed_upgrade: DefaultOnFailedUpdgrade, + sec_websocket_protocol, + }) + } +} + +fn header_eq(req: &Parts, key: HeaderName, value: &'static str) -> bool { + if let Some(header) = req.headers.get(&key) { + header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) + } else { + false + } +} + +fn header_contains(req: &Parts, key: HeaderName, value: &'static str) -> bool { + let header = if let Some(header) = req.headers.get(&key) { + header + } else { + return false; + }; + + if let Ok(header) = std::str::from_utf8(header.as_bytes()) { + header.to_ascii_lowercase().contains(value) + } else { + false + } +} + +/// A stream of WebSocket messages. +#[derive(Debug)] +pub struct WebSocket { + inner: WebSocketStream, + protocol: Option, +} + +impl WebSocket { + /// Consume `self` and get the inner [`tokio_tungstenite::WebSocketStream`]. + pub fn into_inner(self) -> WebSocketStream { + self.inner + } + + /// Receive another message. + /// + /// Returns `None` if the stream has closed. + pub async fn recv(&mut self) -> Option> { + self.next().await + } + + /// Send a message. + pub async fn send(&mut self, msg: Message) -> Result<(), Error> { + self.inner.send(msg).await + } + + /// Gracefully close this WebSocket. + pub async fn close(mut self) -> Result<(), Error> { + self.inner.close(None).await + } + + /// Return the selected WebSocket subprotocol, if one has been chosen. + pub fn protocol(&self) -> Option<&HeaderValue> { + self.protocol.as_ref() + } +} + +impl Stream for WebSocket { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_next_unpin(cx) + } +} + +impl Sink for WebSocket { + type Error = Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + Pin::new(&mut self.inner).start_send(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_close(cx) + } +} + +fn sign(key: &[u8]) -> HeaderValue { + use base64::engine::Engine as _; + + let mut sha1 = Sha1::default(); + sha1.update(key); + sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]); + let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize())); + HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value") +} + +/// What to do when a connection upgrade fails. +/// +/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details. +pub trait OnFailedUpdgrade: Send + 'static { + /// Call the callback. + fn call(self, error: hyper::Error); +} + +impl OnFailedUpdgrade for F +where + F: FnOnce(hyper::Error) + Send + 'static, +{ + fn call(self, error: hyper::Error) { + self(error) + } +} + +/// The default `OnFailedUpdgrade` used by `WebSocketUpgrade`. +/// +/// It simply ignores the error. +#[non_exhaustive] +#[derive(Debug)] +pub struct DefaultOnFailedUpdgrade; + +impl OnFailedUpdgrade for DefaultOnFailedUpdgrade { + #[inline] + fn call(self, _error: hyper::Error) {} +} + +pub mod rejection { + //! WebSocket specific rejections. + + use super::*; + + macro_rules! define_rejection { + ( + #[status = $status:ident] + #[body = $body:expr] + $(#[$m:meta])* + pub struct $name:ident; + ) => { + $(#[$m])* + #[derive(Debug)] + #[non_exhaustive] + pub struct $name; + + impl IntoResponse for $name { + fn into_response(self) -> Response { + (http::StatusCode::$status, $body).into_response() + } + } + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", $body) + } + } + + impl std::error::Error for $name {} + + impl Default for $name { + fn default() -> Self { + Self + } + } + }; + } + + define_rejection! { + #[status = METHOD_NOT_ALLOWED] + #[body = "Request method must be `GET`"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct MethodNotGet; + } + + define_rejection! { + #[status = BAD_REQUEST] + #[body = "Connection header did not include 'upgrade'"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct InvalidConnectionHeader; + } + + define_rejection! { + #[status = BAD_REQUEST] + #[body = "`Upgrade` header did not include 'websocket'"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct InvalidUpgradeHeader; + } + + define_rejection! { + #[status = BAD_REQUEST] + #[body = "`Sec-WebSocket-Version` header did not include '13'"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct InvalidWebSocketVersionHeader; + } + + define_rejection! { + #[status = BAD_REQUEST] + #[body = "`Sec-WebSocket-Key` header missing"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct WebSocketKeyHeaderMissing; + } + + define_rejection! { + #[status = UPGRADE_REQUIRED] + #[body = "WebSocket request couldn't be upgraded since no upgrade state was present"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + /// + /// This rejection is returned if the connection cannot be upgraded for example if the + /// request is HTTP/1.0. + /// + /// See [MDN] for more details about connection upgrades. + /// + /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Upgrade + pub struct ConnectionNotUpgradable; + } + + macro_rules! composite_rejection { + ( + $(#[$m:meta])* + pub enum $name:ident { + $($variant:ident),+ + $(,)? + } + ) => { + $(#[$m])* + #[derive(Debug)] + #[non_exhaustive] + pub enum $name { + $( + #[allow(missing_docs)] + $variant($variant) + ),+ + } + + impl IntoResponse for $name { + fn into_response(self) -> Response { + match self { + $( + Self::$variant(inner) => inner.into_response(), + )+ + } + } + } + + $( + impl From<$variant> for $name { + fn from(inner: $variant) -> Self { + Self::$variant(inner) + } + } + )+ + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + $( + Self::$variant(inner) => write!(f, "{}", inner), + )+ + } + } + } + + impl std::error::Error for $name { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + $( + Self::$variant(inner) => Some(inner), + )+ + } + } + } + }; + } + + composite_rejection! { + /// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade). + /// + /// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade) + /// extractor can fail. + pub enum WebSocketUpgradeRejection { + MethodNotGet, + InvalidConnectionHeader, + InvalidUpgradeHeader, + InvalidWebSocketVersionHeader, + WebSocketKeyHeaderMissing, + ConnectionNotUpgradable, + } + } +} diff --git a/src/server_runners/mod.rs b/src/server_runners/mod.rs new file mode 100644 index 0000000..5fdecc3 --- /dev/null +++ b/src/server_runners/mod.rs @@ -0,0 +1,13 @@ + +cfg_if::cfg_if! { + if #[cfg(feature = "axum")] { + pub mod axum; + pub mod axum_tungstenite; + } +} + +#[cfg(feature = "tokio-tungstenite")] +pub mod tungstenite; + +#[cfg(feature = "tungstenite_common")] +pub mod tungstenite_common; diff --git a/src/server_runners/tungstenite.rs b/src/server_runners/tungstenite.rs new file mode 100644 index 0000000..937ae31 --- /dev/null +++ b/src/server_runners/tungstenite.rs @@ -0,0 +1,147 @@ +//! The `tungstenite` feature must be enabled in order to use this module. +//! +//! ```no_run +//! # use async_trait::async_trait; +//! # struct MySession {} +//! # #[async_trait::async_trait] +//! # impl ezsockets::SessionExt for MySession { +//! # type ID = u16; +//! # type Call = (); +//! # fn id(&self) -> &Self::ID { unimplemented!() } +//! # async fn on_text(&mut self, text: String) -> Result<(), ezsockets::Error> { unimplemented!() } +//! # async fn on_binary(&mut self, bytes: Vec) -> Result<(), ezsockets::Error> { unimplemented!() } +//! # async fn on_call(&mut self, call: Self::Call) -> Result<(), ezsockets::Error> { unimplemented!() } +//! # } +//! struct MyServer {} +//! +//! #[async_trait] +//! impl ezsockets::ServerExt for MyServer { +//! // ... +//! # type Session = MySession; +//! # type Call = (); +//! # async fn on_connect(&mut self, socket: ezsockets::Socket, request: ezsockets::Request, address: std::net::SocketAddr) -> Result, Option> { unimplemented!() } +//! # async fn on_disconnect(&mut self, id: ::ID, reason: Result, ezsockets::Error>) -> Result<(), ezsockets::Error> { unimplemented!() } +//! # async fn on_call(&mut self, call: Self::Call) -> Result<(), ezsockets::Error> { unimplemented!() } +//! } +//! +//! #[tokio::main] +//! async fn main() { +//! let (server, _) = ezsockets::Server::create(|_| MyServer {}); +//! ezsockets::tungstenite::run(server, "127.0.0.1:8080").await.unwrap(); +//! } +//! ``` + +use crate::Server; +use crate::Error; +use crate::Socket; +use crate::SocketConfig; +use crate::ServerExt; +use crate::Request; +use crate::tungstenite::tungstenite::handshake::server::ErrorResponse; + +use tokio_tungstenite::tungstenite; + +use tokio::net::TcpListener; +use tokio::net::ToSocketAddrs; +use tokio::net::TcpStream; + +pub enum Acceptor { + Plain, + #[cfg(feature = "native-tls")] + NativeTls(tokio_native_tls::TlsAcceptor), + #[cfg(feature = "rustls")] + Rustls(tokio_rustls::TlsAcceptor), +} + +impl Acceptor { + async fn accept(&self, stream: TcpStream) -> Result<(Socket, Request), Error> { + let mut req0 = None; + let callback = |req: &http::Request<()>, resp: http::Response<()>| -> Result, ErrorResponse> { + let mut req1 = Request::builder() + .method(req.method().clone()) + .uri(req.uri().clone()) + .version(req.version()); + for (k, v) in req.headers() { + req1 = req1.header(k, v); + } + let Ok(body) = req1.body(()) else { return Err(ErrorResponse::default()); }; + req0 = Some(body); + + Ok(resp) + }; + let socket = match self { + Acceptor::Plain => { + let socket = tokio_tungstenite::accept_hdr_async(stream, callback).await?; + Socket::new(socket, SocketConfig::default()) + } + #[cfg(feature = "native-tls")] + Acceptor::NativeTls(acceptor) => { + let tls_stream = acceptor.accept(stream).await?; + let socket = tokio_tungstenite::accept_hdr_async(tls_stream, callback).await?; + Socket::new(socket, SocketConfig::default()) + } + #[cfg(feature = "rustls")] + Acceptor::Rustls(acceptor) => { + let tls_stream = acceptor.accept(stream).await?; + let socket = tokio_tungstenite::accept_hdr_async(tls_stream, callback).await?; + Socket::new(socket, SocketConfig::default()) + } + }; + let Some(req_body) = req0 else { return Err("invalid request body".into()); }; + Ok((socket, req_body)) + } +} + +async fn run_acceptor( + server: Server, + listener: TcpListener, + acceptor: Acceptor, +) -> Result<(), Error> +where + E: ServerExt + 'static +{ + loop { + // TODO: Find a better way without those stupid matches + let (stream, address) = match listener.accept().await { + Ok(stream) => stream, + Err(err) => { + tracing::warn!("failed to accept tcp connection: {:?}", err); + continue; + }, + }; + let (socket, request) = match acceptor.accept(stream).await { + Ok(socket) => socket, + Err(err) => { + tracing::warn!(%address, "failed to accept websocket connection: {:?}", err); + continue; + } + }; + server.accept(socket, request, address); + } +} + +// Run the server +pub async fn run( + server: Server, + address: A, +) -> Result<(), Error> +where + E: ServerExt + 'static, + A: ToSocketAddrs, +{ + let listener = TcpListener::bind(address).await?; + run_acceptor(server, listener, Acceptor::Plain).await +} + +/// Run the server on custom `Listener` and `Acceptor` +/// For default acceptor use `Acceptor::plain` +pub async fn run_on( + server: Server, + listener: TcpListener, + acceptor: Acceptor, +) -> Result<(), Error> +where + E: ServerExt + 'static +{ + run_acceptor(server, listener, acceptor).await +} diff --git a/src/server_runners/tungstenite_common.rs b/src/server_runners/tungstenite_common.rs new file mode 100644 index 0000000..e0f1e4a --- /dev/null +++ b/src/server_runners/tungstenite_common.rs @@ -0,0 +1,103 @@ +//! The `tungstenite_common` feature must be enabled in order to use this module. +//! + +use crate::socket::RawMessage; +use crate::CloseCode; +use crate::CloseFrame; +use crate::Message; +use tokio_tungstenite::tungstenite; +use tungstenite::protocol::frame::coding::CloseCode as TungsteniteCloseCode; + +impl<'t> From> for CloseFrame { + fn from(frame: tungstenite::protocol::CloseFrame) -> Self { + Self { + code: frame.code.into(), + reason: frame.reason.into(), + } + } +} + +impl<'t> From for tungstenite::protocol::CloseFrame<'t> { + fn from(frame: CloseFrame) -> Self { + Self { + code: frame.code.into(), + reason: frame.reason.into(), + } + } +} + +impl From for TungsteniteCloseCode { + fn from(code: CloseCode) -> Self { + match code { + CloseCode::Normal => Self::Normal, + CloseCode::Away => Self::Away, + CloseCode::Protocol => Self::Protocol, + CloseCode::Unsupported => Self::Unsupported, + CloseCode::Status => Self::Status, + CloseCode::Abnormal => Self::Abnormal, + CloseCode::Invalid => Self::Invalid, + CloseCode::Policy => Self::Policy, + CloseCode::Size => Self::Size, + CloseCode::Extension => Self::Extension, + CloseCode::Error => Self::Error, + CloseCode::Restart => Self::Restart, + CloseCode::Again => Self::Again, + } + } +} + +impl From for CloseCode { + fn from(code: TungsteniteCloseCode) -> Self { + match code { + TungsteniteCloseCode::Normal => Self::Normal, + TungsteniteCloseCode::Away => Self::Away, + TungsteniteCloseCode::Protocol => Self::Protocol, + TungsteniteCloseCode::Unsupported => Self::Unsupported, + TungsteniteCloseCode::Status => Self::Status, + TungsteniteCloseCode::Abnormal => Self::Abnormal, + TungsteniteCloseCode::Invalid => Self::Invalid, + TungsteniteCloseCode::Policy => Self::Policy, + TungsteniteCloseCode::Size => Self::Size, + TungsteniteCloseCode::Extension => Self::Extension, + TungsteniteCloseCode::Error => Self::Error, + TungsteniteCloseCode::Restart => Self::Restart, + TungsteniteCloseCode::Again => Self::Again, + code => unimplemented!("could not handle close code: {code:?}"), + } + } +} + +impl From for tungstenite::Message { + fn from(message: RawMessage) -> Self { + match message { + RawMessage::Text(text) => Self::Text(text), + RawMessage::Binary(bytes) => Self::Binary(bytes), + RawMessage::Ping(bytes) => Self::Ping(bytes), + RawMessage::Pong(bytes) => Self::Pong(bytes), + RawMessage::Close(frame) => Self::Close(frame.map(CloseFrame::into)), + } + } +} + +impl From for RawMessage { + fn from(message: tungstenite::Message) -> Self { + match message { + tungstenite::Message::Text(text) => Self::Text(text), + tungstenite::Message::Binary(bytes) => Self::Binary(bytes), + tungstenite::Message::Ping(bytes) => Self::Ping(bytes), + tungstenite::Message::Pong(bytes) => Self::Pong(bytes), + tungstenite::Message::Close(frame) => Self::Close(frame.map(CloseFrame::from)), + tungstenite::Message::Frame(_) => unreachable!(), + } + } +} + +impl From for tungstenite::Message { + fn from(message: Message) -> Self { + match message { + Message::Text(text) => tungstenite::Message::Text(text), + Message::Binary(bytes) => tungstenite::Message::Binary(bytes), + Message::Close(frame) => tungstenite::Message::Close(frame.map(CloseFrame::into)), + } + } +} diff --git a/src/tungstenite.rs b/src/tungstenite.rs deleted file mode 100644 index d3c6c5a..0000000 --- a/src/tungstenite.rs +++ /dev/null @@ -1,249 +0,0 @@ -//! `tungstenite` feature must be enabled in order to use this module. -//! -//! ```no_run -//! # use async_trait::async_trait; -//! # struct MySession {} -//! # #[async_trait::async_trait] -//! # impl ezsockets::SessionExt for MySession { -//! # type ID = u16; -//! # type Call = (); -//! # fn id(&self) -> &Self::ID { unimplemented!() } -//! # async fn on_text(&mut self, text: String) -> Result<(), ezsockets::Error> { unimplemented!() } -//! # async fn on_binary(&mut self, bytes: Vec) -> Result<(), ezsockets::Error> { unimplemented!() } -//! # async fn on_call(&mut self, call: Self::Call) -> Result<(), ezsockets::Error> { unimplemented!() } -//! # } -//! struct MyServer {} -//! -//! #[async_trait] -//! impl ezsockets::ServerExt for MyServer { -//! // ... -//! # type Session = MySession; -//! # type Call = (); -//! # async fn on_connect(&mut self, socket: ezsockets::Socket, request: ezsockets::Request, address: std::net::SocketAddr) -> Result, Option> { unimplemented!() } -//! # async fn on_disconnect(&mut self, id: ::ID, reason: Result, ezsockets::Error>) -> Result<(), ezsockets::Error> { unimplemented!() } -//! # async fn on_call(&mut self, call: Self::Call) -> Result<(), ezsockets::Error> { unimplemented!() } -//! } -//! -//! #[tokio::main] -//! async fn main() { -//! let (server, _) = ezsockets::Server::create(|_| MyServer {}); -//! ezsockets::tungstenite::run(server, "127.0.0.1:8080").await.unwrap(); -//! } -//! ``` - -use crate::socket::{RawMessage, SocketConfig}; -use crate::tungstenite::tungstenite::handshake::server::ErrorResponse; -use crate::CloseCode; -use crate::CloseFrame; -use crate::Message; -use crate::Request; -use tokio_tungstenite::tungstenite; -use tungstenite::protocol::frame::coding::CloseCode as TungsteniteCloseCode; - -impl<'t> From> for CloseFrame { - fn from(frame: tungstenite::protocol::CloseFrame) -> Self { - Self { - code: frame.code.into(), - reason: frame.reason.into(), - } - } -} - -impl<'t> From for tungstenite::protocol::CloseFrame<'t> { - fn from(frame: CloseFrame) -> Self { - Self { - code: frame.code.into(), - reason: frame.reason.into(), - } - } -} - -impl From for TungsteniteCloseCode { - fn from(code: CloseCode) -> Self { - match code { - CloseCode::Normal => Self::Normal, - CloseCode::Away => Self::Away, - CloseCode::Protocol => Self::Protocol, - CloseCode::Unsupported => Self::Unsupported, - CloseCode::Status => Self::Status, - CloseCode::Abnormal => Self::Abnormal, - CloseCode::Invalid => Self::Invalid, - CloseCode::Policy => Self::Policy, - CloseCode::Size => Self::Size, - CloseCode::Extension => Self::Extension, - CloseCode::Error => Self::Error, - CloseCode::Restart => Self::Restart, - CloseCode::Again => Self::Again, - } - } -} - -impl From for CloseCode { - fn from(code: TungsteniteCloseCode) -> Self { - match code { - TungsteniteCloseCode::Normal => Self::Normal, - TungsteniteCloseCode::Away => Self::Away, - TungsteniteCloseCode::Protocol => Self::Protocol, - TungsteniteCloseCode::Unsupported => Self::Unsupported, - TungsteniteCloseCode::Status => Self::Status, - TungsteniteCloseCode::Abnormal => Self::Abnormal, - TungsteniteCloseCode::Invalid => Self::Invalid, - TungsteniteCloseCode::Policy => Self::Policy, - TungsteniteCloseCode::Size => Self::Size, - TungsteniteCloseCode::Extension => Self::Extension, - TungsteniteCloseCode::Error => Self::Error, - TungsteniteCloseCode::Restart => Self::Restart, - TungsteniteCloseCode::Again => Self::Again, - code => unimplemented!("could not handle close code: {code:?}"), - } - } -} - -impl From for tungstenite::Message { - fn from(message: RawMessage) -> Self { - match message { - RawMessage::Text(text) => Self::Text(text), - RawMessage::Binary(bytes) => Self::Binary(bytes), - RawMessage::Ping(bytes) => Self::Ping(bytes), - RawMessage::Pong(bytes) => Self::Pong(bytes), - RawMessage::Close(frame) => Self::Close(frame.map(CloseFrame::into)), - } - } -} - -impl From for RawMessage { - fn from(message: tungstenite::Message) -> Self { - match message { - tungstenite::Message::Text(text) => Self::Text(text), - tungstenite::Message::Binary(bytes) => Self::Binary(bytes), - tungstenite::Message::Ping(bytes) => Self::Ping(bytes), - tungstenite::Message::Pong(bytes) => Self::Pong(bytes), - tungstenite::Message::Close(frame) => Self::Close(frame.map(CloseFrame::from)), - tungstenite::Message::Frame(_) => unreachable!(), - } - } -} - -impl From for tungstenite::Message { - fn from(message: Message) -> Self { - match message { - Message::Text(text) => tungstenite::Message::Text(text), - Message::Binary(bytes) => tungstenite::Message::Binary(bytes), - Message::Close(frame) => tungstenite::Message::Close(frame.map(CloseFrame::into)), - } - } -} - -cfg_if::cfg_if! { - if #[cfg(feature = "server")] { - use crate::Server; - use crate::Error; - use crate::Socket; - use crate::ServerExt; - - use tokio::net::TcpListener; - use tokio::net::ToSocketAddrs; - use tokio::net::TcpStream; - - pub enum Acceptor { - Plain, - #[cfg(feature = "native-tls")] - NativeTls(tokio_native_tls::TlsAcceptor), - #[cfg(feature = "rustls")] - Rustls(tokio_rustls::TlsAcceptor), - } - - impl Acceptor { - async fn accept(&self, stream: TcpStream) -> Result<(Socket, Request), Error> { - let mut req0 = None; - let callback = |req: &http::Request<()>, resp: http::Response<()>| -> Result, ErrorResponse> { - let mut req1 = Request::builder() - .method(req.method().clone()) - .uri(req.uri().clone()) - .version(req.version()); - for (k, v) in req.headers() { - req1 = req1.header(k, v); - } - let Ok(body) = req1.body(()) else { return Err(ErrorResponse::default()); }; - req0 = Some(body); - - Ok(resp) - }; - let socket = match self { - Acceptor::Plain => { - let socket = tokio_tungstenite::accept_hdr_async(stream, callback).await?; - Socket::new(socket, SocketConfig::default()) - } - #[cfg(feature = "native-tls")] - Acceptor::NativeTls(acceptor) => { - let tls_stream = acceptor.accept(stream).await?; - let socket = tokio_tungstenite::accept_hdr_async(tls_stream, callback).await?; - Socket::new(socket, SocketConfig::default()) - } - #[cfg(feature = "rustls")] - Acceptor::Rustls(acceptor) => { - let tls_stream = acceptor.accept(stream).await?; - let socket = tokio_tungstenite::accept_hdr_async(tls_stream, callback).await?; - Socket::new(socket, SocketConfig::default()) - } - }; - let Some(req_body) = req0 else { return Err("invalid request body".into()); }; - Ok((socket, req_body)) - } - } - - async fn run_acceptor( - server: Server, - listener: TcpListener, - acceptor: Acceptor, - ) -> Result<(), Error> - where - E: ServerExt + 'static - { - loop { - // TODO: Find a better way without those stupid matches - let (stream, address) = match listener.accept().await { - Ok(stream) => stream, - Err(err) => { - tracing::warn!("failed to accept tcp connection: {:?}", err); - continue; - }, - }; - let (socket, request) = match acceptor.accept(stream).await { - Ok(socket) => socket, - Err(err) => { - tracing::warn!(%address, "failed to accept websocket connection: {:?}", err); - continue; - } - }; - server.accept(socket, request, address); - } - } - - // Run the server - pub async fn run( - server: Server, - address: A, - ) -> Result<(), Error> - where - E: ServerExt + 'static, - A: ToSocketAddrs, - { - let listener = TcpListener::bind(address).await?; - run_acceptor(server, listener, Acceptor::Plain).await - } - - /// Run the server on custom `Listener` and `Acceptor` - /// For default acceptor use `Acceptor::plain` - pub async fn run_on( - server: Server, - listener: TcpListener, - acceptor: Acceptor, - ) -> Result<(), Error> - where - E: ServerExt + 'static - { - run_acceptor(server, listener, acceptor).await - } - } -}