From 8c20f39347e1e6cccf25c9cfdaadc81adf3a9501 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Thu, 21 Sep 2023 17:40:02 -0700 Subject: [PATCH] [FEAT] Native Rust listing of GCS (#1392) Implements `ls` functionality for GCS --------- Co-authored-by: Jay Chia --- .github/workflows/python-package.yml | 8 + src/daft-io/src/google_cloud.rs | 215 ++++++++++++++------ tests/integration/io/test_list_files_gcs.py | 83 ++++++++ 3 files changed, 239 insertions(+), 67 deletions(-) create mode 100644 tests/integration/io/test_list_files_gcs.py diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index ea5066e35b..979ae9a1a5 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -258,6 +258,14 @@ jobs: aws-region: us-west-2 role-to-assume: ${{ secrets.ACTIONS_AWS_ROLE_ARN }} role-session-name: DaftPythonPackageGitHubWorkflow + - name: Assume GitHub Actions GCloud Credentials + uses: google-github-actions/auth@v1 + with: + credentials_json: ${{ secrets.ACTIONS_GCP_SERVICE_ACCOUNT_JSON }} + # NOTE: Workload Identity seems to be having problems with our Rust crate, so we use JSON instead + # See issue: https://github.com/yoshidan/google-cloud-rust/issues/171#issuecomment-1730511655 + # workload_identity_provider: ${{ secrets.ACTIONS_GCP_WORKLOAD_IDENTITY_PROVIDER }} + # service_account: ${{ secrets.ACTIONS_GCP_SERVICE_ACCOUNT }} - name: Spin up IO services uses: isbang/compose-action@v1.5.1 with: diff --git a/src/daft-io/src/google_cloud.rs b/src/daft-io/src/google_cloud.rs index 04a6f45d2a..fe8973d1e9 100644 --- a/src/daft-io/src/google_cloud.rs +++ b/src/daft-io/src/google_cloud.rs @@ -9,11 +9,14 @@ use async_trait::async_trait; use google_cloud_storage::client::Client; use google_cloud_storage::http::objects::get::GetObjectRequest; +use google_cloud_storage::http::objects::list::ListObjectsRequest; use google_cloud_storage::http::Error as GError; use snafu::IntoError; use snafu::ResultExt; use snafu::Snafu; +use crate::object_io::FileMetadata; +use crate::object_io::FileType; use crate::object_io::LSResult; use crate::object_io::ObjectSource; use crate::s3_like; @@ -25,6 +28,9 @@ enum Error { #[snafu(display("Unable to open {}: {}", path, source))] UnableToOpenFile { path: String, source: GError }, + #[snafu(display("Unable to list objects: \"{}\"", path))] + UnableToListObjects { path: String, source: GError }, + #[snafu(display("Unable to read data from {}: {}", path, source))] UnableToReadBytes { path: String, source: GError }, @@ -46,44 +52,44 @@ impl From for super::Error { fn from(error: Error) -> Self { use Error::*; match error { - UnableToReadBytes { path, source } | UnableToOpenFile { path, source } => { - match source { - GError::HttpClient(err) => match err.status().map(|s| s.as_u16()) { - Some(404) | Some(410) => super::Error::NotFound { - path, - source: err.into(), - }, - Some(401) => super::Error::Unauthorized { - store: super::SourceType::GCS, - path, - source: err.into(), - }, - _ => super::Error::UnableToOpenFile { - path, - source: err.into(), - }, + UnableToReadBytes { path, source } + | UnableToOpenFile { path, source } + | UnableToListObjects { path, source } => match source { + GError::HttpClient(err) => match err.status().map(|s| s.as_u16()) { + Some(404) | Some(410) => super::Error::NotFound { + path, + source: err.into(), + }, + Some(401) => super::Error::Unauthorized { + store: super::SourceType::GCS, + path, + source: err.into(), }, - GError::Response(err) => match err.code { - 404 | 410 => super::Error::NotFound { - path, - source: err.into(), - }, - 401 => super::Error::Unauthorized { - store: super::SourceType::GCS, - path, - source: err.into(), - }, - _ => super::Error::UnableToOpenFile { - path, - source: err.into(), - }, + _ => super::Error::UnableToOpenFile { + path, + source: err.into(), }, - GError::TokenSource(err) => super::Error::UnableToLoadCredentials { + }, + GError::Response(err) => match err.code { + 404 | 410 => super::Error::NotFound { + path, + source: err.into(), + }, + 401 => super::Error::Unauthorized { store: super::SourceType::GCS, - source: err, + path, + source: err.into(), }, - } - } + _ => super::Error::UnableToOpenFile { + path, + source: err.into(), + }, + }, + GError::TokenSource(err) => super::Error::UnableToLoadCredentials { + store: super::SourceType::GCS, + source: err, + }, + }, NotAFile { path } => super::Error::NotAFile { path }, InvalidUrl { path, source } => super::Error::InvalidUrl { path, source }, UnableToLoadCredentials { source } => super::Error::UnableToLoadCredentials { @@ -99,23 +105,23 @@ enum GCSClientWrapper { S3Compat(Arc), } +fn parse_uri(uri: &url::Url) -> super::Result<(&str, &str)> { + let bucket = match uri.host_str() { + Some(s) => Ok(s), + None => Err(Error::InvalidUrl { + path: uri.to_string(), + source: url::ParseError::EmptyHost, + }), + }?; + let key = uri.path(); + let key = key.strip_prefix('/').unwrap_or(key); + Ok((bucket, key)) +} + impl GCSClientWrapper { async fn get(&self, uri: &str, range: Option>) -> super::Result { - let parsed = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; - let bucket = match parsed.host_str() { - Some(s) => Ok(s), - None => Err(Error::InvalidUrl { - path: uri.into(), - source: url::ParseError::EmptyHost, - }), - }?; - let key = parsed.path(); - let key = if let Some(key) = key.strip_prefix('/') { - key - } else { - return Err(Error::NotAFile { path: uri.into() }.into()); - }; - + let uri = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; + let (bucket, key) = parse_uri(&uri)?; match self { GCSClientWrapper::Native(client) => { let req = GetObjectRequest { @@ -156,20 +162,8 @@ impl GCSClientWrapper { } async fn get_size(&self, uri: &str) -> super::Result { - let parsed = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; - let bucket = match parsed.host_str() { - Some(s) => Ok(s), - None => Err(Error::InvalidUrl { - path: uri.into(), - source: url::ParseError::EmptyHost, - }), - }?; - let key = parsed.path(); - let key = if let Some(key) = key.strip_prefix('/') { - key - } else { - return Err(Error::NotAFile { path: uri.into() }.into()); - }; + let uri = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; + let (bucket, key) = parse_uri(&uri)?; match self { GCSClientWrapper::Native(client) => { let req = GetObjectRequest { @@ -192,6 +186,93 @@ impl GCSClientWrapper { } } } + + async fn _ls_impl( + &self, + client: &Client, + bucket: &str, + key: &str, + delimiter: Option<&str>, + continuation_token: Option<&str>, + ) -> super::Result { + let req = ListObjectsRequest { + bucket: bucket.to_string(), + prefix: Some(key.to_string()), + end_offset: None, + start_offset: None, + page_token: continuation_token.map(|s| s.to_string()), + delimiter: Some(delimiter.unwrap_or("/").to_string()), // returns results in "directory mode" + max_results: Some(1000), // Recommended value from API docs + include_trailing_delimiter: Some(false), // This will not populate "directories" in the response's .item[] + projection: None, + versions: None, + }; + let ls_response = client + .list_objects(&req) + .await + .context(UnableToListObjectsSnafu { + path: format!("gs://{}/{}", bucket, key), + })?; + let response_items = ls_response.items.unwrap_or_default(); + let response_prefixes = ls_response.prefixes.unwrap_or_default(); + let files = response_items.iter().filter_map(|obj| { + if obj.name.ends_with('/') { + // Sometimes the GCS API returns "folders" in .items[], so we manually filter here + None + } else { + Some(FileMetadata { + filepath: format!("gs://{}/{}", bucket, obj.name), + size: Some(obj.size as u64), + filetype: FileType::File, + }) + } + }); + let dirs = response_prefixes.iter().map(|pref| FileMetadata { + filepath: format!("gs://{}/{}", bucket, pref), + size: None, + filetype: FileType::Directory, + }); + Ok(LSResult { + files: files.chain(dirs).collect(), + continuation_token: ls_response.next_page_token, + }) + } + + async fn ls( + &self, + path: &str, + delimiter: Option<&str>, + continuation_token: Option<&str>, + ) -> super::Result { + let uri = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?; + let (bucket, key) = parse_uri(&uri)?; + match self { + GCSClientWrapper::Native(client) => { + // Attempt to forcefully ls the key as a directory (by ensuring a "/" suffix) + // If no items were obtained, then this is actually a file and we perform a second ls to obtain just the file's + // details as the one-and-only-one entry + let forced_directory_key = format!("{}/", key.strip_suffix('/').unwrap_or(key)); + let forced_directory_ls_result = self + ._ls_impl( + client, + bucket, + forced_directory_key.as_str(), + delimiter, + continuation_token, + ) + .await?; + if forced_directory_ls_result.files.is_empty() { + self._ls_impl(client, bucket, key, delimiter, continuation_token) + .await + } else { + Ok(forced_directory_ls_result) + } + } + GCSClientWrapper::S3Compat(client) => { + client.ls(path, delimiter, continuation_token).await + } + } + } } pub(crate) struct GCSSource { @@ -248,10 +329,10 @@ impl ObjectSource for GCSSource { async fn ls( &self, - _path: &str, - _delimiter: Option<&str>, - _continuation_token: Option<&str>, + path: &str, + delimiter: Option<&str>, + continuation_token: Option<&str>, ) -> super::Result { - unimplemented!("gcs ls"); + self.client.ls(path, delimiter, continuation_token).await } } diff --git a/tests/integration/io/test_list_files_gcs.py b/tests/integration/io/test_list_files_gcs.py new file mode 100644 index 0000000000..ddbb29bc2d --- /dev/null +++ b/tests/integration/io/test_list_files_gcs.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import gcsfs +import pytest + +from daft.daft import io_list + +BUCKET = "daft-public-data-gs" + + +def compare_gcs_result(daft_ls_result: list, fsspec_result: list): + daft_files = [(f["path"], f["type"].lower()) for f in daft_ls_result] + gcsfs_files = [(f"gs://{f['name']}", f["type"]) for f in fsspec_result] + + # Perform necessary post-processing of fsspec results to match expected behavior from Daft: + + # NOTE: gcsfs sometimes does not return the trailing / for directories, so we have to ensure it + gcsfs_files = [ + (f"{path.rstrip('/')}/", type_) if type_ == "directory" else (path, type_) for path, type_ in gcsfs_files + ] + + # NOTE: gcsfs will sometimes return 0-sized marker files for manually-created folders, which we ignore here + # Be careful here because this will end up pruning any truly size-0 files that are actually files and not folders! + size_0_files = {f"gs://{f['name']}" for f in fsspec_result if f["size"] == 0 and f["type"] == "file"} + gcsfs_files = [(path, type_) for path, type_ in gcsfs_files if path not in size_0_files] + + assert len(daft_files) == len(gcsfs_files) + assert sorted(daft_files) == sorted(gcsfs_files) + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "path", + [ + f"gs://{BUCKET}", + f"gs://{BUCKET}/", + f"gs://{BUCKET}/test_ls", + f"gs://{BUCKET}/test_ls/", + f"gs://{BUCKET}/test_ls/paginated-1100-files/", + ], +) +def test_gs_flat_directory_listing(path): + fs = gcsfs.GCSFileSystem() + daft_ls_result = io_list(path) + fsspec_result = fs.ls(path, detail=True) + compare_gcs_result(daft_ls_result, fsspec_result) + + +@pytest.mark.integration() +def test_gs_single_file_listing(): + path = f"gs://{BUCKET}/test_ls/file.txt" + fs = gcsfs.GCSFileSystem() + daft_ls_result = io_list(path) + fsspec_result = fs.ls(path, detail=True) + compare_gcs_result(daft_ls_result, fsspec_result) + + +@pytest.mark.integration() +def test_gs_notfound(): + path = f"gs://{BUCKET}/test_ls/MISSING" + fs = gcsfs.GCSFileSystem() + with pytest.raises(FileNotFoundError): + fs.ls(path, detail=True) + + # NOTE: Google Cloud does not return a 404 to indicate anything missing, but just returns empty results + # Thus Daft is unable to differentiate between "missing" folders and "empty" folders + daft_ls_result = io_list(path) + assert daft_ls_result == [] + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "path", + [ + f"gs://{BUCKET}/test_ls", + f"gs://{BUCKET}/test_ls/", + ], +) +def test_gs_flat_directory_listing_recursive(path): + fs = gcsfs.GCSFileSystem() + daft_ls_result = io_list(path, recursive=True) + fsspec_result = list(fs.glob(path.rstrip("/") + "/**", detail=True).values()) + compare_gcs_result(daft_ls_result, fsspec_result)