diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 2b4614e110..8dba518e46 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -541,10 +541,12 @@ impl S3LikeSource { } } + #[allow(clippy::too_many_arguments)] #[async_recursion] async fn _list_impl( &self, _permit: SemaphorePermit<'async_recursion>, + scheme: &str, bucket: &str, key: &str, delimiter: String, @@ -587,7 +589,7 @@ impl S3LikeSource { } else { request.send().await }; - let uri = &format!("s3://{bucket}/{key}"); + let uri = &format!("{scheme}://{bucket}/{key}"); match response { Ok(v) => { let dirs = v.common_prefixes(); @@ -604,7 +606,10 @@ impl S3LikeSource { if let Some(dirs) = dirs { for d in dirs { let fmeta = FileMetadata { - filepath: format!("s3://{bucket}/{}", d.prefix().unwrap_or_default()), + filepath: format!( + "{scheme}://{bucket}/{}", + d.prefix().unwrap_or_default() + ), size: None, filetype: FileType::Directory, }; @@ -614,7 +619,10 @@ impl S3LikeSource { if let Some(files) = files { for f in files { let fmeta = FileMetadata { - filepath: format!("s3://{bucket}/{}", f.key().unwrap_or_default()), + filepath: format!( + "{scheme}://{bucket}/{}", + f.key().unwrap_or_default() + ), size: Some(f.size() as u64), filetype: FileType::File, }; @@ -646,6 +654,7 @@ impl S3LikeSource { log::debug!("S3 Region of {uri} different than client {:?} vs {:?} Attempting List in that region with new client", new_region, region); self._list_impl( _permit, + scheme, bucket, key, delimiter, @@ -694,6 +703,7 @@ impl ObjectSource for S3LikeSource { continuation_token: Option<&str>, ) -> super::Result { let parsed = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?; + let scheme = parsed.scheme(); let delimiter = delimiter.unwrap_or("/"); let bucket = match parsed.host_str() { @@ -723,6 +733,7 @@ impl ObjectSource for S3LikeSource { self._list_impl( permit, + scheme, bucket, &key, delimiter.into(), @@ -742,6 +753,7 @@ impl ObjectSource for S3LikeSource { let mut lsr = self ._list_impl( permit, + scheme, bucket, key, delimiter.into(), @@ -749,7 +761,7 @@ impl ObjectSource for S3LikeSource { &self.default_region, ) .await?; - let target_path = format!("s3://{bucket}/{key}"); + let target_path = format!("{scheme}://{bucket}/{key}"); lsr.files.retain(|f| f.filepath == target_path); if lsr.files.is_empty() { diff --git a/tests/integration/io/test_list_files_gcs.py b/tests/integration/io/test_list_files_gcs.py index 640d9d5be4..1053e25fa1 100644 --- a/tests/integration/io/test_list_files_gcs.py +++ b/tests/integration/io/test_list_files_gcs.py @@ -3,9 +3,11 @@ import gcsfs import pytest -from daft.daft import io_list +from daft.daft import GCSConfig, IOConfig, io_list BUCKET = "daft-public-data-gs" +DEFAULT_GCS_CONFIG = GCSConfig(project_id=None, anonymous=None) +ANON_GCS_CONFIG = GCSConfig(project_id=None, anonymous=True) def gcsfs_recursive_list(fs, path) -> list: @@ -49,28 +51,32 @@ def compare_gcs_result(daft_ls_result: list, fsspec_result: list): ], ) @pytest.mark.parametrize("recursive", [False, True]) -def test_gs_flat_directory_listing(path, recursive): +@pytest.mark.parametrize("gcs_config", [DEFAULT_GCS_CONFIG, ANON_GCS_CONFIG]) +def test_gs_flat_directory_listing(path, recursive, gcs_config): fs = gcsfs.GCSFileSystem() - daft_ls_result = io_list(path, recursive=recursive) + daft_ls_result = io_list(path, recursive=recursive, io_config=IOConfig(gcs=gcs_config)) fsspec_result = gcsfs_recursive_list(fs, path) if recursive else fs.ls(path, detail=True) compare_gcs_result(daft_ls_result, fsspec_result) @pytest.mark.integration() @pytest.mark.parametrize("recursive", [False, True]) -def test_gs_single_file_listing(recursive): +@pytest.mark.parametrize("gcs_config", [DEFAULT_GCS_CONFIG, ANON_GCS_CONFIG]) +def test_gs_single_file_listing(recursive, gcs_config): path = f"gs://{BUCKET}/test_ls/file.txt" fs = gcsfs.GCSFileSystem() - daft_ls_result = io_list(path, recursive=recursive) + daft_ls_result = io_list(path, recursive=recursive, io_config=IOConfig(gcs=gcs_config)) fsspec_result = gcsfs_recursive_list(fs, path) if recursive else fs.ls(path, detail=True) compare_gcs_result(daft_ls_result, fsspec_result) @pytest.mark.integration() -def test_gs_notfound(): +@pytest.mark.parametrize("recursive", [False, True]) +@pytest.mark.parametrize("gcs_config", [DEFAULT_GCS_CONFIG, ANON_GCS_CONFIG]) +def test_gs_notfound(recursive, gcs_config): path = f"gs://{BUCKET}/test_ls/MISSING" fs = gcsfs.GCSFileSystem() with pytest.raises(FileNotFoundError): fs.ls(path, detail=True) with pytest.raises(FileNotFoundError, match=path): - io_list(path) + io_list(path, recursive=recursive, io_config=IOConfig(gcs=gcs_config))