Skip to content

Commit

Permalink
add binding for reccursive iter to python
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Sep 12, 2023
1 parent 51f7ced commit 4f9be30
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
29 changes: 11 additions & 18 deletions src/daft-io/src/object_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use async_trait::async_trait;
use bytes::Bytes;
use futures::stream::{BoxStream, Stream};
use futures::StreamExt;
use tokio::sync::OwnedSemaphorePermit;
use tokio::sync::mpsc::Sender;
use tokio::sync::OwnedSemaphorePermit;
use tokio::sync::Semaphore;

use crate::local::{collect_file, LocalFile};
Expand Down Expand Up @@ -86,7 +86,12 @@ pub(crate) trait ObjectSource: Sync + Send {
continuation_token: Option<&str>,
) -> super::Result<LSResult>;

async fn iter_dir(&self, uri: &str, delimiter: Option<&str>, _limit: Option<usize>) -> super::Result<BoxStream<super::Result<FileMetadata>>> {
async fn iter_dir(
&self,
uri: &str,
delimiter: Option<&str>,
_limit: Option<usize>,
) -> super::Result<BoxStream<super::Result<FileMetadata>>> {
let uri = uri.to_string();
let delimiter = delimiter.map(String::from);
let s = stream! {
Expand All @@ -108,38 +113,26 @@ pub(crate) trait ObjectSource: Sync + Send {
}
}

pub(crate) async fn nested(
pub(crate) async fn recursive_iter(
source: Arc<dyn ObjectSource>,
uri: &str,
) -> super::Result<BoxStream<super::Result<FileMetadata>>> {
let (to_rtn_tx, mut to_rtn_rx) = tokio::sync::mpsc::channel(16 * 1024);
let sema = Arc::new(tokio::sync::Semaphore::new(64));
fn add_to_channel(
source: Arc<dyn ObjectSource>,
tx: Sender<FileMetadata>,
dir: String,
connection_counter: Arc<Semaphore>,
) {
fn add_to_channel(source: Arc<dyn ObjectSource>, tx: Sender<FileMetadata>, dir: String) {
tokio::spawn(async move {
let _handle = connection_counter.acquire().await.unwrap();
let mut s = source.iter_dir(&dir, None, None).await.unwrap();
let tx = &tx;
while let Some(tr) = s.next().await {
let tr = tr.unwrap();
match tr.filetype {
FileType::File => tx.send(tr).await.unwrap(),
FileType::Directory => add_to_channel(
source.clone(),
tx.clone(),
tr.filepath,
connection_counter.clone(),
),
FileType::Directory => add_to_channel(source.clone(), tx.clone(), tr.filepath),
};
}
});
}

add_to_channel(source, to_rtn_tx, uri.to_string(), sema);
add_to_channel(source, to_rtn_tx, uri.to_string());

let to_rtn_stream = stream! {
while let Some(v) = to_rtn_rx.recv().await {
Expand Down
27 changes: 23 additions & 4 deletions src/daft-io/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@ pub use common_io_config::python::{AzureConfig, GCSConfig, IOConfig};
pub use py::register_modules;

mod py {
use crate::{get_io_client, get_runtime, object_io::LSResult, parse_url};
use crate::{
get_io_client, get_runtime,
object_io::{recursive_iter, LSResult},
parse_url,
};
use common_error::DaftResult;
use futures::{StreamExt, TryStreamExt};
use pyo3::{
prelude::*,
types::{PyDict, PyList},
Expand All @@ -13,10 +18,11 @@ mod py {
fn io_list(
py: Python,
path: String,
recursive: bool,
multithreaded_io: Option<bool>,
io_config: Option<common_io_config::python::IOConfig>,
) -> PyResult<&PyList> {
let lsr: DaftResult<LSResult> = py.allow_threads(|| {
let lsr: DaftResult<Vec<_>> = py.allow_threads(|| {
let io_client = get_io_client(
multithreaded_io.unwrap_or(true),
io_config.unwrap_or_default().config.into(),
Expand All @@ -27,12 +33,25 @@ mod py {

runtime_handle.block_on(async move {
let source = io_client.get_source(&scheme).await?;
Ok(source.ls(&path, None, None).await?)
let files = if recursive {
recursive_iter(source, &path)
.await?
.try_collect::<Vec<_>>()
.await?
} else {
source
.iter_dir(&path, None, None)
.await?
.try_collect::<Vec<_>>()
.await?
};

Ok(files)
})
});
let lsr = lsr?;
let mut to_rtn = vec![];
for file in lsr.files {
for file in lsr {
let dict = PyDict::new(py);
dict.set_item("type", format!("{:?}", file.filetype))?;
dict.set_item("path", file.filepath)?;
Expand Down

0 comments on commit 4f9be30

Please sign in to comment.