From a024123dd2f157d562ac73b9b337b40914b1921e Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Mon, 25 Sep 2023 15:55:13 -0700 Subject: [PATCH] [BUG] Fix scheme bug in GCS anonymous mode (#1443) I think the "deadlocking" in #1432 is actually caused by Python's difflib being really slow on the diff of the bad result. Also, because our io integration tests on that workflow isn't provided with GCP credentials, it is defaulting to the "anonymous" GCS client which is s3 based. We should fix that once we verify that this fix works. Closes: #1432 --------- Co-authored-by: Jay Chia --- src/daft-io/src/s3_like.rs | 20 ++++++++++++++++---- tests/integration/io/test_list_files_gcs.py | 20 +++++++++++++------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 2b4614e110..8dba518e46 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -541,10 +541,12 @@ impl S3LikeSource { } } + #[allow(clippy::too_many_arguments)] #[async_recursion] async fn _list_impl( &self, _permit: SemaphorePermit<'async_recursion>, + scheme: &str, bucket: &str, key: &str, delimiter: String, @@ -587,7 +589,7 @@ impl S3LikeSource { } else { request.send().await }; - let uri = &format!("s3://{bucket}/{key}"); + let uri = &format!("{scheme}://{bucket}/{key}"); match response { Ok(v) => { let dirs = v.common_prefixes(); @@ -604,7 +606,10 @@ impl S3LikeSource { if let Some(dirs) = dirs { for d in dirs { let fmeta = FileMetadata { - filepath: format!("s3://{bucket}/{}", d.prefix().unwrap_or_default()), + filepath: format!( + "{scheme}://{bucket}/{}", + d.prefix().unwrap_or_default() + ), size: None, filetype: FileType::Directory, }; @@ -614,7 +619,10 @@ impl S3LikeSource { if let Some(files) = files { for f in files { let fmeta = FileMetadata { - filepath: format!("s3://{bucket}/{}", f.key().unwrap_or_default()), + filepath: format!( + "{scheme}://{bucket}/{}", + f.key().unwrap_or_default() + ), size: Some(f.size() as u64), filetype: FileType::File, }; @@ -646,6 +654,7 @@ impl S3LikeSource { log::debug!("S3 Region of {uri} different than client {:?} vs {:?} Attempting List in that region with new client", new_region, region); self._list_impl( _permit, + scheme, bucket, key, delimiter, @@ -694,6 +703,7 @@ impl ObjectSource for S3LikeSource { continuation_token: Option<&str>, ) -> super::Result { let parsed = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?; + let scheme = parsed.scheme(); let delimiter = delimiter.unwrap_or("/"); let bucket = match parsed.host_str() { @@ -723,6 +733,7 @@ impl ObjectSource for S3LikeSource { self._list_impl( permit, + scheme, bucket, &key, delimiter.into(), @@ -742,6 +753,7 @@ impl ObjectSource for S3LikeSource { let mut lsr = self ._list_impl( permit, + scheme, bucket, key, delimiter.into(), @@ -749,7 +761,7 @@ impl ObjectSource for S3LikeSource { &self.default_region, ) .await?; - let target_path = format!("s3://{bucket}/{key}"); + let target_path = format!("{scheme}://{bucket}/{key}"); lsr.files.retain(|f| f.filepath == target_path); if lsr.files.is_empty() { diff --git a/tests/integration/io/test_list_files_gcs.py b/tests/integration/io/test_list_files_gcs.py index 640d9d5be4..1053e25fa1 100644 --- a/tests/integration/io/test_list_files_gcs.py +++ b/tests/integration/io/test_list_files_gcs.py @@ -3,9 +3,11 @@ import gcsfs import pytest -from daft.daft import io_list +from daft.daft import GCSConfig, IOConfig, io_list BUCKET = "daft-public-data-gs" +DEFAULT_GCS_CONFIG = GCSConfig(project_id=None, anonymous=None) +ANON_GCS_CONFIG = GCSConfig(project_id=None, anonymous=True) def gcsfs_recursive_list(fs, path) -> list: @@ -49,28 +51,32 @@ def compare_gcs_result(daft_ls_result: list, fsspec_result: list): ], ) @pytest.mark.parametrize("recursive", [False, True]) -def test_gs_flat_directory_listing(path, recursive): +@pytest.mark.parametrize("gcs_config", [DEFAULT_GCS_CONFIG, ANON_GCS_CONFIG]) +def test_gs_flat_directory_listing(path, recursive, gcs_config): fs = gcsfs.GCSFileSystem() - daft_ls_result = io_list(path, recursive=recursive) + daft_ls_result = io_list(path, recursive=recursive, io_config=IOConfig(gcs=gcs_config)) fsspec_result = gcsfs_recursive_list(fs, path) if recursive else fs.ls(path, detail=True) compare_gcs_result(daft_ls_result, fsspec_result) @pytest.mark.integration() @pytest.mark.parametrize("recursive", [False, True]) -def test_gs_single_file_listing(recursive): +@pytest.mark.parametrize("gcs_config", [DEFAULT_GCS_CONFIG, ANON_GCS_CONFIG]) +def test_gs_single_file_listing(recursive, gcs_config): path = f"gs://{BUCKET}/test_ls/file.txt" fs = gcsfs.GCSFileSystem() - daft_ls_result = io_list(path, recursive=recursive) + daft_ls_result = io_list(path, recursive=recursive, io_config=IOConfig(gcs=gcs_config)) fsspec_result = gcsfs_recursive_list(fs, path) if recursive else fs.ls(path, detail=True) compare_gcs_result(daft_ls_result, fsspec_result) @pytest.mark.integration() -def test_gs_notfound(): +@pytest.mark.parametrize("recursive", [False, True]) +@pytest.mark.parametrize("gcs_config", [DEFAULT_GCS_CONFIG, ANON_GCS_CONFIG]) +def test_gs_notfound(recursive, gcs_config): path = f"gs://{BUCKET}/test_ls/MISSING" fs = gcsfs.GCSFileSystem() with pytest.raises(FileNotFoundError): fs.ls(path, detail=True) with pytest.raises(FileNotFoundError, match=path): - io_list(path) + io_list(path, recursive=recursive, io_config=IOConfig(gcs=gcs_config))