Skip to content

Commit

Permalink
Hyper 1/Axum 0.7 support (#33)
Browse files Browse the repository at this point in the history
* Support for Hyper v1/Axum 0.7

* Working text streamer

* CSV stream fix

* JSON support

* Protobuf support

* Fixed tests

* Updated pipelines
  • Loading branch information
abdolence authored Dec 8, 2023
1 parent a81bf93 commit 7eb9f69
Show file tree
Hide file tree
Showing 14 changed files with 76 additions and 98 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ on:
workflow_dispatch:
push:
pull_request:
types: [opened]
types: [opened]
concurrency:
group: ${{ github.workflow }}-${{ github.ref_protected && github.run_id || github.event.pull_request.number || github.ref }}
cancel-in-progress: true
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # v4
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand Down
9 changes: 5 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ name = "axum_streams"
path = "src/lib.rs"

[dependencies]
axum = { version = "0.6" }
axum = { version = "0.7" }
bytes = "1"
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
http = "0.2"
http = "1"
http-body = "1"
mime = "0.3"
tokio = "1"
serde = { version = "1", features = ["serde_derive"], optional = true }
Expand All @@ -44,10 +45,10 @@ text = []

[dev-dependencies]
futures = "0.3"
hyper = "0.14"
hyper = "1"
reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "multipart"] }
tower = { version = "0.4", default-features = false, features = ["util", "make"] }
tower-http = { version = "0.4", features = ["util", "map-response-body"] }
tower-http = { version = "0.5", features = ["util", "map-response-body"] }
tower-layer = "0.3"
tower-service = "0.3"
tokio = { version = "1", features = ["full"] }
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ and want to avoid huge memory allocation.
Cargo.toml:
```toml
[dependencies]
axum-streams = { version = "0.10", features=["json", "csv", "protobuf", "text"] }
axum-streams = { version = "0.11", features=["json", "csv", "protobuf", "text"] }
```

## Compatibility matrix

| axum | axum-streams |
|------|--------------|
| 0.7 | v0.11 |
| 0.6 | v0.9-v0.10 |
| 0.5 | 0.7 |

Expand Down
17 changes: 8 additions & 9 deletions examples/csv-example.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
use axum::response::IntoResponse;
use axum::routing::*;
use axum::Router;
use std::net::SocketAddr;

use futures::prelude::*;
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use tokio_stream::StreamExt;

use axum_streams::*;

#[derive(Debug, Clone, Deserialize, Serialize)]
struct MyTestStructure {
some_test_field: String,
some_test_field1: String,
some_test_field2: String,
}

fn source_test_stream() -> impl Stream<Item = MyTestStructure> {
// Simulating a stream with a plain vector and throttling to show how it works
stream::iter(vec![
MyTestStructure {
some_test_field: "test1".to_string()
some_test_field1: "test1".to_string(),
some_test_field2: "test2".to_string()
};
1000
])
.throttle(std::time::Duration::from_millis(50))
.throttle(std::time::Duration::from_millis(500))
}

async fn test_csv_stream() -> impl IntoResponse {
Expand All @@ -45,10 +47,7 @@ async fn main() {
// `GET /` goes to `root`
.route("/csv-stream", get(test_csv_stream));

let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();

axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
axum::serve(listener, app).await.unwrap();
}
9 changes: 3 additions & 6 deletions examples/json-example.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use axum::response::IntoResponse;
use axum::routing::*;
use axum::Router;
use std::net::SocketAddr;

use futures::prelude::*;
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use tokio_stream::StreamExt;

use axum_streams::*;
Expand Down Expand Up @@ -41,10 +41,7 @@ async fn main() {
.route("/json-array-stream", get(test_json_array_stream))
.route("/json-nl-stream", get(test_json_nl_stream));

let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();

axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
axum::serve(listener, app).await.unwrap();
}
10 changes: 3 additions & 7 deletions examples/protobuf-example.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use axum::response::IntoResponse;
use axum::routing::*;
use axum::Router;
use std::net::SocketAddr;

use futures::prelude::*;
use tokio::net::TcpListener;
use tokio_stream::StreamExt;

use axum_streams::*;
Expand Down Expand Up @@ -36,10 +36,6 @@ async fn main() {
// `GET /` goes to `root`
.route("/protobuf-stream", get(test_protobuf_stream));

let addr = SocketAddr::from(([127, 0, 0, 1], 8080));

axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();
axum::serve(listener, app).await.unwrap();
}
9 changes: 3 additions & 6 deletions examples/text-example.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use axum::response::IntoResponse;
use axum::routing::*;
use axum::Router;
use std::net::SocketAddr;

use futures::prelude::*;
use tokio::net::TcpListener;
use tokio_stream::StreamExt;

use axum_streams::*;
Expand All @@ -28,10 +28,7 @@ async fn main() {
// `GET /` goes to `root`
.route("/text-stream", get(test_text_stream));

let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();

axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
axum::serve(listener, app).await.unwrap();
}
8 changes: 5 additions & 3 deletions src/csv_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use futures::Stream;
use futures_util::stream::BoxStream;
use futures_util::StreamExt;
use http::HeaderMap;
use http_body::Frame;
use serde::Serialize;

pub struct CsvStreamFormat {
Expand Down Expand Up @@ -96,7 +97,7 @@ where
fn to_bytes_stream<'a, 'b>(
&'a self,
stream: BoxStream<'b, T>,
) -> BoxStream<'b, Result<axum::body::Bytes, axum::Error>> {
) -> BoxStream<'b, Result<Frame<axum::body::Bytes>, axum::Error>> {
let stream_with_header = self.has_headers;
let stream_delimiter = self.delimiter;
let stream_flexible = self.flexible;
Expand All @@ -106,7 +107,7 @@ where
let stream_escape = self.escape;
let terminator = self.terminator;

let stream_bytes: BoxStream<Result<axum::body::Bytes, axum::Error>> = Box::pin({
let stream_bytes: BoxStream<Result<Frame<axum::body::Bytes>, axum::Error>> = Box::pin({
stream.enumerate().map(move |(index, obj)| {
let mut writer = csv::WriterBuilder::new()
.has_headers(index == 0 && stream_with_header)
Expand All @@ -125,6 +126,7 @@ where
.into_inner()
.map_err(axum::Error::new)
.map(axum::body::Bytes::from)
.map(Frame::data)
})
});

Expand Down Expand Up @@ -188,7 +190,7 @@ mod tests {
}),
);

let client = TestClient::new(app);
let client = TestClient::new(app).await;

let expected_csv = test_stream_vec
.iter()
Expand Down
25 changes: 13 additions & 12 deletions src/json_formats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use futures::Stream;
use futures_util::stream::BoxStream;
use futures_util::StreamExt;
use http::HeaderMap;
use http_body::Frame;
use serde::Serialize;
use std::io::Write;

Expand All @@ -22,8 +23,8 @@ where
fn to_bytes_stream<'a, 'b>(
&'a self,
stream: BoxStream<'b, T>,
) -> BoxStream<'b, Result<axum::body::Bytes, axum::Error>> {
let stream_bytes: BoxStream<Result<axum::body::Bytes, axum::Error>> = Box::pin({
) -> BoxStream<'b, Result<Frame<axum::body::Bytes>, axum::Error>> {
let stream_bytes: BoxStream<Result<Frame<axum::body::Bytes>, axum::Error>> = Box::pin({
stream.enumerate().map(|(index, obj)| {
let mut buf = BytesMut::new().writer();

Expand All @@ -37,7 +38,7 @@ where
match sep_write_res {
Ok(_) => {
match serde_json::to_writer(&mut buf, &obj).map_err(axum::Error::new) {
Ok(_) => Ok(buf.into_inner().freeze()),
Ok(_) => Ok(Frame::data(buf.into_inner().freeze())),
Err(e) => Err(e),
}
}
Expand All @@ -46,14 +47,14 @@ where
})
});

let prepend_stream: BoxStream<Result<axum::body::Bytes, axum::Error>> =
let prepend_stream: BoxStream<Result<Frame<axum::body::Bytes>, axum::Error>> =
Box::pin(futures_util::stream::once(futures_util::future::ready(
Ok::<_, axum::Error>(axum::body::Bytes::from(JSON_ARRAY_BEGIN_BYTES)),
Ok::<_, axum::Error>(Frame::data(axum::body::Bytes::from(JSON_ARRAY_BEGIN_BYTES))),
)));

let append_stream: BoxStream<Result<axum::body::Bytes, axum::Error>> =
let append_stream: BoxStream<Result<Frame<axum::body::Bytes>, axum::Error>> =
Box::pin(futures_util::stream::once(futures_util::future::ready(
Ok::<_, axum::Error>(axum::body::Bytes::from(JSON_ARRAY_END_BYTES)),
Ok::<_, axum::Error>(Frame::data(axum::body::Bytes::from(JSON_ARRAY_END_BYTES))),
)));

Box::pin(prepend_stream.chain(stream_bytes.chain(append_stream)))
Expand Down Expand Up @@ -84,13 +85,13 @@ where
fn to_bytes_stream<'a, 'b>(
&'a self,
stream: BoxStream<'b, T>,
) -> BoxStream<'b, Result<axum::body::Bytes, axum::Error>> {
let stream_bytes: BoxStream<Result<axum::body::Bytes, axum::Error>> = Box::pin({
) -> BoxStream<'b, Result<Frame<axum::body::Bytes>, axum::Error>> {
let stream_bytes: BoxStream<Result<Frame<axum::body::Bytes>, axum::Error>> = Box::pin({
stream.map(|obj| {
let mut buf = BytesMut::new().writer();
match serde_json::to_writer(&mut buf, &obj).map_err(axum::Error::new) {
Ok(_) => match buf.write_all(JSON_NL_SEP_BYTES).map_err(axum::Error::new) {
Ok(_) => Ok(buf.into_inner().freeze()),
Ok(_) => Ok(Frame::data(buf.into_inner().freeze())),
Err(e) => Err(e),
},
Err(e) => Err(e),
Expand Down Expand Up @@ -164,7 +165,7 @@ mod tests {
get(|| async { StreamBodyAs::new(JsonArrayStreamFormat::new(), test_stream) }),
);

let client = TestClient::new(app);
let client = TestClient::new(app).await;

let expected_json = serde_json::to_string(&test_stream_vec).unwrap();
let res = client.get("/").send().await.unwrap();
Expand Down Expand Up @@ -201,7 +202,7 @@ mod tests {
get(|| async { StreamBodyAs::new(JsonNewLineStreamFormat::new(), test_stream) }),
);

let client = TestClient::new(app);
let client = TestClient::new(app).await;

let expected_json = test_stream_vec
.iter()
Expand Down
11 changes: 7 additions & 4 deletions src/protobuf_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use futures::Stream;
use futures_util::stream::BoxStream;
use futures_util::StreamExt;
use http::HeaderMap;
use http_body::Frame;

pub struct ProtobufStreamFormat;

Expand All @@ -19,7 +20,7 @@ where
fn to_bytes_stream<'a, 'b>(
&'a self,
stream: BoxStream<'b, T>,
) -> BoxStream<'b, Result<axum::body::Bytes, axum::Error>> {
) -> BoxStream<'b, Result<Frame<axum::body::Bytes>, axum::Error>> {
fn write_protobuf_record<T>(obj: T) -> Result<Vec<u8>, axum::Error>
where
T: prost::Message,
Expand All @@ -33,10 +34,12 @@ where
Ok(frame_vec)
}

let stream_bytes: BoxStream<Result<axum::body::Bytes, axum::Error>> = Box::pin({
let stream_bytes: BoxStream<Result<Frame<axum::body::Bytes>, axum::Error>> = Box::pin({
stream.map(move |obj| {
let write_protobuf_res = write_protobuf_record(obj);
write_protobuf_res.map(axum::body::Bytes::from)
write_protobuf_res
.map(axum::body::Bytes::from)
.map(Frame::data)
})
});

Expand Down Expand Up @@ -97,7 +100,7 @@ mod tests {
get(|| async { StreamBodyAs::new(ProtobufStreamFormat::new(), test_stream) }),
);

let client = TestClient::new(app);
let client = TestClient::new(app).await;

let expected_proto_buf: Vec<u8> = test_stream_vec
.iter()
Expand Down
18 changes: 6 additions & 12 deletions src/stream_body_as.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use crate::stream_format::StreamingFormat;
use axum::body::HttpBody;
use axum::body::{Body, HttpBody};
use axum::response::{IntoResponse, Response};
use futures::Stream;
use futures_util::stream::BoxStream;
use http::HeaderMap;
use http_body::Frame;
use std::fmt::Formatter;
use std::pin::Pin;
use std::task::{Context, Poll};

pub struct StreamBodyAs<'a> {
stream: BoxStream<'a, Result<axum::body::Bytes, axum::Error>>,
stream: BoxStream<'a, Result<Frame<axum::body::Bytes>, axum::Error>>,
trailers: Option<HeaderMap>,
}

Expand Down Expand Up @@ -46,7 +47,7 @@ impl IntoResponse for StreamBodyAs<'static> {
HeaderMap::new()
};

let mut response = Response::new(axum::body::boxed(self));
let mut response: Response<Body> = Response::new(Body::new(self));
*response.headers_mut() = headers;
response
}
Expand All @@ -56,17 +57,10 @@ impl<'a> HttpBody for StreamBodyAs<'a> {
type Data = axum::body::Bytes;
type Error = axum::Error;

fn poll_data(
fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
Pin::new(&mut self.stream).poll_next(cx)
}

fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
Poll::Ready(Ok(self.trailers.clone()))
}
}
Loading

0 comments on commit 7eb9f69

Please sign in to comment.