diff --git a/src/daft-csv/src/metadata.rs b/src/daft-csv/src/metadata.rs index f45f25946b..1350bbe425 100644 --- a/src/daft-csv/src/metadata.rs +++ b/src/daft-csv/src/metadata.rs @@ -130,7 +130,7 @@ pub(crate) async fn read_csv_schema_single( ) .await } - GetResult::Stream(stream, size, _) => { + GetResult::Stream(stream, size, ..) => { read_csv_schema_from_compressed_reader( StreamReader::new(stream), compression_codec, diff --git a/src/daft-csv/src/read.rs b/src/daft-csv/src/read.rs index f03ebb883b..c646eed2ac 100644 --- a/src/daft-csv/src/read.rs +++ b/src/daft-csv/src/read.rs @@ -356,7 +356,7 @@ async fn read_csv_single_into_stream( .unwrap_or(64 * 1024), ) } - GetResult::Stream(stream, _, _) => ( + GetResult::Stream(stream, ..) => ( Box::new(StreamReader::new(stream)), read_options .as_ref() diff --git a/src/daft-io/src/azure_blob.rs b/src/daft-io/src/azure_blob.rs index 17c054899a..9fef7435db 100644 --- a/src/daft-io/src/azure_blob.rs +++ b/src/daft-io/src/azure_blob.rs @@ -484,6 +484,7 @@ impl ObjectSource for AzureBlobSource { io_stats_on_bytestream(Box::pin(stream), io_stats), None, None, + None, )) } diff --git a/src/daft-io/src/google_cloud.rs b/src/daft-io/src/google_cloud.rs index f2ad4e67e4..4c39c144a9 100644 --- a/src/daft-io/src/google_cloud.rs +++ b/src/daft-io/src/google_cloud.rs @@ -174,6 +174,7 @@ impl GCSClientWrapper { io_stats_on_bytestream(response, io_stats), size, None, + None, )) } diff --git a/src/daft-io/src/http.rs b/src/daft-io/src/http.rs index 9ed0381b2b..f455050ede 100644 --- a/src/daft-io/src/http.rs +++ b/src/daft-io/src/http.rs @@ -214,6 +214,7 @@ impl ObjectSource for HttpSource { io_stats_on_bytestream(stream, io_stats), size_bytes, None, + None, )) } diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 4407e5dfe6..087f280888 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -1,6 +1,7 @@ #![feature(async_closure)] #![feature(let_chains)] #![feature(io_error_more)] +#![feature(if_let_guard)] mod azure_blob; mod google_cloud; mod http; @@ -19,6 +20,7 @@ pub mod python; pub use common_io_config::{AzureConfig, IOConfig, S3Config}; pub use object_io::FileMetadata; pub use object_io::GetResult; +use object_io::StreamingRetryParams; #[cfg(feature = "python")] pub use python::register_modules; pub use stats::{IOStatsContext, IOStatsRef}; @@ -232,7 +234,10 @@ impl IOClient { ) -> Result { let (scheme, path) = parse_url(&input)?; let source = self.get_source(&scheme).await?; - source.get(path.as_ref(), range, io_stats).await + let get_result = source + .get(path.as_ref(), range.clone(), io_stats.clone()) + .await?; + Ok(get_result.with_retry(StreamingRetryParams::new(source, input, range, io_stats))) } pub async fn single_url_get_size( diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index 8a9d316591..5d50c2126b 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -1,5 +1,6 @@ use std::ops::Range; use std::sync::Arc; +use std::time::Duration; use async_trait::async_trait; use bytes::Bytes; @@ -12,16 +13,44 @@ use tokio::sync::OwnedSemaphorePermit; use crate::local::{collect_file, LocalFile}; use crate::stats::IOStatsRef; +pub struct StreamingRetryParams { + source: Arc, + input: String, + range: Option>, + io_stats: Option, +} + +impl StreamingRetryParams { + pub(crate) fn new( + source: Arc, + input: String, + range: Option>, + io_stats: Option, + ) -> Self { + Self { + source, + input, + range, + io_stats, + } + } +} + pub enum GetResult { File(LocalFile), Stream( BoxStream<'static, super::Result>, Option, Option, + Option>, ), } -async fn collect_bytes(mut stream: S, size_hint: Option) -> super::Result +async fn collect_bytes( + mut stream: S, + size_hint: Option, + _permit: Option, +) -> super::Result where S: Stream> + Send + Unpin, { @@ -47,9 +76,57 @@ where impl GetResult { pub async fn bytes(self) -> super::Result { use GetResult::*; - match self { + let mut get_result = self; + match get_result { File(f) => collect_file(f).await, - Stream(stream, size, _permit) => collect_bytes(stream, size).await, + Stream(stream, size, permit, retry_params) => { + use rand::Rng; + const NUM_TRIES: u64 = 3; + const JITTER_MS: u64 = 2_500; + const MAX_BACKOFF_MS: u64 = 20_000; + + let mut result = collect_bytes(stream, size, permit).await; // drop permit to ensure quota + for attempt in 1..NUM_TRIES { + match result { + Err(super::Error::SocketError { .. }) + | Err(super::Error::UnableToReadBytes { .. }) + if let Some(rp) = &retry_params => + { + let jitter = rand::thread_rng() + .gen_range(0..((1 << (attempt - 1)) * JITTER_MS)) + as u64; + let jitter = jitter.min(MAX_BACKOFF_MS); + + log::warn!( + "Received Socket Error when streaming bytes! Attempt {attempt} out of {NUM_TRIES} tries. Trying again in {jitter}ms\nDetails\n{}", + result.err().unwrap() + ); + tokio::time::sleep(Duration::from_millis(jitter)).await; + + get_result = rp + .source + .get(&rp.input, rp.range.clone(), rp.io_stats.clone()) + .await?; + if let GetResult::Stream(stream, size, permit, _) = get_result { + result = collect_bytes(stream, size, permit).await; + } else { + unreachable!("Retrying a stream should always be a stream"); + } + } + _ => break, + } + } + result + } + } + } + + pub fn with_retry(self, params: StreamingRetryParams) -> Self { + match self { + GetResult::File(..) => self, + GetResult::Stream(s, size, permit, _) => { + GetResult::Stream(s, size, permit, Some(Box::new(params))) + } } } } diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index dbde82bded..f3a527e328 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -647,6 +647,7 @@ impl S3LikeSource { stream, Some(v.content_length as usize), Some(permit), + None, )) } @@ -933,7 +934,7 @@ impl ObjectSource for S3LikeSource { .await?; if io_stats.is_some() { - if let GetResult::Stream(stream, num_bytes, permit) = get_result { + if let GetResult::Stream(stream, num_bytes, permit, retry_params) = get_result { if let Some(is) = io_stats.as_ref() { is.mark_get_requests(1) } @@ -941,6 +942,7 @@ impl ObjectSource for S3LikeSource { io_stats_on_bytestream(stream, io_stats), num_bytes, permit, + retry_params, )) } else { panic!("This should always be a stream"); diff --git a/src/daft-json/src/read.rs b/src/daft-json/src/read.rs index 94a2524bcc..b6ee963781 100644 --- a/src/daft-json/src/read.rs +++ b/src/daft-json/src/read.rs @@ -335,7 +335,7 @@ async fn read_json_single_into_stream( .unwrap_or(64), ) } - GetResult::Stream(stream, _, _) => ( + GetResult::Stream(stream, ..) => ( Box::new(StreamReader::new(stream)), // Use user-provided buffer size, falling back to 8 * the user-provided chunk size if that exists, otherwise falling back to 512 KiB as the default. read_options diff --git a/src/daft-json/src/schema.rs b/src/daft-json/src/schema.rs index be1ee3cfde..cb10875ebe 100644 --- a/src/daft-json/src/schema.rs +++ b/src/daft-json/src/schema.rs @@ -122,7 +122,7 @@ pub(crate) async fn read_json_schema_single( Box::new(BufReader::new(File::open(file.path).await?)), max_bytes, ), - GetResult::Stream(stream, size, _) => ( + GetResult::Stream(stream, size, ..) => ( Box::new(StreamReader::new(stream)), // Truncate max_bytes to size if both are set. max_bytes.map(|m| size.map(|s| m.min(s)).unwrap_or(m)),