diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index 60a79d03b2..d74d8a3541 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -5,8 +5,8 @@ use async_trait::async_trait; use bytes::Bytes; use futures::stream::{BoxStream, Stream}; use futures::StreamExt; -use tokio::sync::OwnedSemaphorePermit; use tokio::sync::mpsc::Sender; +use tokio::sync::OwnedSemaphorePermit; use tokio::sync::Semaphore; use crate::local::{collect_file, LocalFile}; @@ -86,7 +86,12 @@ pub(crate) trait ObjectSource: Sync + Send { continuation_token: Option<&str>, ) -> super::Result; - async fn iter_dir(&self, uri: &str, delimiter: Option<&str>, _limit: Option) -> super::Result>> { + async fn iter_dir( + &self, + uri: &str, + delimiter: Option<&str>, + _limit: Option, + ) -> super::Result>> { let uri = uri.to_string(); let delimiter = delimiter.map(String::from); let s = stream! { @@ -108,38 +113,26 @@ pub(crate) trait ObjectSource: Sync + Send { } } -pub(crate) async fn nested( +pub(crate) async fn recursive_iter( source: Arc, uri: &str, ) -> super::Result>> { let (to_rtn_tx, mut to_rtn_rx) = tokio::sync::mpsc::channel(16 * 1024); - let sema = Arc::new(tokio::sync::Semaphore::new(64)); - fn add_to_channel( - source: Arc, - tx: Sender, - dir: String, - connection_counter: Arc, - ) { + fn add_to_channel(source: Arc, tx: Sender, dir: String) { tokio::spawn(async move { - let _handle = connection_counter.acquire().await.unwrap(); let mut s = source.iter_dir(&dir, None, None).await.unwrap(); let tx = &tx; while let Some(tr) = s.next().await { let tr = tr.unwrap(); match tr.filetype { FileType::File => tx.send(tr).await.unwrap(), - FileType::Directory => add_to_channel( - source.clone(), - tx.clone(), - tr.filepath, - connection_counter.clone(), - ), + FileType::Directory => add_to_channel(source.clone(), tx.clone(), tr.filepath), }; } }); } - add_to_channel(source, to_rtn_tx, uri.to_string(), sema); + add_to_channel(source, to_rtn_tx, uri.to_string()); let to_rtn_stream = stream! { while let Some(v) = to_rtn_rx.recv().await { diff --git a/src/daft-io/src/python.rs b/src/daft-io/src/python.rs index 306ecdc745..808836e6b2 100644 --- a/src/daft-io/src/python.rs +++ b/src/daft-io/src/python.rs @@ -2,8 +2,13 @@ pub use common_io_config::python::{AzureConfig, GCSConfig, IOConfig}; pub use py::register_modules; mod py { - use crate::{get_io_client, get_runtime, object_io::LSResult, parse_url}; + use crate::{ + get_io_client, get_runtime, + object_io::{recursive_iter, LSResult}, + parse_url, + }; use common_error::DaftResult; + use futures::{StreamExt, TryStreamExt}; use pyo3::{ prelude::*, types::{PyDict, PyList}, @@ -13,10 +18,11 @@ mod py { fn io_list( py: Python, path: String, + recursive: bool, multithreaded_io: Option, io_config: Option, ) -> PyResult<&PyList> { - let lsr: DaftResult = py.allow_threads(|| { + let lsr: DaftResult> = py.allow_threads(|| { let io_client = get_io_client( multithreaded_io.unwrap_or(true), io_config.unwrap_or_default().config.into(), @@ -27,12 +33,25 @@ mod py { runtime_handle.block_on(async move { let source = io_client.get_source(&scheme).await?; - Ok(source.ls(&path, None, None).await?) + let files = if recursive { + recursive_iter(source, &path) + .await? + .try_collect::>() + .await? + } else { + source + .iter_dir(&path, None, None) + .await? + .try_collect::>() + .await? + }; + + Ok(files) }) }); let lsr = lsr?; let mut to_rtn = vec![]; - for file in lsr.files { + for file in lsr { let dict = PyDict::new(py); dict.set_item("type", format!("{:?}", file.filetype))?; dict.set_item("path", file.filepath)?;