Skip to content

Commit

Permalink
[FEAT] Native Recursive File Lister (#1353)
Browse files Browse the repository at this point in the history
* Implements `iter_dir` which gives a stream of directories from a dir
* Also implements a recursive version that will explore subdirs
  • Loading branch information
samster25 authored Sep 20, 2023
1 parent 1efeb33 commit f6370e5
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 68 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion src/daft-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ common-error = {path = "../common/error", default-features = false}
dyn-clone = "1.0.12"
fnv = "1.0.7"
html-escape = {workspace = true}
indexmap = {workspace = true, features = ["serde"], version = "2.0.0"}
indexmap = {workspace = true, features = ["serde"]}
lazy_static = {workspace = true}
log = {workspace = true}
ndarray = "0.15.6"
Expand Down
6 changes: 3 additions & 3 deletions src/daft-core/src/array/ops/concat_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ mod test {
Field::new("foo", DataType::List(Box::new(DataType::Int64))),
Int64Array::from((
"item",
Box::new(arrow2::array::Int64Array::from_iter(vec![].iter())),
Box::new(arrow2::array::Int64Array::from_iter([].iter())),
))
.into_series(),
arrow2::offset::OffsetsBuffer::<i64>::try_from(vec![0, 0, 0, 0])?,
Expand All @@ -190,7 +190,7 @@ mod test {
Int64Array::from((
"item",
Box::new(arrow2::array::Int64Array::from_iter(
vec![Some(0), Some(1), Some(1), Some(2), None, None, Some(10000)].iter(),
[Some(0), Some(1), Some(1), Some(2), None, None, Some(10000)].iter(),
)),
))
.into_series(),
Expand Down Expand Up @@ -225,7 +225,7 @@ mod test {
Int64Array::from((
"item",
Box::new(arrow2::array::Int64Array::from_iter(
vec![
[
Some(0),
Some(0),
Some(0),
Expand Down
1 change: 1 addition & 0 deletions src/daft-io/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[dependencies]
async-recursion = "1.0.4"
async-stream = "0.3.5"
async-trait = "0.1.71"
aws-config = {version = "0.55.3", features = ["native-tls", "rt-tokio", "client-hyper", "credentials-sso"], default-features = false}
aws-credential-types = {version = "0.55.3", features = ["hardcoded-credentials"]}
Expand Down
60 changes: 59 additions & 1 deletion src/daft-io/src/object_io.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use std::ops::Range;
use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use futures::stream::{BoxStream, Stream};
use futures::StreamExt;
use tokio::sync::mpsc::Sender;
use tokio::sync::OwnedSemaphorePermit;

use crate::local::{collect_file, LocalFile};
Expand Down Expand Up @@ -67,6 +69,8 @@ pub struct LSResult {
pub continuation_token: Option<String>,
}

use async_stream::stream;

#[async_trait]
pub(crate) trait ObjectSource: Sync + Send {
async fn get(&self, uri: &str, range: Option<Range<usize>>) -> super::Result<GetResult>;
Expand All @@ -81,5 +85,59 @@ pub(crate) trait ObjectSource: Sync + Send {
continuation_token: Option<&str>,
) -> super::Result<LSResult>;

// async fn iter_dir(&self, path: &str, limit: Option<usize>) -> super::Result<Box<dyn Stream<Item = super::Result<LSResult>>>>;
async fn iter_dir(
&self,
uri: &str,
delimiter: Option<&str>,
_limit: Option<usize>,
) -> super::Result<BoxStream<super::Result<FileMetadata>>> {
let uri = uri.to_string();
let delimiter = delimiter.map(String::from);
let s = stream! {
let lsr = self.ls(&uri, delimiter.as_deref(), None).await?;
let mut continuation_token = lsr.continuation_token.clone();
for file in lsr.files {
yield Ok(file);
}

while continuation_token.is_some() {
let lsr = self.ls(&uri, delimiter.as_deref(), continuation_token.as_deref()).await?;
continuation_token = lsr.continuation_token.clone();
for file in lsr.files {
yield Ok(file);
}
}
};
Ok(s.boxed())
}
}

pub(crate) async fn recursive_iter(
source: Arc<dyn ObjectSource>,
uri: &str,
) -> super::Result<BoxStream<super::Result<FileMetadata>>> {
let (to_rtn_tx, mut to_rtn_rx) = tokio::sync::mpsc::channel(16 * 1024);
fn add_to_channel(source: Arc<dyn ObjectSource>, tx: Sender<FileMetadata>, dir: String) {
tokio::spawn(async move {
let mut s = source.iter_dir(&dir, None, None).await.unwrap();
let tx = &tx;
while let Some(tr) = s.next().await {
let tr = tr.unwrap();
match tr.filetype {
FileType::File => tx.send(tr).await.unwrap(),
FileType::Directory => add_to_channel(source.clone(), tx.clone(), tr.filepath),
};
}
});
}

add_to_channel(source, to_rtn_tx, uri.to_string());

let to_rtn_stream = stream! {
while let Some(v) = to_rtn_rx.recv().await {
yield Ok(v)
}
};

Ok(to_rtn_stream.boxed())
}
23 changes: 19 additions & 4 deletions src/daft-io/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ pub use common_io_config::python::{AzureConfig, GCSConfig, IOConfig};
pub use py::register_modules;

mod py {
use crate::{get_io_client, get_runtime, object_io::LSResult, parse_url};
use crate::{get_io_client, get_runtime, object_io::recursive_iter, parse_url};
use common_error::DaftResult;
use futures::TryStreamExt;
use pyo3::{
prelude::*,
types::{PyDict, PyList},
Expand All @@ -13,10 +14,11 @@ mod py {
fn io_list(
py: Python,
path: String,
recursive: Option<bool>,
multithreaded_io: Option<bool>,
io_config: Option<common_io_config::python::IOConfig>,
) -> PyResult<&PyList> {
let lsr: DaftResult<LSResult> = py.allow_threads(|| {
let lsr: DaftResult<Vec<_>> = py.allow_threads(|| {
let io_client = get_io_client(
multithreaded_io.unwrap_or(true),
io_config.unwrap_or_default().config.into(),
Expand All @@ -27,12 +29,25 @@ mod py {

runtime_handle.block_on(async move {
let source = io_client.get_source(&scheme).await?;
Ok(source.ls(&path, None, None).await?)
let files = if recursive.is_some_and(|r| r) {
recursive_iter(source, &path)
.await?
.try_collect::<Vec<_>>()
.await?
} else {
source
.iter_dir(&path, None, None)
.await?
.try_collect::<Vec<_>>()
.await?
};

Ok(files)
})
});
let lsr = lsr?;
let mut to_rtn = vec![];
for file in lsr.files {
for file in lsr {
let dict = PyDict::new(py);
dict.set_item("type", format!("{:?}", file.filetype))?;
dict.set_item("path", file.filepath)?;
Expand Down
104 changes: 60 additions & 44 deletions src/daft-io/src/s3_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,20 @@ impl From<Error> for super::Error {
source: err.into(),
},
},
UnableToListObjects { path, source } => match source.into_service_error() {
ListObjectsV2Error::NoSuchBucket(no_such_key) => super::Error::NotFound {
path,
source: no_such_key.into(),
},
ListObjectsV2Error::Unhandled(v) => super::Error::Unhandled {
path,
msg: DisplayErrorContext(v).to_string(),
},
err => super::Error::UnableToOpenFile {
path,
source: err.into(),
},
},
InvalidUrl { path, source } => super::Error::InvalidUrl { path, source },
UnableToReadBytes { path, source } => super::Error::UnableToReadBytes {
path,
Expand Down Expand Up @@ -538,7 +552,7 @@ impl S3LikeSource {
Ok(v) => {
let dirs = v.common_prefixes();
let files = v.contents();
let continuation_token = v.continuation_token().map(|s| s.to_string());
let continuation_token = v.next_continuation_token().map(|s| s.to_string());
let mut total_len = 0;
if let Some(dirs) = dirs {
total_len += dirs.len()
Expand Down Expand Up @@ -641,6 +655,8 @@ impl ObjectSource for S3LikeSource {
) -> super::Result<LSResult> {
let parsed = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?;
let delimiter = delimiter.unwrap_or("/");
log::warn!("{:?}", parsed);

let bucket = match parsed.host_str() {
Some(s) => Ok(s),
None => Err(Error::InvalidUrl {
Expand All @@ -650,57 +666,57 @@ impl ObjectSource for S3LikeSource {
}?;
let key = parsed.path();

if let Some(key) = key.strip_prefix('/') {
// assume its a directory first
let key = format!("{}/", key.strip_suffix('/').unwrap_or(key));
let lsr = {
let permit = self
.connection_pool_sema
.acquire()
.await
.context(UnableToGrabSemaphoreSnafu)?;
self._list_impl(
let key = key
.strip_prefix('/')
.map(|k| k.strip_suffix('/').unwrap_or(k));
let key = key.unwrap_or("");

// assume its a directory first
let lsr = {
let permit = self
.connection_pool_sema
.acquire()
.await
.context(UnableToGrabSemaphoreSnafu)?;
self._list_impl(
permit,
bucket,
key,
delimiter.into(),
continuation_token.map(String::from),
&self.default_region,
)
.await?
};
if lsr.files.is_empty() && key.contains('/') {
let permit = self
.connection_pool_sema
.acquire()
.await
.context(UnableToGrabSemaphoreSnafu)?;
// Might be a File
let split = key.rsplit_once('/');
let (new_key, _) = split.unwrap();
let mut lsr = self
._list_impl(
permit,
bucket,
&key,
new_key,
delimiter.into(),
continuation_token.map(String::from),
&self.default_region,
)
.await?
};
if lsr.files.is_empty() && key.contains('/') {
let permit = self
.connection_pool_sema
.acquire()
.await
.context(UnableToGrabSemaphoreSnafu)?;
// Might be a File
let split = key.rsplit_once('/');
let (new_key, _) = split.unwrap();
let mut lsr = self
._list_impl(
permit,
bucket,
new_key,
delimiter.into(),
continuation_token.map(String::from),
&self.default_region,
)
.await?;
let target_path = format!("s3://{bucket}/{new_key}");
lsr.files.retain(|f| f.filepath == target_path);

if lsr.files.is_empty() {
// Isnt a file or a directory
return Err(Error::NotFound { path: path.into() }.into());
}
Ok(lsr)
} else {
Ok(lsr)
.await?;
let target_path = format!("s3://{bucket}/{new_key}");
lsr.files.retain(|f| f.filepath == target_path);

if lsr.files.is_empty() {
// Isnt a file or a directory
return Err(Error::NotFound { path: path.into() }.into());
}
Ok(lsr)
} else {
Err(Error::NotAFile { path: path.into() }.into())
Ok(lsr)
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/daft-plan/src/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
pub fn dummy_scan_node(fields: Vec<Field>) -> LogicalPlanBuilder {
let schema = Arc::new(Schema::new(fields).unwrap());
LogicalPlanBuilder::table_scan(
FileInfos::new_internal(vec!["/foo".to_string()], vec![None], vec![None]).into(),
FileInfos::new_internal(vec!["/foo".to_string()], vec![None], vec![None]),
schema,
FileFormatConfig::Json(JsonSourceConfig {}).into(),
StorageConfig::Native(NativeStorageConfig::new_internal(None).into()).into(),
Expand All @@ -24,7 +24,7 @@ pub fn dummy_scan_node(fields: Vec<Field>) -> LogicalPlanBuilder {
pub fn dummy_scan_node_with_limit(fields: Vec<Field>, limit: Option<usize>) -> LogicalPlanBuilder {
let schema = Arc::new(Schema::new(fields).unwrap());
LogicalPlanBuilder::table_scan_with_limit(
FileInfos::new_internal(vec!["/foo".to_string()], vec![None], vec![None]).into(),
FileInfos::new_internal(vec!["/foo".to_string()], vec![None], vec![None]),
schema,
FileFormatConfig::Json(JsonSourceConfig {}).into(),
StorageConfig::Native(NativeStorageConfig::new_internal(None).into()).into(),
Expand Down
37 changes: 24 additions & 13 deletions tests/integration/io/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def retry_server_s3_config(request) -> daft.io.IOConfig:


@contextlib.contextmanager
def mount_data_minio(
minio_io_config: daft.io.IOConfig, folder: pathlib.Path, bucket_name: str = "my-minio-bucket"
def minio_create_bucket(
minio_io_config: daft.io.IOConfig, bucket_name: str = "my-minio-bucket"
) -> YieldFixture[list[str]]:
"""Mounts data in `folder` into files in minio
Expand All @@ -92,22 +92,33 @@ def mount_data_minio(
client_kwargs={"endpoint_url": minio_io_config.s3.endpoint_url},
)
fs.mkdir(bucket_name)

urls = []
for p in folder.glob("**/*"):
if not p.is_file():
continue
key = str(p.relative_to(folder))
url = f"s3://{bucket_name}/{key}"
fs.write_bytes(url, p.read_bytes())
urls.append(url)

try:
yield urls
yield fs
finally:
fs.rm(bucket_name, recursive=True)


@contextlib.contextmanager
def mount_data_minio(
minio_io_config: daft.io.IOConfig, folder: pathlib.Path, bucket_name: str = "my-minio-bucket"
) -> YieldFixture[list[str]]:
"""Mounts data in `folder` into files in minio
Yields a list of S3 URLs
"""
with minio_create_bucket(minio_io_config=minio_io_config, bucket_name=bucket_name) as fs:
urls = []
for p in folder.glob("**/*"):
if not p.is_file():
continue
key = str(p.relative_to(folder))
url = f"s3://{bucket_name}/{key}"
fs.write_bytes(url, p.read_bytes())
urls.append(url)

yield urls


@contextlib.contextmanager
def mount_data_nginx(nginx_config: tuple[str, pathlib.Path], folder: pathlib.Path) -> YieldFixture[list[str]]:
"""Mounts data in `folder` into servable static files in NGINX
Expand Down
Loading

0 comments on commit f6370e5

Please sign in to comment.