Skip to content

Commit

Permalink
use in-repo axum_tungstenite intead of axum::extract::ws for unified …
Browse files Browse the repository at this point in the history
…message/error types
  • Loading branch information
UkoeHB committed Oct 4, 2023
1 parent 0cc591e commit 055417a
Show file tree
Hide file tree
Showing 11 changed files with 871 additions and 307 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Add `ClientConfig::socket_config()` setter so clients can define their socket's config.
- Add `ezsockets::axum::Upgrade::on_upgrade_with_config()` that accepts a `SocketConfig`.
- Refactor `ezeockets::client::connect()` to use a retry loop for the initial connection. Add `max_initial_connect_attempts` and `max_reconnect_attempts` options to the `ClientConfig` (they default to 'infinite').
- Move `axum` and `tungstenite` server runners into new submodule `src/server_runners`.
- Update to `tokio-tungstenite` v0.20.0.
- Fork [axum-tungstenite](https://crates.io/crates/axum-tungstenite) crate into `src/server_runners` and refactor the `axum` runner to use that instead of `axum::extract::ws`.


Migration guide:
Expand Down
23 changes: 16 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,34 @@ async-trait = "0.1.52"
atomic_enum = "0.2.0"
base64 = "0.21.0"
futures = "0.3.21"
http = "0.2.6"
http = "0.2.8"
tokio = { version = "1.17.0", features = ["sync", "rt", "macros", "time"] }
tracing = "0.1.31"
url = "2.2.2"
cfg-if = "1.0.0"

axum = { version = "0.6.1", features = ["ws"], optional = true }
tokio-tungstenite = { version = "0.18.0", optional = true }
tokio-rustls = { version = "0.23.4", optional = true }
axum = { version = "0.6.1", optional = true }
axum-core = { version = "0.3.0", optional = true }
bytes = { version = "1.3.0", optional = true }
futures-util = { version = "0.3.25", default-features = false, features = ["alloc"], optional = true }
http-body = { version = "0.4.5", optional = true }
hyper = { version = "0.14.23", optional = true }
sha-1 = { version = "0.10.1", optional = true }

tokio-tungstenite = { version = "0.20.0", optional = true }
tokio-rustls = { version = "0.24.1", optional = true }
tokio-native-tls = { version = "0.3.1", optional = true }

[features]
default = ["client", "server"]

client = ["tokio-tungstenite"]

server = []
tungstenite = ["server", "tokio-tungstenite"]
axum = ["server", "dep:axum"]
tungstenite_common = ["tokio-tungstenite"]

server = ["tungstenite_common"]
tungstenite = ["server"]
axum = ["server", "dep:axum", "axum-core", "bytes", "futures-util", "http-body", "hyper", "sha-1"]

tls = []
native-tls = ["tls", "tokio-native-tls", "tokio-tungstenite/native-tls"]
Expand Down
4 changes: 2 additions & 2 deletions benches/my_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ fn bench(b: &mut Bencher, client: &mut WebSocket<MaybeTlsStream<TcpStream>>) {
let nonce = rng.gen::<u32>();
let text = format!("Hello {}", nonce);
let message = tungstenite::Message::Text(text.clone());
client.write_message(message).unwrap();
client.send(message).unwrap();

while let tungstenite::Message::Text(received_text) = client.read_message().unwrap() {
while let tungstenite::Message::Text(received_text) = client.read().unwrap() {
return Some(received_text);
}
return None;
Expand Down
4 changes: 2 additions & 2 deletions benches/tungstenite_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ pub fn run(listener: std::net::TcpListener) {
while let Ok((stream, _address)) = listener.accept() {
let mut websocket = accept(stream).unwrap();
loop {
let message = websocket.read_message().unwrap();
let message = websocket.read().unwrap();
// println!("server | msg: {:?}", &message);
match message {
Message::Text(text) => websocket.write_message(Message::Text(text)).unwrap(),
Message::Text(text) => websocket.send(Message::Text(text)).unwrap(),
Message::Binary(_) => todo!(),
Message::Ping(_) => todo!(),
Message::Pong(_) => todo!(),
Expand Down
9 changes: 3 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
//!
//! Refer to [`client`] or [`server`] module for detailed implementation guides.
mod server_runners;
mod socket;

pub use server_runners::*;

pub use socket::CloseCode;
pub use socket::CloseFrame;
pub use socket::Message;
Expand All @@ -18,12 +21,6 @@ pub use socket::Socket;
pub use socket::SocketConfig;
pub use socket::Stream;

#[cfg(feature = "axum")]
pub mod axum;

#[cfg(feature = "tokio-tungstenite")]
pub mod tungstenite;

cfg_if::cfg_if! {
if #[cfg(feature = "client")] {
pub mod client;
Expand Down
47 changes: 6 additions & 41 deletions src/axum.rs → src/server_runners/axum.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! `axum` feature must be enabled in order to use this module.
//! The `axum` feature must be enabled in order to use this module.
//!
//! ```no_run
//! use async_trait::async_trait;
Expand Down Expand Up @@ -58,16 +58,13 @@
//! }
//! ```
use crate::server_runners::axum_tungstenite::rejection::*;
use crate::server_runners::axum_tungstenite::WebSocketUpgrade;
use crate::socket::SocketConfig;
use crate::CloseCode;
use crate::CloseFrame;
use crate::RawMessage;
use crate::Server;
use crate::ServerExt;
use crate::Socket;
use async_trait::async_trait;
use axum::extract::ws;
use axum::extract::ws::rejection::*;
use axum::extract::ConnectInfo;
use axum::extract::FromRequest;
use axum::response::Response;
Expand All @@ -82,7 +79,7 @@ use std::net::SocketAddr;
/// See the [module docs](self) for an example.
#[derive(Debug)]
pub struct Upgrade {
ws: ws::WebSocketUpgrade,
ws: WebSocketUpgrade,
address: SocketAddr,
request: crate::Request,
}
Expand Down Expand Up @@ -120,49 +117,17 @@ where
pure_req = pure_req.header(k, v);
}
let Ok(pure_req) = pure_req.body(()) else {
return Err(ConnectionNotUpgradable::default().into());
return Err(InvalidConnectionHeader {}.into());
};

Ok(Self {
ws: ws::WebSocketUpgrade::from_request(req, state).await?,
ws: WebSocketUpgrade::from_request(req, state).await?,
address,
request: pure_req,
})
}
}

impl From<ws::Message> for RawMessage {
fn from(message: ws::Message) -> Self {
match message {
ws::Message::Text(text) => RawMessage::Text(text),
ws::Message::Binary(binary) => RawMessage::Binary(binary),
ws::Message::Ping(ping) => RawMessage::Ping(ping),
ws::Message::Pong(pong) => RawMessage::Pong(pong),
ws::Message::Close(Some(close)) => RawMessage::Close(Some(CloseFrame {
code: CloseCode::try_from(close.code).unwrap_or(CloseCode::Abnormal),
reason: close.reason.into(),
})),
ws::Message::Close(None) => RawMessage::Close(None),
}
}
}

impl From<RawMessage> for ws::Message {
fn from(message: RawMessage) -> Self {
match message {
RawMessage::Text(text) => ws::Message::Text(text),
RawMessage::Binary(binary) => ws::Message::Binary(binary),
RawMessage::Ping(ping) => ws::Message::Ping(ping),
RawMessage::Pong(pong) => ws::Message::Pong(pong),
RawMessage::Close(Some(close)) => ws::Message::Close(Some(ws::CloseFrame {
code: close.code.into(),
reason: close.reason.into(),
})),
RawMessage::Close(None) => ws::Message::Close(None),
}
}
}

impl Upgrade {
/// Finalize upgrading the connection and call the provided callback with
/// the stream.
Expand Down
Loading

0 comments on commit 055417a

Please sign in to comment.