diff --git a/Cargo.lock b/Cargo.lock index 5ec5a3e035..ed2640e265 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1095,6 +1095,7 @@ dependencies = [ "snafu", "tempfile", "tokio", + "tokio-stream", "url", ] diff --git a/Cargo.toml b/Cargo.toml index 3a9f5a9ba3..19ea900845 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,6 +80,7 @@ rand = "^0.8" serde_json = "1.0.104" snafu = "0.7.4" tokio = {version = "1.32.0", features = ["net", "time", "bytes", "process", "signal", "macros", "rt", "rt-multi-thread"]} +tokio-stream = {version = "0.1.14", features = ["fs"]} [workspace.dependencies.arrow2] git = "https://github.com/Eventual-Inc/arrow2" diff --git a/src/common/error/src/error.rs b/src/common/error/src/error.rs index 604b139e48..c61a491685 100644 --- a/src/common/error/src/error.rs +++ b/src/common/error/src/error.rs @@ -20,6 +20,7 @@ pub enum DaftError { path: String, source: GenericError, }, + InternalError(String), External(GenericError), } @@ -31,7 +32,8 @@ impl std::error::Error for DaftError { | DaftError::TypeError(_) | DaftError::ComputeError(_) | DaftError::ArrowError(_) - | DaftError::ValueError(_) => None, + | DaftError::ValueError(_) + | DaftError::InternalError(_) => None, DaftError::IoError(io_error) => Some(io_error), DaftError::FileNotFound { source, .. } | DaftError::External(source) => Some(&**source), #[cfg(feature = "python")] @@ -96,6 +98,7 @@ impl Display for DaftError { Self::ComputeError(s) => write!(f, "DaftError::ComputeError {s}"), Self::ArrowError(s) => write!(f, "DaftError::ArrowError {s}"), Self::ValueError(s) => write!(f, "DaftError::ValueError {s}"), + Self::InternalError(s) => write!(f, "DaftError::InternalError {s}"), #[cfg(feature = "python")] Self::PyO3Error(e) => write!(f, "DaftError::PyO3Error {e}"), Self::IoError(e) => write!(f, "DaftError::IoError {e}"), diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index e36ae93a19..4662f3e33d 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -29,6 +29,7 @@ serde = {workspace = true} serde_json = {workspace = true} snafu = {workspace = true} tokio = {workspace = true} +tokio-stream = {workspace = true} url = "2.4.0" [dependencies.reqwest] diff --git a/src/daft-io/src/local.rs b/src/daft-io/src/local.rs index ccbbedc622..6cf62e9634 100644 --- a/src/daft-io/src/local.rs +++ b/src/daft-io/src/local.rs @@ -2,12 +2,16 @@ use std::io::SeekFrom; use std::ops::Range; use std::path::PathBuf; -use crate::object_io::LSResult; +use crate::object_io::{self, FileMetadata, LSResult}; use super::object_io::{GetResult, ObjectSource}; use super::Result; use async_trait::async_trait; use bytes::Bytes; +use common_error::DaftError; +use futures::stream::BoxStream; +use futures::StreamExt; +use futures::TryStreamExt; use snafu::{ResultExt, Snafu}; use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncSeekExt}; @@ -33,6 +37,21 @@ enum Error { source: std::io::Error, }, + #[snafu(display("Unable to fetch file metadata for file {}: {}", path, source))] + UnableToFetchFileMetadata { + path: String, + source: std::io::Error, + }, + + #[snafu(display("Unable to get entries for directory {}: {}", path, source))] + UnableToFetchDirectoryEntries { + path: String, + source: std::io::Error, + }, + + #[snafu(display("Unexpected symlink when processing directory {}: {}", path, source))] + UnexpectedSymlink { path: String, source: DaftError }, + #[snafu(display("Unable to parse URL \"{}\"", url.to_string_lossy()))] InvalidUrl { url: PathBuf, source: ParseError }, @@ -44,7 +63,9 @@ impl From for super::Error { fn from(error: Error) -> Self { use Error::*; match error { - UnableToOpenFile { path, source } => { + UnableToOpenFile { path, source } + | UnableToFetchFileMetadata { path, source } + | UnableToFetchDirectoryEntries { path, source } => { use std::io::ErrorKind::*; match source.kind() { NotFound => super::Error::NotFound { @@ -84,49 +105,104 @@ pub struct LocalFile { #[async_trait] impl ObjectSource for LocalSource { async fn get(&self, uri: &str, range: Option>) -> super::Result { - const TO_STRIP: &str = "file://"; - if let Some(p) = uri.strip_prefix(TO_STRIP) { - let path = std::path::Path::new(p); + const LOCAL_PROTOCOL: &str = "file://"; + if let Some(uri) = uri.strip_prefix(LOCAL_PROTOCOL) { Ok(GetResult::File(LocalFile { - path: path.to_path_buf(), + path: uri.into(), range, })) } else { - return Err(Error::InvalidFilePath { - path: uri.to_string(), - } - .into()); + Err(Error::InvalidFilePath { path: uri.into() }.into()) } } async fn get_size(&self, uri: &str) -> super::Result { - const TO_STRIP: &str = "file://"; - if let Some(p) = uri.strip_prefix(TO_STRIP) { - let path = std::path::Path::new(p); - let file = tokio::fs::File::open(path) - .await - .context(UnableToOpenFileSnafu { - path: path.to_string_lossy(), - })?; - let metadata = file.metadata().await.context(UnableToOpenFileSnafu { - path: path.to_string_lossy(), - })?; - return Ok(metadata.len() as usize); - } else { - return Err(Error::InvalidFilePath { + const LOCAL_PROTOCOL: &str = "file://"; + let Some(uri) = uri.strip_prefix(LOCAL_PROTOCOL) else { + return Err(Error::InvalidFilePath { path: uri.into() }.into()); + }; + let meta = tokio::fs::metadata(uri) + .await + .context(UnableToFetchFileMetadataSnafu { path: uri.to_string(), - } - .into()); - } + })?; + Ok(meta.len() as usize) } async fn ls( &self, - _path: &str, + path: &str, _delimiter: Option<&str>, _continuation_token: Option<&str>, ) -> super::Result { - unimplemented!("local ls"); + let s = self.iter_dir(path, None, None).await?; + let files = s.try_collect::>().await?; + Ok(LSResult { + files, + continuation_token: None, + }) + } + + async fn iter_dir( + &self, + uri: &str, + _delimiter: Option<&str>, + _limit: Option, + ) -> super::Result>> { + const LOCAL_PROTOCOL: &str = "file://"; + let Some(uri) = uri.strip_prefix(LOCAL_PROTOCOL) else { + return Err(Error::InvalidFilePath { path: uri.into() }.into()); + }; + let meta = + tokio::fs::metadata(uri) + .await + .with_context(|_| UnableToFetchFileMetadataSnafu { + path: uri.to_string(), + })?; + if meta.file_type().is_file() { + // Provided uri points to a file, so only return that file. + return Ok(futures::stream::iter([Ok(FileMetadata { + filepath: format!("{}{}", LOCAL_PROTOCOL, uri), + size: Some(meta.len()), + filetype: object_io::FileType::File, + })]) + .boxed()); + } + let dir_entries = tokio::fs::read_dir(uri).await.with_context(|_| { + UnableToFetchDirectoryEntriesSnafu { + path: uri.to_string(), + } + })?; + let dir_stream = tokio_stream::wrappers::ReadDirStream::new(dir_entries); + let uri = Arc::new(uri.to_string()); + let file_meta_stream = dir_stream.then(move |entry| { + let uri = uri.clone(); + async move { + let entry = entry.with_context(|_| UnableToFetchDirectoryEntriesSnafu { + path: uri.to_string(), + })?; + let meta = tokio::fs::metadata(entry.path()).await.with_context(|_| { + UnableToFetchFileMetadataSnafu { + path: entry.path().to_string_lossy().to_string(), + } + })?; + Ok(FileMetadata { + filepath: format!( + "{}{}{}", + LOCAL_PROTOCOL, + entry.path().to_string_lossy(), + if meta.is_dir() { "/" } else { "" } + ), + size: Some(meta.len()), + filetype: meta.file_type().try_into().with_context(|_| { + UnexpectedSymlinkSnafu { + path: entry.path().to_string_lossy().to_string(), + } + })?, + }) + } + }); + Ok(file_meta_stream.boxed()) } } @@ -171,16 +247,15 @@ pub(crate) async fn collect_file(local_file: LocalFile) -> Result { #[cfg(test)] mod tests { - use std::io::Write; - use crate::object_io::ObjectSource; + use crate::object_io::{FileMetadata, FileType, ObjectSource}; use crate::Result; use crate::{HttpSource, LocalSource}; - #[tokio::test] - async fn test_full_get_from_local() -> Result<()> { - let mut file1 = tempfile::NamedTempFile::new().unwrap(); + async fn write_remote_parquet_to_local_file( + f: &mut tempfile::NamedTempFile, + ) -> Result { let parquet_file_path = "https://daft-public-data.s3.us-west-2.amazonaws.com/test_fixtures/parquet_small/0dad4c3f-da0d-49db-90d8-98684571391b-0.parquet"; let parquet_expected_md5 = "929674747af64a98aceaa6d895863bd3"; @@ -190,15 +265,22 @@ mod tests { let all_bytes = bytes.as_ref(); let checksum = format!("{:x}", md5::compute(all_bytes)); assert_eq!(checksum, parquet_expected_md5); - file1.write_all(all_bytes).unwrap(); - file1.flush().unwrap(); + f.write_all(all_bytes).unwrap(); + f.flush().unwrap(); + Ok(bytes) + } + + #[tokio::test] + async fn test_local_full_get() -> Result<()> { + let mut file1 = tempfile::NamedTempFile::new().unwrap(); + let bytes = write_remote_parquet_to_local_file(&mut file1).await?; let parquet_file_path = format!("file://{}", file1.path().to_str().unwrap()); let client = LocalSource::get_client().await?; let try_all_bytes = client.get(&parquet_file_path, None).await?.bytes().await?; - assert_eq!(try_all_bytes.len(), all_bytes.len()); - assert_eq!(try_all_bytes.as_ref(), all_bytes); + assert_eq!(try_all_bytes.len(), bytes.len()); + assert_eq!(try_all_bytes, bytes); let first_bytes = client .get_range(&parquet_file_path, 0..10) @@ -206,7 +288,7 @@ mod tests { .bytes() .await?; assert_eq!(first_bytes.len(), 10); - assert_eq!(first_bytes.as_ref(), &all_bytes[..10]); + assert_eq!(first_bytes.as_ref(), &bytes[..10]); let first_bytes = client .get_range(&parquet_file_path, 10..100) @@ -214,21 +296,58 @@ mod tests { .bytes() .await?; assert_eq!(first_bytes.len(), 90); - assert_eq!(first_bytes.as_ref(), &all_bytes[10..100]); + assert_eq!(first_bytes.as_ref(), &bytes[10..100]); let last_bytes = client - .get_range( - &parquet_file_path, - (all_bytes.len() - 10)..(all_bytes.len() + 10), - ) + .get_range(&parquet_file_path, (bytes.len() - 10)..(bytes.len() + 10)) .await? .bytes() .await?; assert_eq!(last_bytes.len(), 10); - assert_eq!(last_bytes.as_ref(), &all_bytes[(all_bytes.len() - 10)..]); + assert_eq!(last_bytes.as_ref(), &bytes[(bytes.len() - 10)..]); let size_from_get_size = client.get_size(parquet_file_path.as_str()).await?; - assert_eq!(size_from_get_size, all_bytes.len()); + assert_eq!(size_from_get_size, bytes.len()); + + Ok(()) + } + + #[tokio::test] + async fn test_local_full_ls() -> Result<()> { + let dir = tempfile::tempdir().unwrap(); + let mut file1 = tempfile::NamedTempFile::new_in(dir.path()).unwrap(); + write_remote_parquet_to_local_file(&mut file1).await?; + let mut file2 = tempfile::NamedTempFile::new_in(dir.path()).unwrap(); + write_remote_parquet_to_local_file(&mut file2).await?; + let mut file3 = tempfile::NamedTempFile::new_in(dir.path()).unwrap(); + write_remote_parquet_to_local_file(&mut file3).await?; + let dir_path = format!("file://{}", dir.path().to_string_lossy()); + let client = LocalSource::get_client().await?; + + let ls_result = client.ls(dir_path.as_ref(), None, None).await?; + let mut files = ls_result.files.clone(); + // Ensure stable sort ordering of file paths before comparing with expected payload. + files.sort_by(|a, b| a.filepath.cmp(&b.filepath)); + let mut expected = vec![ + FileMetadata { + filepath: format!("file://{}", file1.path().to_string_lossy()), + size: Some(file1.as_file().metadata().unwrap().len()), + filetype: FileType::File, + }, + FileMetadata { + filepath: format!("file://{}", file2.path().to_string_lossy()), + size: Some(file2.as_file().metadata().unwrap().len()), + filetype: FileType::File, + }, + FileMetadata { + filepath: format!("file://{}", file3.path().to_string_lossy()), + size: Some(file3.as_file().metadata().unwrap().len()), + filetype: FileType::File, + }, + ]; + expected.sort_by(|a, b| a.filepath.cmp(&b.filepath)); + assert_eq!(files, expected); + assert_eq!(ls_result.continuation_token, None); Ok(()) } diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index 98bc23de31..9613d387d1 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use async_trait::async_trait; use bytes::Bytes; +use common_error::DaftError; use futures::stream::{BoxStream, Stream}; use futures::StreamExt; use tokio::sync::mpsc::Sender; @@ -52,12 +53,32 @@ impl GetResult { } } -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq)] pub enum FileType { File, Directory, } -#[derive(Debug)] + +impl TryFrom for FileType { + type Error = DaftError; + + fn try_from(value: std::fs::FileType) -> Result { + if value.is_dir() { + Ok(Self::Directory) + } else if value.is_file() { + Ok(Self::File) + } else if value.is_symlink() { + Err(DaftError::InternalError(format!("Symlinks should never be encountered when constructing FileMetadata, but got: {:?}", value))) + } else { + unreachable!( + "Can only be a directory, file, or symlink, but got: {:?}", + value + ) + } + } +} + +#[derive(Debug, Clone, PartialEq)] pub struct FileMetadata { pub filepath: String, pub size: Option, diff --git a/tests/integration/io/test_list_files_local.py b/tests/integration/io/test_list_files_local.py new file mode 100644 index 0000000000..dfd016038b --- /dev/null +++ b/tests/integration/io/test_list_files_local.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import pytest +from fsspec.implementations.local import LocalFileSystem + +from daft.daft import io_list + + +def local_recursive_list(fs, path) -> list: + all_results = [] + curr_level_result = fs.ls(path, detail=True) + for item in curr_level_result: + if item["type"] == "directory": + new_path = item["name"] + all_results.extend(local_recursive_list(fs, new_path)) + item["name"] += "/" + all_results.append(item) + else: + all_results.append(item) + return all_results + + +def compare_local_result(daft_ls_result: list, fs_result: list): + daft_files = [(f["path"], f["type"].lower()) for f in daft_ls_result] + fs_files = [(f'file://{f["name"]}', f["type"]) for f in fs_result] + assert sorted(daft_files) == sorted(fs_files) + + +@pytest.mark.integration() +@pytest.mark.parametrize("include_protocol", [False, True]) +def test_flat_directory_listing(tmp_path, include_protocol): + d = tmp_path / "dir" + d.mkdir() + files = ["a", "b", "c"] + for name in files: + p = d / name + p.touch() + d = str(d) + if include_protocol: + d = "file://" + d + daft_ls_result = io_list(d) + fs = LocalFileSystem() + fs_result = fs.ls(d, detail=True) + compare_local_result(daft_ls_result, fs_result) + + +@pytest.mark.integration() +@pytest.mark.parametrize("include_protocol", [False, True]) +def test_recursive_directory_listing(tmp_path, include_protocol): + d = tmp_path / "dir" + d.mkdir() + files = ["a", "b/bb", "c/cc/ccc"] + for name in files: + p = d + segments = name.split("/") + for intermediate_dir in segments[:-1]: + p /= intermediate_dir + p.mkdir() + p /= segments[-1] + p.touch() + d = str(d) + if include_protocol: + d = "file://" + d + daft_ls_result = io_list(d, recursive=True) + fs = LocalFileSystem() + fs_result = local_recursive_list(fs, d) + compare_local_result(daft_ls_result, fs_result) + + +@pytest.mark.integration() +@pytest.mark.parametrize("include_protocol", [False, True]) +@pytest.mark.parametrize( + "recursive", + [False, True], +) +def test_single_file_directory_listing(tmp_path, include_protocol, recursive): + d = tmp_path / "dir" + d.mkdir() + files = ["a", "b/bb", "c/cc/ccc"] + for name in files: + p = d + segments = name.split("/") + for intermediate_dir in segments[:-1]: + p /= intermediate_dir + p.mkdir() + p /= segments[-1] + p.touch() + p = f"{d}/c/cc/ccc" + if include_protocol: + p = "file://" + p + daft_ls_result = io_list(p, recursive=recursive) + fs_result = [{"name": f"{d}/c/cc/ccc", "type": "file"}] + assert len(daft_ls_result) == 1 + compare_local_result(daft_ls_result, fs_result) + + +@pytest.mark.integration() +@pytest.mark.parametrize("include_protocol", [False, True]) +def test_missing_file_path(tmp_path, include_protocol): + d = tmp_path / "dir" + d.mkdir() + files = ["a", "b/bb", "c/cc/ccc"] + for name in files: + p = d + segments = name.split("/") + for intermediate_dir in segments[:-1]: + p /= intermediate_dir + p.mkdir() + p /= segments[-1] + p.touch() + p = f"{d}/c/cc/ddd" + if include_protocol: + p = "file://" + p + with pytest.raises(FileNotFoundError, match=f"File: {d}/c/cc/ddd not found"): + daft_ls_result = io_list(p, recursive=True)