From f6370e51054c86748a8f30fc64075c638c8c06ef Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Wed, 20 Sep 2023 13:06:52 -0700 Subject: [PATCH] [FEAT] Native Recursive File Lister (#1353) * Implements `iter_dir` which gives a stream of directories from a dir * Also implements a recursive version that will explore subdirs --- Cargo.lock | 1 + src/daft-core/Cargo.toml | 2 +- src/daft-core/src/array/ops/concat_agg.rs | 6 +- src/daft-io/Cargo.toml | 1 + src/daft-io/src/object_io.rs | 60 +++++++++- src/daft-io/src/python.rs | 23 +++- src/daft-io/src/s3_like.rs | 104 ++++++++++-------- src/daft-plan/src/test/mod.rs | 4 +- tests/integration/io/conftest.py | 37 ++++--- .../io/test_list_files_s3_minio.py | 25 +++++ 10 files changed, 195 insertions(+), 68 deletions(-) create mode 100644 tests/integration/io/test_list_files_s3_minio.py diff --git a/Cargo.lock b/Cargo.lock index 8d8003f0f9..b12580ed71 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1063,6 +1063,7 @@ name = "daft-io" version = "0.1.10" dependencies = [ "async-recursion", + "async-stream", "async-trait", "aws-config", "aws-credential-types", diff --git a/src/daft-core/Cargo.toml b/src/daft-core/Cargo.toml index edb39ed2eb..bf9018e28b 100644 --- a/src/daft-core/Cargo.toml +++ b/src/daft-core/Cargo.toml @@ -8,7 +8,7 @@ common-error = {path = "../common/error", default-features = false} dyn-clone = "1.0.12" fnv = "1.0.7" html-escape = {workspace = true} -indexmap = {workspace = true, features = ["serde"], version = "2.0.0"} +indexmap = {workspace = true, features = ["serde"]} lazy_static = {workspace = true} log = {workspace = true} ndarray = "0.15.6" diff --git a/src/daft-core/src/array/ops/concat_agg.rs b/src/daft-core/src/array/ops/concat_agg.rs index 97b588fd0b..ec92b9ecdb 100644 --- a/src/daft-core/src/array/ops/concat_agg.rs +++ b/src/daft-core/src/array/ops/concat_agg.rs @@ -165,7 +165,7 @@ mod test { Field::new("foo", DataType::List(Box::new(DataType::Int64))), Int64Array::from(( "item", - Box::new(arrow2::array::Int64Array::from_iter(vec![].iter())), + Box::new(arrow2::array::Int64Array::from_iter([].iter())), )) .into_series(), arrow2::offset::OffsetsBuffer::::try_from(vec![0, 0, 0, 0])?, @@ -190,7 +190,7 @@ mod test { Int64Array::from(( "item", Box::new(arrow2::array::Int64Array::from_iter( - vec![Some(0), Some(1), Some(1), Some(2), None, None, Some(10000)].iter(), + [Some(0), Some(1), Some(1), Some(2), None, None, Some(10000)].iter(), )), )) .into_series(), @@ -225,7 +225,7 @@ mod test { Int64Array::from(( "item", Box::new(arrow2::array::Int64Array::from_iter( - vec![ + [ Some(0), Some(0), Some(0), diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index ea361bc3e6..ec715f374b 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -1,5 +1,6 @@ [dependencies] async-recursion = "1.0.4" +async-stream = "0.3.5" async-trait = "0.1.71" aws-config = {version = "0.55.3", features = ["native-tls", "rt-tokio", "client-hyper", "credentials-sso"], default-features = false} aws-credential-types = {version = "0.55.3", features = ["hardcoded-credentials"]} diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index 950438a37b..a3e5dbdd74 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -1,9 +1,11 @@ use std::ops::Range; +use std::sync::Arc; use async_trait::async_trait; use bytes::Bytes; use futures::stream::{BoxStream, Stream}; use futures::StreamExt; +use tokio::sync::mpsc::Sender; use tokio::sync::OwnedSemaphorePermit; use crate::local::{collect_file, LocalFile}; @@ -67,6 +69,8 @@ pub struct LSResult { pub continuation_token: Option, } +use async_stream::stream; + #[async_trait] pub(crate) trait ObjectSource: Sync + Send { async fn get(&self, uri: &str, range: Option>) -> super::Result; @@ -81,5 +85,59 @@ pub(crate) trait ObjectSource: Sync + Send { continuation_token: Option<&str>, ) -> super::Result; - // async fn iter_dir(&self, path: &str, limit: Option) -> super::Result>>>; + async fn iter_dir( + &self, + uri: &str, + delimiter: Option<&str>, + _limit: Option, + ) -> super::Result>> { + let uri = uri.to_string(); + let delimiter = delimiter.map(String::from); + let s = stream! { + let lsr = self.ls(&uri, delimiter.as_deref(), None).await?; + let mut continuation_token = lsr.continuation_token.clone(); + for file in lsr.files { + yield Ok(file); + } + + while continuation_token.is_some() { + let lsr = self.ls(&uri, delimiter.as_deref(), continuation_token.as_deref()).await?; + continuation_token = lsr.continuation_token.clone(); + for file in lsr.files { + yield Ok(file); + } + } + }; + Ok(s.boxed()) + } +} + +pub(crate) async fn recursive_iter( + source: Arc, + uri: &str, +) -> super::Result>> { + let (to_rtn_tx, mut to_rtn_rx) = tokio::sync::mpsc::channel(16 * 1024); + fn add_to_channel(source: Arc, tx: Sender, dir: String) { + tokio::spawn(async move { + let mut s = source.iter_dir(&dir, None, None).await.unwrap(); + let tx = &tx; + while let Some(tr) = s.next().await { + let tr = tr.unwrap(); + match tr.filetype { + FileType::File => tx.send(tr).await.unwrap(), + FileType::Directory => add_to_channel(source.clone(), tx.clone(), tr.filepath), + }; + } + }); + } + + add_to_channel(source, to_rtn_tx, uri.to_string()); + + let to_rtn_stream = stream! { + while let Some(v) = to_rtn_rx.recv().await { + yield Ok(v) + } + }; + + Ok(to_rtn_stream.boxed()) } diff --git a/src/daft-io/src/python.rs b/src/daft-io/src/python.rs index 306ecdc745..7e48ac6657 100644 --- a/src/daft-io/src/python.rs +++ b/src/daft-io/src/python.rs @@ -2,8 +2,9 @@ pub use common_io_config::python::{AzureConfig, GCSConfig, IOConfig}; pub use py::register_modules; mod py { - use crate::{get_io_client, get_runtime, object_io::LSResult, parse_url}; + use crate::{get_io_client, get_runtime, object_io::recursive_iter, parse_url}; use common_error::DaftResult; + use futures::TryStreamExt; use pyo3::{ prelude::*, types::{PyDict, PyList}, @@ -13,10 +14,11 @@ mod py { fn io_list( py: Python, path: String, + recursive: Option, multithreaded_io: Option, io_config: Option, ) -> PyResult<&PyList> { - let lsr: DaftResult = py.allow_threads(|| { + let lsr: DaftResult> = py.allow_threads(|| { let io_client = get_io_client( multithreaded_io.unwrap_or(true), io_config.unwrap_or_default().config.into(), @@ -27,12 +29,25 @@ mod py { runtime_handle.block_on(async move { let source = io_client.get_source(&scheme).await?; - Ok(source.ls(&path, None, None).await?) + let files = if recursive.is_some_and(|r| r) { + recursive_iter(source, &path) + .await? + .try_collect::>() + .await? + } else { + source + .iter_dir(&path, None, None) + .await? + .try_collect::>() + .await? + }; + + Ok(files) }) }); let lsr = lsr?; let mut to_rtn = vec![]; - for file in lsr.files { + for file in lsr { let dict = PyDict::new(py); dict.set_item("type", format!("{:?}", file.filetype))?; dict.set_item("path", file.filepath)?; diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 14a09916d3..7d14fee157 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -134,6 +134,20 @@ impl From for super::Error { source: err.into(), }, }, + UnableToListObjects { path, source } => match source.into_service_error() { + ListObjectsV2Error::NoSuchBucket(no_such_key) => super::Error::NotFound { + path, + source: no_such_key.into(), + }, + ListObjectsV2Error::Unhandled(v) => super::Error::Unhandled { + path, + msg: DisplayErrorContext(v).to_string(), + }, + err => super::Error::UnableToOpenFile { + path, + source: err.into(), + }, + }, InvalidUrl { path, source } => super::Error::InvalidUrl { path, source }, UnableToReadBytes { path, source } => super::Error::UnableToReadBytes { path, @@ -538,7 +552,7 @@ impl S3LikeSource { Ok(v) => { let dirs = v.common_prefixes(); let files = v.contents(); - let continuation_token = v.continuation_token().map(|s| s.to_string()); + let continuation_token = v.next_continuation_token().map(|s| s.to_string()); let mut total_len = 0; if let Some(dirs) = dirs { total_len += dirs.len() @@ -641,6 +655,8 @@ impl ObjectSource for S3LikeSource { ) -> super::Result { let parsed = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?; let delimiter = delimiter.unwrap_or("/"); + log::warn!("{:?}", parsed); + let bucket = match parsed.host_str() { Some(s) => Ok(s), None => Err(Error::InvalidUrl { @@ -650,57 +666,57 @@ impl ObjectSource for S3LikeSource { }?; let key = parsed.path(); - if let Some(key) = key.strip_prefix('/') { - // assume its a directory first - let key = format!("{}/", key.strip_suffix('/').unwrap_or(key)); - let lsr = { - let permit = self - .connection_pool_sema - .acquire() - .await - .context(UnableToGrabSemaphoreSnafu)?; - self._list_impl( + let key = key + .strip_prefix('/') + .map(|k| k.strip_suffix('/').unwrap_or(k)); + let key = key.unwrap_or(""); + + // assume its a directory first + let lsr = { + let permit = self + .connection_pool_sema + .acquire() + .await + .context(UnableToGrabSemaphoreSnafu)?; + self._list_impl( + permit, + bucket, + key, + delimiter.into(), + continuation_token.map(String::from), + &self.default_region, + ) + .await? + }; + if lsr.files.is_empty() && key.contains('/') { + let permit = self + .connection_pool_sema + .acquire() + .await + .context(UnableToGrabSemaphoreSnafu)?; + // Might be a File + let split = key.rsplit_once('/'); + let (new_key, _) = split.unwrap(); + let mut lsr = self + ._list_impl( permit, bucket, - &key, + new_key, delimiter.into(), continuation_token.map(String::from), &self.default_region, ) - .await? - }; - if lsr.files.is_empty() && key.contains('/') { - let permit = self - .connection_pool_sema - .acquire() - .await - .context(UnableToGrabSemaphoreSnafu)?; - // Might be a File - let split = key.rsplit_once('/'); - let (new_key, _) = split.unwrap(); - let mut lsr = self - ._list_impl( - permit, - bucket, - new_key, - delimiter.into(), - continuation_token.map(String::from), - &self.default_region, - ) - .await?; - let target_path = format!("s3://{bucket}/{new_key}"); - lsr.files.retain(|f| f.filepath == target_path); - - if lsr.files.is_empty() { - // Isnt a file or a directory - return Err(Error::NotFound { path: path.into() }.into()); - } - Ok(lsr) - } else { - Ok(lsr) + .await?; + let target_path = format!("s3://{bucket}/{new_key}"); + lsr.files.retain(|f| f.filepath == target_path); + + if lsr.files.is_empty() { + // Isnt a file or a directory + return Err(Error::NotFound { path: path.into() }.into()); } + Ok(lsr) } else { - Err(Error::NotAFile { path: path.into() }.into()) + Ok(lsr) } } } diff --git a/src/daft-plan/src/test/mod.rs b/src/daft-plan/src/test/mod.rs index c08c12008c..1b7406faaf 100644 --- a/src/daft-plan/src/test/mod.rs +++ b/src/daft-plan/src/test/mod.rs @@ -12,7 +12,7 @@ use crate::{ pub fn dummy_scan_node(fields: Vec) -> LogicalPlanBuilder { let schema = Arc::new(Schema::new(fields).unwrap()); LogicalPlanBuilder::table_scan( - FileInfos::new_internal(vec!["/foo".to_string()], vec![None], vec![None]).into(), + FileInfos::new_internal(vec!["/foo".to_string()], vec![None], vec![None]), schema, FileFormatConfig::Json(JsonSourceConfig {}).into(), StorageConfig::Native(NativeStorageConfig::new_internal(None).into()).into(), @@ -24,7 +24,7 @@ pub fn dummy_scan_node(fields: Vec) -> LogicalPlanBuilder { pub fn dummy_scan_node_with_limit(fields: Vec, limit: Option) -> LogicalPlanBuilder { let schema = Arc::new(Schema::new(fields).unwrap()); LogicalPlanBuilder::table_scan_with_limit( - FileInfos::new_internal(vec!["/foo".to_string()], vec![None], vec![None]).into(), + FileInfos::new_internal(vec!["/foo".to_string()], vec![None], vec![None]), schema, FileFormatConfig::Json(JsonSourceConfig {}).into(), StorageConfig::Native(NativeStorageConfig::new_internal(None).into()).into(), diff --git a/tests/integration/io/conftest.py b/tests/integration/io/conftest.py index ca5a923b5b..8d7c947f3c 100644 --- a/tests/integration/io/conftest.py +++ b/tests/integration/io/conftest.py @@ -79,8 +79,8 @@ def retry_server_s3_config(request) -> daft.io.IOConfig: @contextlib.contextmanager -def mount_data_minio( - minio_io_config: daft.io.IOConfig, folder: pathlib.Path, bucket_name: str = "my-minio-bucket" +def minio_create_bucket( + minio_io_config: daft.io.IOConfig, bucket_name: str = "my-minio-bucket" ) -> YieldFixture[list[str]]: """Mounts data in `folder` into files in minio @@ -92,22 +92,33 @@ def mount_data_minio( client_kwargs={"endpoint_url": minio_io_config.s3.endpoint_url}, ) fs.mkdir(bucket_name) - - urls = [] - for p in folder.glob("**/*"): - if not p.is_file(): - continue - key = str(p.relative_to(folder)) - url = f"s3://{bucket_name}/{key}" - fs.write_bytes(url, p.read_bytes()) - urls.append(url) - try: - yield urls + yield fs finally: fs.rm(bucket_name, recursive=True) +@contextlib.contextmanager +def mount_data_minio( + minio_io_config: daft.io.IOConfig, folder: pathlib.Path, bucket_name: str = "my-minio-bucket" +) -> YieldFixture[list[str]]: + """Mounts data in `folder` into files in minio + + Yields a list of S3 URLs + """ + with minio_create_bucket(minio_io_config=minio_io_config, bucket_name=bucket_name) as fs: + urls = [] + for p in folder.glob("**/*"): + if not p.is_file(): + continue + key = str(p.relative_to(folder)) + url = f"s3://{bucket_name}/{key}" + fs.write_bytes(url, p.read_bytes()) + urls.append(url) + + yield urls + + @contextlib.contextmanager def mount_data_nginx(nginx_config: tuple[str, pathlib.Path], folder: pathlib.Path) -> YieldFixture[list[str]]: """Mounts data in `folder` into servable static files in NGINX diff --git a/tests/integration/io/test_list_files_s3_minio.py b/tests/integration/io/test_list_files_s3_minio.py new file mode 100644 index 0000000000..2542996fb5 --- /dev/null +++ b/tests/integration/io/test_list_files_s3_minio.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import pytest + +from daft.daft import io_list + +from .conftest import minio_create_bucket + + +def compare_s3_result(daft_ls_result: list, s3fs_result: list): + daft_files = [(f["path"], f["type"].lower()) for f in daft_ls_result] + s3fs_files = [(f"s3://{f['Key']}", f["type"]) for f in s3fs_result] + assert sorted(daft_files) == sorted(s3fs_files) + + +@pytest.mark.integration() +def test_flat_directory_listing(minio_io_config): + bucket_name = "bucket" + with minio_create_bucket(minio_io_config, bucket_name=bucket_name) as fs: + files = ["a", "b", "c"] + for name in files: + fs.touch(f"{bucket_name}/{name}") + daft_ls_result = io_list(f"s3://{bucket_name}", io_config=minio_io_config) + s3fs_result = fs.ls(f"s3://{bucket_name}", detail=True) + compare_s3_result(daft_ls_result, s3fs_result)