From 61bf3e27c3bfa04aa03a0a0ac3ed822560dd1af0 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Thu, 21 Sep 2023 18:13:59 -0700 Subject: [PATCH] [FEAT] Native S3 Lister, support trailing slashes and fix panics when connection is dropped for tokio (#1404) --- src/daft-io/src/lib.rs | 6 ++++++ src/daft-io/src/object_io.rs | 12 ++++++++---- src/daft-io/src/s3_like.rs | 18 +++++++++--------- .../integration/io/test_list_files_s3_minio.py | 18 ++++++++++++++++++ 4 files changed, 41 insertions(+), 13 deletions(-) diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index d83a486e65..e83a9768a9 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -84,6 +84,12 @@ pub enum Error { #[snafu(display("Unhandled Error for path: {}\nDetails:\n{}", path, msg))] Unhandled { path: String, msg: String }, + #[snafu( + display("Error sending data over a tokio channel: {}", source), + context(false) + )] + UnableToSendDataOverChannel { source: DynError }, + #[snafu(display("Error joining spawned task: {}", source), context(false))] JoinError { source: tokio::task::JoinError }, } diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index d793c129c8..3e624cb27f 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -127,18 +127,22 @@ pub(crate) async fn recursive_iter( let mut s = match s { Ok(s) => s, Err(e) => { - tx.send(Err(e)).await.unwrap(); - return; + tx.send(Err(e)).await.map_err(|se| { + super::Error::UnableToSendDataOverChannel { source: se.into() } + })?; + return super::Result::<_, super::Error>::Ok(()); } }; let tx = &tx; while let Some(tr) = s.next().await { - let tr = tr; if let Ok(ref tr) = tr && matches!(tr.filetype, FileType::Directory) { add_to_channel(source.clone(), tx.clone(), tr.filepath.clone()) } - tx.send(tr).await.unwrap(); + tx.send(tr) + .await + .map_err(|e| super::Error::UnableToSendDataOverChannel { source: e.into() })?; } + super::Result::Ok(()) }); } diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index cc1c9c8c71..f9bf4ec16b 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -699,14 +699,15 @@ impl ObjectSource for S3LikeSource { }), }?; let key = parsed.path(); - - let key = key.strip_prefix('/').unwrap_or(""); + let key = key + .trim_start_matches(delimiter) + .trim_end_matches(delimiter); let key = if key.is_empty() { "".to_string() } else { - let key = key.strip_suffix('/').unwrap_or(key); - format!("{key}/") + format!("{key}{delimiter}") }; + // assume its a directory first let lsr = { let permit = self @@ -725,26 +726,25 @@ impl ObjectSource for S3LikeSource { ) .await? }; - if lsr.files.is_empty() && key.contains('/') { + if lsr.files.is_empty() && key.contains(delimiter) { let permit = self .connection_pool_sema .acquire() .await .context(UnableToGrabSemaphoreSnafu)?; // Might be a File - let split = key.rsplit_once('/'); - let (new_key, _) = split.unwrap(); + let key = key.trim_end_matches(delimiter); let mut lsr = self ._list_impl( permit, bucket, - new_key, + key, delimiter.into(), continuation_token.map(String::from), &self.default_region, ) .await?; - let target_path = format!("s3://{bucket}/{new_key}"); + let target_path = format!("s3://{bucket}/{key}"); lsr.files.retain(|f| f.filepath == target_path); if lsr.files.is_empty() { diff --git a/tests/integration/io/test_list_files_s3_minio.py b/tests/integration/io/test_list_files_s3_minio.py index 7a1744c307..66aae58418 100644 --- a/tests/integration/io/test_list_files_s3_minio.py +++ b/tests/integration/io/test_list_files_s3_minio.py @@ -70,6 +70,24 @@ def test_single_file_directory_listing(minio_io_config, recursive): compare_s3_result(daft_ls_result, s3fs_result) +@pytest.mark.integration() +@pytest.mark.parametrize( + "recursive", + [False, True], +) +def test_single_file_directory_listing_trailing(minio_io_config, recursive): + bucket_name = "bucket" + with minio_create_bucket(minio_io_config, bucket_name=bucket_name) as fs: + files = ["a", "b/bb", "c/cc/ccc"] + for name in files: + fs.write_bytes(f"s3://{bucket_name}/{name}", b"") + daft_ls_result = io_list(f"s3://{bucket_name}/c/cc///", io_config=minio_io_config, recursive=recursive) + fs.invalidate_cache() + s3fs_result = s3fs_recursive_list(fs, path=f"s3://{bucket_name}/c/cc///") + assert len(daft_ls_result) == 1 + compare_s3_result(daft_ls_result, s3fs_result) + + @pytest.mark.integration() @pytest.mark.parametrize( "recursive",