diff --git a/src/bao_slice_decoder.rs b/src/bao_slice_decoder.rs index 7df23f2..bb33e20 100644 --- a/src/bao_slice_decoder.rs +++ b/src/bao_slice_decoder.rs @@ -505,7 +505,7 @@ impl Read for SliceDecoder { } #[derive(Debug)] -pub struct AsyncSliceDecoder { +pub(crate) struct AsyncSliceDecoder { inner: SliceValidator, current_item: Option, } diff --git a/src/get.rs b/src/get.rs index b8f76d6..c5253ca 100644 --- a/src/get.rs +++ b/src/get.rs @@ -1,4 +1,5 @@ use std::fmt::Debug; +use std::io; use std::net::SocketAddr; use std::time::{Duration, Instant}; @@ -9,7 +10,7 @@ use postcard::experimental::max_size::MaxSize; use s2n_quic::stream::ReceiveStream; use s2n_quic::Connection; use s2n_quic::{client::Connect, Client}; -use tokio::io::AsyncRead; +use tokio::io::{AsyncRead, ReadBuf}; use tracing::debug; use crate::bao_slice_decoder::AsyncSliceDecoder; @@ -66,6 +67,37 @@ pub struct Stats { pub mbits: f64, } +/// A verified stream of data coming from the provider +/// +/// We guarantee that the data is correct by incrementally verifying a hash +#[repr(transparent)] +#[derive(Debug)] +pub struct DataStream(AsyncSliceDecoder); + +impl DataStream { + fn new(inner: ReceiveStream, hash: bao::Hash) -> Self { + DataStream(AsyncSliceDecoder::new(inner, hash, 0, u64::MAX)) + } + + async fn read_size(&mut self) -> io::Result { + self.0.read_size().await + } + + fn into_inner(self) -> ReceiveStream { + self.0.into_inner() + } +} + +impl AsyncRead for DataStream { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.0).poll_read(cx, buf) + } +} + pub async fn run( hash: bao::Hash, token: AuthToken, @@ -79,8 +111,8 @@ where FutA: Future>, B: FnMut(Collection) -> FutB, FutB: Future>, - C: FnMut(bao::Hash, AsyncSliceDecoder, Option) -> FutC, - FutC: Future>>, + C: FnMut(bao::Hash, DataStream, Option) -> FutC, + FutC: Future>, { let now = Instant::now(); let (_client, mut connection) = setup(opts).await?; @@ -208,13 +240,11 @@ where /// /// Returns an `AsyncReader` /// The `AsyncReader` can be used to read the content. -async fn handle_blob_response< - R: AsyncRead + futures::io::AsyncRead + Send + Sync + Unpin + 'static, ->( +async fn handle_blob_response( hash: bao::Hash, - mut reader: R, + mut reader: ReceiveStream, buffer: &mut BytesMut, -) -> Result> { +) -> Result { match read_lp_data(&mut reader, buffer).await? { Some(response_buffer) => { let response: Response = postcard::from_bytes(&response_buffer)?; @@ -231,7 +261,7 @@ async fn handle_blob_response< // next blob in collection will be sent over Res::Found => { assert!(buffer.is_empty()); - let decoder = AsyncSliceDecoder::new(reader, hash, 0, u64::MAX); + let decoder = DataStream::new(reader, hash); Ok(decoder) } }