Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace axum with axum_tungstenite #84

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Add `ClientConfig::socket_config()` setter so clients can define their socket's config.
- Add `ezsockets::axum::Upgrade::on_upgrade_with_config()` that accepts a `SocketConfig`.
- Refactor `ezeockets::client::connect()` to use a retry loop for the initial connection. Add `max_initial_connect_attempts` and `max_reconnect_attempts` options to the `ClientConfig` (they default to 'infinite').
- Move `axum` and `tungstenite` server runners into new submodule `src/server_runners`.
- Update to `tokio-tungstenite` v0.20.0.
- Fork [axum-tungstenite](https://crates.io/crates/axum-tungstenite) crate into `src/server_runners` and refactor the `axum` runner to use that instead of `axum::extract::ws`.


Migration guide:
Expand Down
23 changes: 16 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,34 @@ async-trait = "0.1.52"
atomic_enum = "0.2.0"
base64 = "0.21.0"
futures = "0.3.21"
http = "0.2.6"
http = "0.2.8"
tokio = { version = "1.17.0", features = ["sync", "rt", "macros", "time"] }
tracing = "0.1.31"
url = "2.2.2"
cfg-if = "1.0.0"

axum = { version = "0.6.1", features = ["ws"], optional = true }
tokio-tungstenite = { version = "0.18.0", optional = true }
tokio-rustls = { version = "0.23.4", optional = true }
axum = { version = "0.6.1", optional = true }
axum-core = { version = "0.3.0", optional = true }
bytes = { version = "1.3.0", optional = true }
futures-util = { version = "0.3.25", default-features = false, features = ["alloc"], optional = true }
http-body = { version = "0.4.5", optional = true }
hyper = { version = "0.14.23", optional = true }
sha-1 = { version = "0.10.1", optional = true }

tokio-tungstenite = { version = "0.20.0", optional = true }
tokio-rustls = { version = "0.24.1", optional = true }
tokio-native-tls = { version = "0.3.1", optional = true }

[features]
default = ["client", "server"]

client = ["tokio-tungstenite"]

server = []
tungstenite = ["server", "tokio-tungstenite"]
axum = ["server", "dep:axum"]
tungstenite_common = ["tokio-tungstenite"]

server = ["tungstenite_common"]
tungstenite = ["server"]
axum = ["server", "dep:axum", "axum-core", "bytes", "futures-util", "http-body", "hyper", "sha-1"]

tls = []
native-tls = ["tls", "tokio-native-tls", "tokio-tungstenite/native-tls"]
Expand Down
4 changes: 2 additions & 2 deletions benches/my_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ fn bench(b: &mut Bencher, client: &mut WebSocket<MaybeTlsStream<TcpStream>>) {
let nonce = rng.gen::<u32>();
let text = format!("Hello {}", nonce);
let message = tungstenite::Message::Text(text.clone());
client.write_message(message).unwrap();
client.send(message).unwrap();

while let tungstenite::Message::Text(received_text) = client.read_message().unwrap() {
while let tungstenite::Message::Text(received_text) = client.read().unwrap() {
return Some(received_text);
}
return None;
Expand Down
4 changes: 2 additions & 2 deletions benches/tungstenite_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ pub fn run(listener: std::net::TcpListener) {
while let Ok((stream, _address)) = listener.accept() {
let mut websocket = accept(stream).unwrap();
loop {
let message = websocket.read_message().unwrap();
let message = websocket.read().unwrap();
// println!("server | msg: {:?}", &message);
match message {
Message::Text(text) => websocket.write_message(Message::Text(text)).unwrap(),
Message::Text(text) => websocket.send(Message::Text(text)).unwrap(),
Message::Binary(_) => todo!(),
Message::Ping(_) => todo!(),
Message::Pong(_) => todo!(),
Expand Down
9 changes: 3 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
//!
//! Refer to [`client`] or [`server`] module for detailed implementation guides.

mod server_runners;
mod socket;

pub use server_runners::*;

pub use socket::CloseCode;
pub use socket::CloseFrame;
pub use socket::Message;
Expand All @@ -18,12 +21,6 @@ pub use socket::Socket;
pub use socket::SocketConfig;
pub use socket::Stream;

#[cfg(feature = "axum")]
pub mod axum;

#[cfg(feature = "tokio-tungstenite")]
pub mod tungstenite;

cfg_if::cfg_if! {
if #[cfg(feature = "client")] {
pub mod client;
Expand Down
47 changes: 6 additions & 41 deletions src/axum.rs → src/server_runners/axum.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! `axum` feature must be enabled in order to use this module.
//! The `axum` feature must be enabled in order to use this module.
//!
//! ```no_run
//! use async_trait::async_trait;
Expand Down Expand Up @@ -58,16 +58,13 @@
//! }
//! ```

use crate::server_runners::axum_tungstenite::rejection::*;
use crate::server_runners::axum_tungstenite::WebSocketUpgrade;
use crate::socket::SocketConfig;
use crate::CloseCode;
use crate::CloseFrame;
use crate::RawMessage;
use crate::Server;
use crate::ServerExt;
use crate::Socket;
use async_trait::async_trait;
use axum::extract::ws;
use axum::extract::ws::rejection::*;
use axum::extract::ConnectInfo;
use axum::extract::FromRequest;
use axum::response::Response;
Expand All @@ -82,7 +79,7 @@ use std::net::SocketAddr;
/// See the [module docs](self) for an example.
#[derive(Debug)]
pub struct Upgrade {
ws: ws::WebSocketUpgrade,
ws: WebSocketUpgrade,
address: SocketAddr,
request: crate::Request,
}
Expand Down Expand Up @@ -120,49 +117,17 @@ where
pure_req = pure_req.header(k, v);
}
let Ok(pure_req) = pure_req.body(()) else {
return Err(ConnectionNotUpgradable::default().into());
return Err(InvalidConnectionHeader {}.into());
};

Ok(Self {
ws: ws::WebSocketUpgrade::from_request(req, state).await?,
ws: WebSocketUpgrade::from_request(req, state).await?,
address,
request: pure_req,
})
}
}

impl From<ws::Message> for RawMessage {
fn from(message: ws::Message) -> Self {
match message {
ws::Message::Text(text) => RawMessage::Text(text),
ws::Message::Binary(binary) => RawMessage::Binary(binary),
ws::Message::Ping(ping) => RawMessage::Ping(ping),
ws::Message::Pong(pong) => RawMessage::Pong(pong),
ws::Message::Close(Some(close)) => RawMessage::Close(Some(CloseFrame {
code: CloseCode::try_from(close.code).unwrap_or(CloseCode::Abnormal),
reason: close.reason.into(),
})),
ws::Message::Close(None) => RawMessage::Close(None),
}
}
}

impl From<RawMessage> for ws::Message {
fn from(message: RawMessage) -> Self {
match message {
RawMessage::Text(text) => ws::Message::Text(text),
RawMessage::Binary(binary) => ws::Message::Binary(binary),
RawMessage::Ping(ping) => ws::Message::Ping(ping),
RawMessage::Pong(pong) => ws::Message::Pong(pong),
RawMessage::Close(Some(close)) => ws::Message::Close(Some(ws::CloseFrame {
code: close.code.into(),
reason: close.reason.into(),
})),
RawMessage::Close(None) => ws::Message::Close(None),
}
}
}

impl Upgrade {
/// Finalize upgrading the connection and call the provided callback with
/// the stream.
Expand Down
Loading
Loading