From f45f5231cd6d8b8a362d65cd4b74281fc16f3a32 Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Mon, 25 Sep 2023 15:00:10 -0700 Subject: [PATCH 1/7] [BUG] Fix runner check at plan execution time for new query planner (#1435) This PR fixes `is_ray_runner` check at plan execution time in the new query planner. Previously, we were checking the Daft context at execution time, which wouldn't be properly set under the Ray runner, since we don't propagate the Daft context to all Ray workers. This PR changes the runner check to be based on an explicit flag passed from the runner to the physical plan execution API. Closes #1433 --- daft/daft.pyi | 4 +- daft/execution/rust_physical_plan_shim.py | 4 +- daft/planner/planner.py | 4 +- daft/planner/py_planner.py | 4 +- daft/planner/rust_planner.py | 6 ++- daft/runners/pyrunner.py | 2 +- daft/runners/ray_runner.py | 2 +- src/daft-plan/src/physical_plan.rs | 53 ++++++++++++++--------- 8 files changed, 49 insertions(+), 30 deletions(-) diff --git a/daft/daft.pyi b/daft/daft.pyi index c4d9eee604..44d2c6880f 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -673,7 +673,9 @@ class PhysicalPlanScheduler: A work scheduler for physical query plans. """ - def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.MaterializedPhysicalPlan: ... + def to_partition_tasks( + self, psets: dict[str, list[PartitionT]], is_ray_runner: bool + ) -> physical_plan.MaterializedPhysicalPlan: ... class LogicalPlanBuilder: """ diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 2ffa9057a3..76afe65f8d 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -2,7 +2,6 @@ from typing import Iterator, TypeVar, cast -from daft.context import get_context from daft.daft import ( FileFormat, FileFormatConfig, @@ -29,10 +28,11 @@ def tabular_scan( file_format_config: FileFormatConfig, storage_config: StorageConfig, limit: int, + is_ray_runner: bool, ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: # TODO(Clark): Fix this Ray runner hack. part = Table._from_pytable(file_info_table) - if get_context().is_ray_runner: + if is_ray_runner: import ray parts = [ray.put(part)] diff --git a/daft/planner/planner.py b/daft/planner/planner.py index 5ee5a66346..1120f88e83 100644 --- a/daft/planner/planner.py +++ b/daft/planner/planner.py @@ -12,5 +12,7 @@ class PhysicalPlanScheduler(ABC): """ @abstractmethod - def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.MaterializedPhysicalPlan: + def to_partition_tasks( + self, psets: dict[str, list[PartitionT]], is_ray_runner: bool + ) -> physical_plan.MaterializedPhysicalPlan: pass diff --git a/daft/planner/py_planner.py b/daft/planner/py_planner.py index bf5321a7bb..a9ab2b90f8 100644 --- a/daft/planner/py_planner.py +++ b/daft/planner/py_planner.py @@ -9,5 +9,7 @@ class PyPhysicalPlanScheduler(PhysicalPlanScheduler): def __init__(self, plan: logical_plan.LogicalPlan): self._plan = plan - def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.MaterializedPhysicalPlan: + def to_partition_tasks( + self, psets: dict[str, list[PartitionT]], is_ray_runner: bool + ) -> physical_plan.MaterializedPhysicalPlan: return physical_plan.materialize(physical_plan_factory._get_physical_plan(self._plan, psets)) diff --git a/daft/planner/rust_planner.py b/daft/planner/rust_planner.py index cd1c00fd83..9ea74edcd5 100644 --- a/daft/planner/rust_planner.py +++ b/daft/planner/rust_planner.py @@ -9,5 +9,7 @@ class RustPhysicalPlanScheduler(PhysicalPlanScheduler): def __init__(self, scheduler: _PhysicalPlanScheduler): self._scheduler = scheduler - def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.MaterializedPhysicalPlan: - return physical_plan.materialize(self._scheduler.to_partition_tasks(psets)) + def to_partition_tasks( + self, psets: dict[str, list[PartitionT]], is_ray_runner: bool + ) -> physical_plan.MaterializedPhysicalPlan: + return physical_plan.materialize(self._scheduler.to_partition_tasks(psets, is_ray_runner)) diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 1ad9dcd7e5..f291727b13 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -148,7 +148,7 @@ def run_iter(self, builder: LogicalPlanBuilder) -> Iterator[Table]: if entry.value is not None } # Get executable tasks from planner. - tasks = plan_scheduler.to_partition_tasks(psets) + tasks = plan_scheduler.to_partition_tasks(psets, is_ray_runner=False) with profiler("profile_PyRunner.run_{datetime.now().isoformat()}.json"): partitions_gen = self._physical_plan_to_partitions(tasks) diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index d596974ff9..a7dd27f355 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -432,7 +432,7 @@ def _run_plan( from loguru import logger # Get executable tasks from plan scheduler. - tasks = plan_scheduler.to_partition_tasks(psets) + tasks = plan_scheduler.to_partition_tasks(psets, is_ray_runner=True) # Note: For autoscaling clusters, we will probably want to query cores dynamically. # Keep in mind this call takes about 0.3ms. diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 9c87facaa9..9766c530ac 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -64,8 +64,12 @@ pub struct PhysicalPlanScheduler { #[pymethods] impl PhysicalPlanScheduler { /// Converts the contained physical plan into an iterator of executable partition tasks. - pub fn to_partition_tasks(&self, psets: HashMap>) -> PyResult { - Python::with_gil(|py| self.plan.to_partition_tasks(py, &psets)) + pub fn to_partition_tasks( + &self, + psets: HashMap>, + is_ray_runner: bool, + ) -> PyResult { + Python::with_gil(|py| self.plan.to_partition_tasks(py, &psets, is_ray_runner)) } } @@ -98,6 +102,7 @@ impl PartitionIterator { } #[cfg(feature = "python")] +#[allow(clippy::too_many_arguments)] fn tabular_scan( py: Python<'_>, source_schema: &SchemaRef, @@ -106,6 +111,7 @@ fn tabular_scan( file_format_config: &Arc, storage_config: &Arc, limit: &Option, + is_ray_runner: bool, ) -> PyResult { let columns_to_read = projection_schema .fields @@ -123,6 +129,7 @@ fn tabular_scan( PyFileFormatConfig::from(file_format_config.clone()), PyStorageConfig::from(storage_config.clone()), *limit, + is_ray_runner, ))?; Ok(py_iter.into()) } @@ -162,6 +169,7 @@ impl PhysicalPlan { &self, py: Python<'_>, psets: &HashMap>, + is_ray_runner: bool, ) -> PyResult { match self { PhysicalPlan::InMemoryScan(InMemoryScan { @@ -198,6 +206,7 @@ impl PhysicalPlan { file_format_config, storage_config, limit, + is_ray_runner, ), PhysicalPlan::TabularScanCsv(TabularScanCsv { projection_schema, @@ -219,6 +228,7 @@ impl PhysicalPlan { file_format_config, storage_config, limit, + is_ray_runner, ), PhysicalPlan::TabularScanJson(TabularScanJson { projection_schema, @@ -240,13 +250,14 @@ impl PhysicalPlan { file_format_config, storage_config, limit, + is_ray_runner, ), PhysicalPlan::Project(Project { input, projection, resource_request, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let projection_pyexprs: Vec = projection .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -258,7 +269,7 @@ impl PhysicalPlan { Ok(py_iter.into()) } PhysicalPlan::Filter(Filter { input, predicate }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let expressions_mod = py.import(pyo3::intern!(py, "daft.expressions.expressions"))?; let py_predicate = expressions_mod @@ -287,7 +298,7 @@ impl PhysicalPlan { limit, num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_physical_plan = py.import(pyo3::intern!(py, "daft.execution.physical_plan"))?; let local_limit_iter = py_physical_plan @@ -299,7 +310,7 @@ impl PhysicalPlan { Ok(global_limit_iter.into()) } PhysicalPlan::Explode(Explode { input, to_explode }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let explode_pyexprs: Vec = to_explode .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -316,7 +327,7 @@ impl PhysicalPlan { descending, num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let sort_by_pyexprs: Vec = sort_by .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -337,7 +348,7 @@ impl PhysicalPlan { input_num_partitions, output_num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "split"))? @@ -345,7 +356,7 @@ impl PhysicalPlan { Ok(py_iter.into()) } PhysicalPlan::Flatten(Flatten { input }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "flatten_plan"))? @@ -356,7 +367,7 @@ impl PhysicalPlan { input, num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "fanout_random"))? @@ -368,7 +379,7 @@ impl PhysicalPlan { num_partitions, partition_by, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let partition_by_pyexprs: Vec = partition_by .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -383,7 +394,7 @@ impl PhysicalPlan { "FanoutByRange not implemented, since only use case (sorting) doesn't need it yet." ), PhysicalPlan::ReduceMerge(ReduceMerge { input }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? .getattr(pyo3::intern!(py, "reduce_merge"))? @@ -396,7 +407,7 @@ impl PhysicalPlan { input, .. }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let aggs_as_pyexprs: Vec = aggregations .iter() .map(|agg_expr| PyExpr::from(Expr::Agg(agg_expr.clone()))) @@ -416,7 +427,7 @@ impl PhysicalPlan { num_from, num_to, }) => { - let upstream_iter = input.to_partition_tasks(py, psets)?; + let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "coalesce"))? @@ -424,8 +435,8 @@ impl PhysicalPlan { Ok(py_iter.into()) } PhysicalPlan::Concat(Concat { other, input }) => { - let upstream_input_iter = input.to_partition_tasks(py, psets)?; - let upstream_other_iter = other.to_partition_tasks(py, psets)?; + let upstream_input_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_other_iter = other.to_partition_tasks(py, psets, is_ray_runner)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "concat"))? @@ -440,8 +451,8 @@ impl PhysicalPlan { join_type, .. }) => { - let upstream_left_iter = left.to_partition_tasks(py, psets)?; - let upstream_right_iter = right.to_partition_tasks(py, psets)?; + let upstream_left_iter = left.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_right_iter = right.to_partition_tasks(py, psets, is_ray_runner)?; let left_on_pyexprs: Vec = left_on .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -474,7 +485,7 @@ impl PhysicalPlan { input, }) => tabular_write( py, - input.to_partition_tasks(py, psets)?, + input.to_partition_tasks(py, psets, is_ray_runner)?, file_format, schema, root_dir, @@ -493,7 +504,7 @@ impl PhysicalPlan { input, }) => tabular_write( py, - input.to_partition_tasks(py, psets)?, + input.to_partition_tasks(py, psets, is_ray_runner)?, file_format, schema, root_dir, @@ -512,7 +523,7 @@ impl PhysicalPlan { input, }) => tabular_write( py, - input.to_partition_tasks(py, psets)?, + input.to_partition_tasks(py, psets, is_ray_runner)?, file_format, schema, root_dir, From a024123dd2f157d562ac73b9b337b40914b1921e Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Mon, 25 Sep 2023 15:55:13 -0700 Subject: [PATCH 2/7] [BUG] Fix scheme bug in GCS anonymous mode (#1443) I think the "deadlocking" in #1432 is actually caused by Python's difflib being really slow on the diff of the bad result. Also, because our io integration tests on that workflow isn't provided with GCP credentials, it is defaulting to the "anonymous" GCS client which is s3 based. We should fix that once we verify that this fix works. Closes: #1432 --------- Co-authored-by: Jay Chia --- src/daft-io/src/s3_like.rs | 20 ++++++++++++++++---- tests/integration/io/test_list_files_gcs.py | 20 +++++++++++++------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 2b4614e110..8dba518e46 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -541,10 +541,12 @@ impl S3LikeSource { } } + #[allow(clippy::too_many_arguments)] #[async_recursion] async fn _list_impl( &self, _permit: SemaphorePermit<'async_recursion>, + scheme: &str, bucket: &str, key: &str, delimiter: String, @@ -587,7 +589,7 @@ impl S3LikeSource { } else { request.send().await }; - let uri = &format!("s3://{bucket}/{key}"); + let uri = &format!("{scheme}://{bucket}/{key}"); match response { Ok(v) => { let dirs = v.common_prefixes(); @@ -604,7 +606,10 @@ impl S3LikeSource { if let Some(dirs) = dirs { for d in dirs { let fmeta = FileMetadata { - filepath: format!("s3://{bucket}/{}", d.prefix().unwrap_or_default()), + filepath: format!( + "{scheme}://{bucket}/{}", + d.prefix().unwrap_or_default() + ), size: None, filetype: FileType::Directory, }; @@ -614,7 +619,10 @@ impl S3LikeSource { if let Some(files) = files { for f in files { let fmeta = FileMetadata { - filepath: format!("s3://{bucket}/{}", f.key().unwrap_or_default()), + filepath: format!( + "{scheme}://{bucket}/{}", + f.key().unwrap_or_default() + ), size: Some(f.size() as u64), filetype: FileType::File, }; @@ -646,6 +654,7 @@ impl S3LikeSource { log::debug!("S3 Region of {uri} different than client {:?} vs {:?} Attempting List in that region with new client", new_region, region); self._list_impl( _permit, + scheme, bucket, key, delimiter, @@ -694,6 +703,7 @@ impl ObjectSource for S3LikeSource { continuation_token: Option<&str>, ) -> super::Result { let parsed = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?; + let scheme = parsed.scheme(); let delimiter = delimiter.unwrap_or("/"); let bucket = match parsed.host_str() { @@ -723,6 +733,7 @@ impl ObjectSource for S3LikeSource { self._list_impl( permit, + scheme, bucket, &key, delimiter.into(), @@ -742,6 +753,7 @@ impl ObjectSource for S3LikeSource { let mut lsr = self ._list_impl( permit, + scheme, bucket, key, delimiter.into(), @@ -749,7 +761,7 @@ impl ObjectSource for S3LikeSource { &self.default_region, ) .await?; - let target_path = format!("s3://{bucket}/{key}"); + let target_path = format!("{scheme}://{bucket}/{key}"); lsr.files.retain(|f| f.filepath == target_path); if lsr.files.is_empty() { diff --git a/tests/integration/io/test_list_files_gcs.py b/tests/integration/io/test_list_files_gcs.py index 640d9d5be4..1053e25fa1 100644 --- a/tests/integration/io/test_list_files_gcs.py +++ b/tests/integration/io/test_list_files_gcs.py @@ -3,9 +3,11 @@ import gcsfs import pytest -from daft.daft import io_list +from daft.daft import GCSConfig, IOConfig, io_list BUCKET = "daft-public-data-gs" +DEFAULT_GCS_CONFIG = GCSConfig(project_id=None, anonymous=None) +ANON_GCS_CONFIG = GCSConfig(project_id=None, anonymous=True) def gcsfs_recursive_list(fs, path) -> list: @@ -49,28 +51,32 @@ def compare_gcs_result(daft_ls_result: list, fsspec_result: list): ], ) @pytest.mark.parametrize("recursive", [False, True]) -def test_gs_flat_directory_listing(path, recursive): +@pytest.mark.parametrize("gcs_config", [DEFAULT_GCS_CONFIG, ANON_GCS_CONFIG]) +def test_gs_flat_directory_listing(path, recursive, gcs_config): fs = gcsfs.GCSFileSystem() - daft_ls_result = io_list(path, recursive=recursive) + daft_ls_result = io_list(path, recursive=recursive, io_config=IOConfig(gcs=gcs_config)) fsspec_result = gcsfs_recursive_list(fs, path) if recursive else fs.ls(path, detail=True) compare_gcs_result(daft_ls_result, fsspec_result) @pytest.mark.integration() @pytest.mark.parametrize("recursive", [False, True]) -def test_gs_single_file_listing(recursive): +@pytest.mark.parametrize("gcs_config", [DEFAULT_GCS_CONFIG, ANON_GCS_CONFIG]) +def test_gs_single_file_listing(recursive, gcs_config): path = f"gs://{BUCKET}/test_ls/file.txt" fs = gcsfs.GCSFileSystem() - daft_ls_result = io_list(path, recursive=recursive) + daft_ls_result = io_list(path, recursive=recursive, io_config=IOConfig(gcs=gcs_config)) fsspec_result = gcsfs_recursive_list(fs, path) if recursive else fs.ls(path, detail=True) compare_gcs_result(daft_ls_result, fsspec_result) @pytest.mark.integration() -def test_gs_notfound(): +@pytest.mark.parametrize("recursive", [False, True]) +@pytest.mark.parametrize("gcs_config", [DEFAULT_GCS_CONFIG, ANON_GCS_CONFIG]) +def test_gs_notfound(recursive, gcs_config): path = f"gs://{BUCKET}/test_ls/MISSING" fs = gcsfs.GCSFileSystem() with pytest.raises(FileNotFoundError): fs.ls(path, detail=True) with pytest.raises(FileNotFoundError, match=path): - io_list(path) + io_list(path, recursive=recursive, io_config=IOConfig(gcs=gcs_config)) From 97845d30dd336874c1666b23cac10f431850d14e Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Mon, 25 Sep 2023 16:03:46 -0700 Subject: [PATCH 3/7] [BUG] Fix num input partitions in coalesce. (#1442) This PR fixes the number of input partitions in our `Coalesce` implementation. The new query planner was incorrectly using the current logical node's `num_partitions` for the number of _input_ partitions, which for a `Coalesce` op, is equal to the _output_ number of partitions; we should instead be using the input (child) logical node's `num_partitions`. The microbenchmarking failure also illustrated that we didn't have any unit testing coverage of `Coalesce` (i.e. `df.into_partitions(n)` where `n` is smaller than the current number of partitions), so we also add a sanity unit test. Closes #1434 --- src/daft-plan/src/planner.rs | 2 +- tests/dataframe/test_repartition.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 5a38d0ed02..31b2f5e0fb 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -171,7 +171,7 @@ pub fn plan(logical_plan: &LogicalPlan) -> DaftResult { let input_physical = plan(input)?; Ok(PhysicalPlan::Coalesce(Coalesce::new( input_physical.into(), - logical_plan.partition_spec().num_partitions, + input.partition_spec().num_partitions, *num_to, ))) } diff --git a/tests/dataframe/test_repartition.py b/tests/dataframe/test_repartition.py index 92c96e5721..88d6281088 100644 --- a/tests/dataframe/test_repartition.py +++ b/tests/dataframe/test_repartition.py @@ -7,3 +7,9 @@ def test_into_partitions_some_empty() -> None: data = {"foo": [1, 2, 3]} df = daft.from_pydict(data).into_partitions(32).collect() assert df.to_pydict() == data + + +def test_into_partitions_coalesce() -> None: + data = {"foo": list(range(100))} + df = daft.from_pydict(data).into_partitions(20).into_partitions(1).collect() + assert df.to_pydict() == data From 3403c0c5a16f73303bb0d72fafa9a8eb3c9ec9a5 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Mon, 25 Sep 2023 18:13:43 -0700 Subject: [PATCH 4/7] [BUG] only publish on main for schedule dispatch anaconda (#1444) --- .github/workflows/python-publish.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 309b6ec26f..0b2dcf8af7 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -130,7 +130,6 @@ jobs: publish: name: Publish wheels to PYPI and Anaconda - if: ${{ (github.ref == 'refs/heads/main') }} runs-on: ubuntu-latest needs: - build-and-test @@ -166,7 +165,7 @@ jobs: run: conda install -q -y anaconda-client "urllib3<2.0" - name: Upload wheels to anaconda nightly - if: ${{ success() && (env.IS_SCHEDULE_DISPATCH == 'true' || env.IS_PUSH == 'true') }} + if: ${{ success() && (((env.IS_SCHEDULE_DISPATCH == 'true') && (github.ref == 'refs/heads/main')) || env.IS_PUSH == 'true') }} shell: bash -el {0} env: DAFT_STAGING_UPLOAD_TOKEN: ${{ secrets.DAFT_STAGING_UPLOAD_TOKEN }} From 558b31e6f213e5b4972076f08bf1131b9ad3c75e Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Mon, 25 Sep 2023 18:32:01 -0700 Subject: [PATCH 5/7] [FEAT] Native listing of http URLs (#1405) Adds capabilities for our HTTP source to perform file listing Logic is roughly adapted from fsspec's [HTTPFileSystem](https://filesystem-spec.readthedocs.io/en/latest/_modules/fsspec/implementations/http.html) in order to maintain some parity with how it behaves. Current differences in behavior: 1. We don't support fsspec's `simple_links=True` option, which would search for any text starting with `http(s)://` and that aren't encased in a HTML `` tag. 2. When pointed at a single file, fsspec returns an empty list but we detect whether the file is HTML and return a list with just the file itself if we see that it is not HTML. --------- Co-authored-by: Jay Chia --- Cargo.lock | 17 +- src/daft-io/Cargo.toml | 1 + src/daft-io/src/http.rs | 129 +++++++++++++++- .../nginx-serve-static-files.conf | 2 +- tests/integration/io/conftest.py | 41 +++-- tests/integration/io/test_list_files_http.py | 146 ++++++++++++++++++ 6 files changed, 306 insertions(+), 30 deletions(-) create mode 100644 tests/integration/io/test_list_files_http.py diff --git a/Cargo.lock b/Cargo.lock index f56c399227..5ec5a3e035 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1088,6 +1088,7 @@ dependencies = [ "openssl-sys", "pyo3", "pyo3-log", + "regex", "reqwest", "serde", "serde_json", @@ -2066,9 +2067,9 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" [[package]] name = "memoffset" @@ -2834,9 +2835,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.1" +version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2eae68fc220f7cf2532e4494aded17545fce192d59cd996e0fe7887f4ceb575" +checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" dependencies = [ "aho-corasick", "memchr", @@ -2846,9 +2847,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.2" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83d3daa6976cffb758ec878f108ba0e062a45b2d6ca3a2cca965338855476caf" +checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" dependencies = [ "aho-corasick", "memchr", @@ -2857,9 +2858,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.3" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ab07dc67230e4a4718e70fd5c20055a4334b121f1f9db8fe63ef39ce9b8c846" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" [[package]] name = "reqwest" diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index e18cd06c00..e36ae93a19 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -24,6 +24,7 @@ log = {workspace = true} openssl-sys = {version = "0.9.93", features = ["vendored"]} pyo3 = {workspace = true, optional = true} pyo3-log = {workspace = true, optional = true} +regex = {version = "1.9.5"} serde = {workspace = true} serde_json = {workspace = true} snafu = {workspace = true} diff --git a/src/daft-io/src/http.rs b/src/daft-io/src/http.rs index a98728d91b..5c2f3b252f 100644 --- a/src/daft-io/src/http.rs +++ b/src/daft-io/src/http.rs @@ -3,13 +3,22 @@ use std::{num::ParseIntError, ops::Range, string::FromUtf8Error, sync::Arc}; use async_trait::async_trait; use futures::{StreamExt, TryStreamExt}; +use lazy_static::lazy_static; +use regex::Regex; use reqwest::header::{CONTENT_LENGTH, RANGE}; use snafu::{IntoError, ResultExt, Snafu}; +use url::Position; -use crate::object_io::LSResult; +use crate::object_io::{FileMetadata, FileType, LSResult}; use super::object_io::{GetResult, ObjectSource}; +lazy_static! { + // Taken from: https://stackoverflow.com/a/15926317/3821154 + static ref HTML_A_TAG_HREF_RE: Regex = + Regex::new(r#"<(a|A)\s+(?:[^>]*?\s+)?(href|HREF)=["'](?P[^"']+)"#).unwrap(); +} + #[derive(Debug, Snafu)] enum Error { #[snafu(display("Unable to connect to {}: {}", path, source))] @@ -45,7 +54,15 @@ enum Error { #[snafu(display( "Unable to parse data as Utf8 while reading header for file: {path}. {source}" ))] - UnableToParseUtf8 { path: String, source: FromUtf8Error }, + UnableToParseUtf8Header { path: String, source: FromUtf8Error }, + + #[snafu(display( + "Unable to parse data as Utf8 while reading body for file: {path}. {source}" + ))] + UnableToParseUtf8Body { + path: String, + source: reqwest::Error, + }, #[snafu(display( "Unable to parse data as Integer while reading header for file: {path}. {source}" @@ -53,6 +70,64 @@ enum Error { UnableToParseInteger { path: String, source: ParseIntError }, } +/// Finds and retrieves FileMetadata from HTML text +/// +/// This function will look for `` tags and return all the links that it finds as +/// absolute URLs +fn _get_file_metadata_from_html(path: &str, text: &str) -> super::Result> { + let path_url = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?; + let metas = HTML_A_TAG_HREF_RE + .captures_iter(text) + .map(|captures| { + // Parse the matched URL into an absolute URL + let matched_url = captures.name("url").unwrap().as_str(); + let absolute_path = if let Ok(parsed_matched_url) = url::Url::parse(matched_url) { + // matched_url is already an absolute path + parsed_matched_url + } else if matched_url.starts_with('/') { + // matched_url is a path relative to the origin of `path` + let base = url::Url::parse(&path_url[..Position::BeforePath]).unwrap(); + base.join(matched_url) + .with_context(|_| InvalidUrlSnafu { path: matched_url })? + } else { + // matched_url is a path relative to `path` and needs to be joined + path_url + .join(matched_url) + .with_context(|_| InvalidUrlSnafu { path: matched_url })? + }; + + // Ignore any links that are not descendants of `path` to avoid cycles + let relative = path_url.make_relative(&absolute_path); + match relative { + None => { + return Ok(None); + } + Some(relative_path) + if relative_path.is_empty() || relative_path.starts_with("..") => + { + return Ok(None); + } + _ => (), + }; + + let filetype = if matched_url.ends_with('/') { + FileType::Directory + } else { + FileType::File + }; + Ok(Some(FileMetadata { + filepath: absolute_path.to_string(), + // NOTE: This is consistent with fsspec behavior, but we may choose to HEAD the files to grab Content-Length + // for populating `size` if necessary + size: None, + filetype, + })) + }) + .collect::>>()?; + + Ok(metas.into_iter().flatten().collect()) +} + pub(crate) struct HttpSource { client: reqwest::Client, } @@ -135,8 +210,9 @@ impl ObjectSource for HttpSource { let headers = response.headers(); match headers.get(CONTENT_LENGTH) { Some(v) => { - let size_bytes = String::from_utf8(v.as_bytes().to_vec()) - .with_context(|_| UnableToParseUtf8Snafu:: { path: uri.into() })?; + let size_bytes = String::from_utf8(v.as_bytes().to_vec()).with_context(|_| { + UnableToParseUtf8HeaderSnafu:: { path: uri.into() } + })?; Ok(size_bytes .parse() @@ -148,11 +224,52 @@ impl ObjectSource for HttpSource { async fn ls( &self, - _path: &str, + path: &str, _delimiter: Option<&str>, _continuation_token: Option<&str>, ) -> super::Result { - unimplemented!("http ls"); + let request = self.client.get(path); + let response = request + .send() + .await + .context(UnableToConnectSnafu:: { path: path.into() })? + .error_for_status() + .with_context(|_| UnableToOpenFileSnafu { path })?; + + // Reconstruct the actual path of the request, which may have been redirected via a 301 + // This is important because downstream URL joining logic relies on proper trailing-slashes/index.html + let path = response.url().to_string(); + let path = if path.ends_with('/') { + format!("{}/", path.trim_end_matches('/')) + } else { + path + }; + + match response.headers().get("content-type") { + // If the content-type is text/html, we treat the data on this path as a traversable "directory" + Some(header_value) if header_value.to_str().map_or(false, |v| v == "text/html") => { + let text = response + .text() + .await + .with_context(|_| UnableToParseUtf8BodySnafu { + path: path.to_string(), + })?; + let file_metadatas = _get_file_metadata_from_html(path.as_str(), text.as_str())?; + Ok(LSResult { + files: file_metadatas, + continuation_token: None, + }) + } + // All other forms of content-type is treated as a raw file + _ => Ok(LSResult { + files: vec![FileMetadata { + filepath: path.to_string(), + filetype: FileType::File, + size: response.content_length(), + }], + continuation_token: None, + }), + } } } diff --git a/tests/integration/docker-compose/nginx-serve-static-files.conf b/tests/integration/docker-compose/nginx-serve-static-files.conf index 0c097a5096..9673ecd43b 100644 --- a/tests/integration/docker-compose/nginx-serve-static-files.conf +++ b/tests/integration/docker-compose/nginx-serve-static-files.conf @@ -11,7 +11,7 @@ http { listen [::]:8080; resolver 127.0.0.11; - autoindex off; + autoindex on; server_name _; server_tokens off; diff --git a/tests/integration/io/conftest.py b/tests/integration/io/conftest.py index 13076bca96..e950cafda1 100644 --- a/tests/integration/io/conftest.py +++ b/tests/integration/io/conftest.py @@ -129,21 +129,32 @@ def mount_data_nginx(nginx_config: tuple[str, pathlib.Path], folder: pathlib.Pat """ server_url, static_assets_tmpdir = nginx_config - # Copy data - for root, dirs, files in os.walk(folder, topdown=False): - for file in files: - shutil.copy2(os.path.join(root, file), os.path.join(static_assets_tmpdir, file)) - for dir in dirs: - shutil.copytree(os.path.join(root, dir), os.path.join(static_assets_tmpdir, dir)) - - yield [f"{server_url}/{p.relative_to(folder)}" for p in folder.glob("**/*") if p.is_file()] - - # Delete data - for root, dirs, files in os.walk(static_assets_tmpdir, topdown=False): - for file in files: - os.remove(os.path.join(root, file)) - for dir in dirs: - os.rmdir(os.path.join(root, dir)) + # Cleanup any old stuff in mount folder + for item in os.listdir(static_assets_tmpdir): + path = static_assets_tmpdir / item + if path.is_dir(): + shutil.rmtree(path) + else: + os.remove(path) + + # Copy data to mount folder + for item in os.listdir(folder): + src = folder / item + dest = static_assets_tmpdir / item + if src.is_dir(): + shutil.copytree(str(src), str(dest)) + else: + shutil.copy2(src, dest) + + try: + yield [f"{server_url}/{p.relative_to(folder)}" for p in folder.glob("**/*") if p.is_file()] + finally: + for item in os.listdir(static_assets_tmpdir): + path = static_assets_tmpdir / item + if path.is_dir(): + shutil.rmtree(static_assets_tmpdir / item) + else: + os.remove(static_assets_tmpdir / item) ### diff --git a/tests/integration/io/test_list_files_http.py b/tests/integration/io/test_list_files_http.py new file mode 100644 index 0000000000..2471133c3f --- /dev/null +++ b/tests/integration/io/test_list_files_http.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest +from fsspec.implementations.http import HTTPFileSystem + +from daft.daft import io_list +from tests.integration.io.conftest import mount_data_nginx + + +def compare_http_result(daft_ls_result: list, fsspec_result: list): + daft_files = [(f["path"], f["type"].lower(), f["size"]) for f in daft_ls_result] + httpfs_files = [(f["name"], f["type"], f["size"]) for f in fsspec_result] + assert len(daft_files) == len(httpfs_files) + assert sorted(daft_files) == sorted(httpfs_files) + + +@pytest.fixture(scope="module") +def nginx_http_url(nginx_config, tmpdir_factory): + tmpdir = tmpdir_factory.mktemp("test-list-http") + data_path = Path(tmpdir) + (Path(data_path) / "file.txt").touch() + (Path(data_path) / "test_ls").mkdir() + (Path(data_path) / "test_ls" / "file.txt").touch() + (Path(data_path) / "test_ls" / "paginated-10-files").mkdir() + for i in range(10): + (Path(data_path) / "test_ls" / "paginated-10-files" / f"file.{i}.txt").touch() + + with mount_data_nginx(nginx_config, data_path): + yield nginx_config[0] + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "path", + [ + f"", + f"/", + f"/test_ls", + f"/test_ls/", + f"/test_ls//", + f"/test_ls/paginated-10-files/", + ], +) +def test_http_flat_directory_listing(path, nginx_http_url): + http_path = f"{nginx_http_url}{path}" + fs = HTTPFileSystem() + fsspec_result = fs.ls(http_path, detail=True) + daft_ls_result = io_list(http_path) + compare_http_result(daft_ls_result, fsspec_result) + + +@pytest.mark.integration() +def test_gs_single_file_listing(nginx_http_url): + path = f"{nginx_http_url}/test_ls/file.txt" + daft_ls_result = io_list(path) + + # NOTE: FSSpec will return size 0 list for this case, but we want to return 1 element to be + # consistent with behavior of our other file listing utilities + # fs = HTTPFileSystem() + # fsspec_result = fs.ls(path, detail=True) + + assert len(daft_ls_result) == 1 + assert daft_ls_result[0] == {"path": path, "size": 0, "type": "File"} + + +@pytest.mark.integration() +def test_http_notfound(nginx_http_url): + path = f"{nginx_http_url}/test_ls/MISSING" + fs = HTTPFileSystem() + with pytest.raises(FileNotFoundError, match=path): + fs.ls(path, detail=True) + + with pytest.raises(FileNotFoundError, match=path): + io_list(path) + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "path", + [ + f"", + f"/", + ], +) +def test_http_flat_directory_listing_recursive(path, nginx_http_url): + http_path = f"{nginx_http_url}/{path}" + fs = HTTPFileSystem() + fsspec_result = list(fs.glob(http_path.rstrip("/") + "/**", detail=True).values()) + daft_ls_result = io_list(http_path, recursive=True) + compare_http_result(daft_ls_result, fsspec_result) + + +@pytest.mark.integration() +def test_http_listing_absolute_urls(nginx_config, tmpdir): + nginx_http_url, _ = nginx_config + + tmpdir = Path(tmpdir) + test_manifest_file = tmpdir / "index.html" + test_manifest_file.write_text( + f""" + this is an absolute path to a file + this is an absolute path to a dir + """ + ) + + with mount_data_nginx(nginx_config, tmpdir): + http_path = f"{nginx_http_url}/index.html" + daft_ls_result = io_list(http_path, recursive=False) + + # NOTE: Cannot use fsspec here because they do not correctly find the links + # fsspec_result = fs.ls(http_path, detail=True) + # compare_http_result(daft_ls_result, fsspec_result) + + assert daft_ls_result == [ + {"type": "File", "path": f"{nginx_http_url}/other.html", "size": None}, + {"type": "Directory", "path": f"{nginx_http_url}/dir/", "size": None}, + ] + + +@pytest.mark.integration() +def test_http_listing_absolute_base_urls(nginx_config, tmpdir): + nginx_http_url, _ = nginx_config + + tmpdir = Path(tmpdir) + test_manifest_file = tmpdir / "index.html" + test_manifest_file.write_text( + f""" + this is an absolute base path to a file + this is an absolute base path to a dir + """ + ) + + with mount_data_nginx(nginx_config, tmpdir): + http_path = f"{nginx_http_url}/index.html" + daft_ls_result = io_list(http_path, recursive=False) + + # NOTE: Cannot use fsspec here because they do not correctly find the links + # fsspec_result = fs.ls(http_path, detail=True) + # compare_http_result(daft_ls_result, fsspec_result) + + assert daft_ls_result == [ + {"type": "File", "path": f"{nginx_http_url}/other.html", "size": None}, + {"type": "Directory", "path": f"{nginx_http_url}/dir/", "size": None}, + ] From 32fd4f69549e4577989abbdd8af133dc9a663cbc Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Thu, 28 Sep 2023 12:45:47 -0700 Subject: [PATCH 6/7] [FEAT] Add local native filesystem globbing. (#1449) This PR adds an `ls` implementation for the native local filesystem. --- Cargo.lock | 1 + Cargo.toml | 1 + src/common/error/src/error.rs | 5 +- src/daft-io/Cargo.toml | 1 + src/daft-io/src/local.rs | 211 ++++++++++++++---- src/daft-io/src/object_io.rs | 25 ++- tests/integration/io/test_list_files_local.py | 115 ++++++++++ 7 files changed, 310 insertions(+), 49 deletions(-) create mode 100644 tests/integration/io/test_list_files_local.py diff --git a/Cargo.lock b/Cargo.lock index 5ec5a3e035..ed2640e265 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1095,6 +1095,7 @@ dependencies = [ "snafu", "tempfile", "tokio", + "tokio-stream", "url", ] diff --git a/Cargo.toml b/Cargo.toml index 3a9f5a9ba3..19ea900845 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,6 +80,7 @@ rand = "^0.8" serde_json = "1.0.104" snafu = "0.7.4" tokio = {version = "1.32.0", features = ["net", "time", "bytes", "process", "signal", "macros", "rt", "rt-multi-thread"]} +tokio-stream = {version = "0.1.14", features = ["fs"]} [workspace.dependencies.arrow2] git = "https://github.com/Eventual-Inc/arrow2" diff --git a/src/common/error/src/error.rs b/src/common/error/src/error.rs index 604b139e48..c61a491685 100644 --- a/src/common/error/src/error.rs +++ b/src/common/error/src/error.rs @@ -20,6 +20,7 @@ pub enum DaftError { path: String, source: GenericError, }, + InternalError(String), External(GenericError), } @@ -31,7 +32,8 @@ impl std::error::Error for DaftError { | DaftError::TypeError(_) | DaftError::ComputeError(_) | DaftError::ArrowError(_) - | DaftError::ValueError(_) => None, + | DaftError::ValueError(_) + | DaftError::InternalError(_) => None, DaftError::IoError(io_error) => Some(io_error), DaftError::FileNotFound { source, .. } | DaftError::External(source) => Some(&**source), #[cfg(feature = "python")] @@ -96,6 +98,7 @@ impl Display for DaftError { Self::ComputeError(s) => write!(f, "DaftError::ComputeError {s}"), Self::ArrowError(s) => write!(f, "DaftError::ArrowError {s}"), Self::ValueError(s) => write!(f, "DaftError::ValueError {s}"), + Self::InternalError(s) => write!(f, "DaftError::InternalError {s}"), #[cfg(feature = "python")] Self::PyO3Error(e) => write!(f, "DaftError::PyO3Error {e}"), Self::IoError(e) => write!(f, "DaftError::IoError {e}"), diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index e36ae93a19..4662f3e33d 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -29,6 +29,7 @@ serde = {workspace = true} serde_json = {workspace = true} snafu = {workspace = true} tokio = {workspace = true} +tokio-stream = {workspace = true} url = "2.4.0" [dependencies.reqwest] diff --git a/src/daft-io/src/local.rs b/src/daft-io/src/local.rs index ccbbedc622..6cf62e9634 100644 --- a/src/daft-io/src/local.rs +++ b/src/daft-io/src/local.rs @@ -2,12 +2,16 @@ use std::io::SeekFrom; use std::ops::Range; use std::path::PathBuf; -use crate::object_io::LSResult; +use crate::object_io::{self, FileMetadata, LSResult}; use super::object_io::{GetResult, ObjectSource}; use super::Result; use async_trait::async_trait; use bytes::Bytes; +use common_error::DaftError; +use futures::stream::BoxStream; +use futures::StreamExt; +use futures::TryStreamExt; use snafu::{ResultExt, Snafu}; use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncSeekExt}; @@ -33,6 +37,21 @@ enum Error { source: std::io::Error, }, + #[snafu(display("Unable to fetch file metadata for file {}: {}", path, source))] + UnableToFetchFileMetadata { + path: String, + source: std::io::Error, + }, + + #[snafu(display("Unable to get entries for directory {}: {}", path, source))] + UnableToFetchDirectoryEntries { + path: String, + source: std::io::Error, + }, + + #[snafu(display("Unexpected symlink when processing directory {}: {}", path, source))] + UnexpectedSymlink { path: String, source: DaftError }, + #[snafu(display("Unable to parse URL \"{}\"", url.to_string_lossy()))] InvalidUrl { url: PathBuf, source: ParseError }, @@ -44,7 +63,9 @@ impl From for super::Error { fn from(error: Error) -> Self { use Error::*; match error { - UnableToOpenFile { path, source } => { + UnableToOpenFile { path, source } + | UnableToFetchFileMetadata { path, source } + | UnableToFetchDirectoryEntries { path, source } => { use std::io::ErrorKind::*; match source.kind() { NotFound => super::Error::NotFound { @@ -84,49 +105,104 @@ pub struct LocalFile { #[async_trait] impl ObjectSource for LocalSource { async fn get(&self, uri: &str, range: Option>) -> super::Result { - const TO_STRIP: &str = "file://"; - if let Some(p) = uri.strip_prefix(TO_STRIP) { - let path = std::path::Path::new(p); + const LOCAL_PROTOCOL: &str = "file://"; + if let Some(uri) = uri.strip_prefix(LOCAL_PROTOCOL) { Ok(GetResult::File(LocalFile { - path: path.to_path_buf(), + path: uri.into(), range, })) } else { - return Err(Error::InvalidFilePath { - path: uri.to_string(), - } - .into()); + Err(Error::InvalidFilePath { path: uri.into() }.into()) } } async fn get_size(&self, uri: &str) -> super::Result { - const TO_STRIP: &str = "file://"; - if let Some(p) = uri.strip_prefix(TO_STRIP) { - let path = std::path::Path::new(p); - let file = tokio::fs::File::open(path) - .await - .context(UnableToOpenFileSnafu { - path: path.to_string_lossy(), - })?; - let metadata = file.metadata().await.context(UnableToOpenFileSnafu { - path: path.to_string_lossy(), - })?; - return Ok(metadata.len() as usize); - } else { - return Err(Error::InvalidFilePath { + const LOCAL_PROTOCOL: &str = "file://"; + let Some(uri) = uri.strip_prefix(LOCAL_PROTOCOL) else { + return Err(Error::InvalidFilePath { path: uri.into() }.into()); + }; + let meta = tokio::fs::metadata(uri) + .await + .context(UnableToFetchFileMetadataSnafu { path: uri.to_string(), - } - .into()); - } + })?; + Ok(meta.len() as usize) } async fn ls( &self, - _path: &str, + path: &str, _delimiter: Option<&str>, _continuation_token: Option<&str>, ) -> super::Result { - unimplemented!("local ls"); + let s = self.iter_dir(path, None, None).await?; + let files = s.try_collect::>().await?; + Ok(LSResult { + files, + continuation_token: None, + }) + } + + async fn iter_dir( + &self, + uri: &str, + _delimiter: Option<&str>, + _limit: Option, + ) -> super::Result>> { + const LOCAL_PROTOCOL: &str = "file://"; + let Some(uri) = uri.strip_prefix(LOCAL_PROTOCOL) else { + return Err(Error::InvalidFilePath { path: uri.into() }.into()); + }; + let meta = + tokio::fs::metadata(uri) + .await + .with_context(|_| UnableToFetchFileMetadataSnafu { + path: uri.to_string(), + })?; + if meta.file_type().is_file() { + // Provided uri points to a file, so only return that file. + return Ok(futures::stream::iter([Ok(FileMetadata { + filepath: format!("{}{}", LOCAL_PROTOCOL, uri), + size: Some(meta.len()), + filetype: object_io::FileType::File, + })]) + .boxed()); + } + let dir_entries = tokio::fs::read_dir(uri).await.with_context(|_| { + UnableToFetchDirectoryEntriesSnafu { + path: uri.to_string(), + } + })?; + let dir_stream = tokio_stream::wrappers::ReadDirStream::new(dir_entries); + let uri = Arc::new(uri.to_string()); + let file_meta_stream = dir_stream.then(move |entry| { + let uri = uri.clone(); + async move { + let entry = entry.with_context(|_| UnableToFetchDirectoryEntriesSnafu { + path: uri.to_string(), + })?; + let meta = tokio::fs::metadata(entry.path()).await.with_context(|_| { + UnableToFetchFileMetadataSnafu { + path: entry.path().to_string_lossy().to_string(), + } + })?; + Ok(FileMetadata { + filepath: format!( + "{}{}{}", + LOCAL_PROTOCOL, + entry.path().to_string_lossy(), + if meta.is_dir() { "/" } else { "" } + ), + size: Some(meta.len()), + filetype: meta.file_type().try_into().with_context(|_| { + UnexpectedSymlinkSnafu { + path: entry.path().to_string_lossy().to_string(), + } + })?, + }) + } + }); + Ok(file_meta_stream.boxed()) } } @@ -171,16 +247,15 @@ pub(crate) async fn collect_file(local_file: LocalFile) -> Result { #[cfg(test)] mod tests { - use std::io::Write; - use crate::object_io::ObjectSource; + use crate::object_io::{FileMetadata, FileType, ObjectSource}; use crate::Result; use crate::{HttpSource, LocalSource}; - #[tokio::test] - async fn test_full_get_from_local() -> Result<()> { - let mut file1 = tempfile::NamedTempFile::new().unwrap(); + async fn write_remote_parquet_to_local_file( + f: &mut tempfile::NamedTempFile, + ) -> Result { let parquet_file_path = "https://daft-public-data.s3.us-west-2.amazonaws.com/test_fixtures/parquet_small/0dad4c3f-da0d-49db-90d8-98684571391b-0.parquet"; let parquet_expected_md5 = "929674747af64a98aceaa6d895863bd3"; @@ -190,15 +265,22 @@ mod tests { let all_bytes = bytes.as_ref(); let checksum = format!("{:x}", md5::compute(all_bytes)); assert_eq!(checksum, parquet_expected_md5); - file1.write_all(all_bytes).unwrap(); - file1.flush().unwrap(); + f.write_all(all_bytes).unwrap(); + f.flush().unwrap(); + Ok(bytes) + } + + #[tokio::test] + async fn test_local_full_get() -> Result<()> { + let mut file1 = tempfile::NamedTempFile::new().unwrap(); + let bytes = write_remote_parquet_to_local_file(&mut file1).await?; let parquet_file_path = format!("file://{}", file1.path().to_str().unwrap()); let client = LocalSource::get_client().await?; let try_all_bytes = client.get(&parquet_file_path, None).await?.bytes().await?; - assert_eq!(try_all_bytes.len(), all_bytes.len()); - assert_eq!(try_all_bytes.as_ref(), all_bytes); + assert_eq!(try_all_bytes.len(), bytes.len()); + assert_eq!(try_all_bytes, bytes); let first_bytes = client .get_range(&parquet_file_path, 0..10) @@ -206,7 +288,7 @@ mod tests { .bytes() .await?; assert_eq!(first_bytes.len(), 10); - assert_eq!(first_bytes.as_ref(), &all_bytes[..10]); + assert_eq!(first_bytes.as_ref(), &bytes[..10]); let first_bytes = client .get_range(&parquet_file_path, 10..100) @@ -214,21 +296,58 @@ mod tests { .bytes() .await?; assert_eq!(first_bytes.len(), 90); - assert_eq!(first_bytes.as_ref(), &all_bytes[10..100]); + assert_eq!(first_bytes.as_ref(), &bytes[10..100]); let last_bytes = client - .get_range( - &parquet_file_path, - (all_bytes.len() - 10)..(all_bytes.len() + 10), - ) + .get_range(&parquet_file_path, (bytes.len() - 10)..(bytes.len() + 10)) .await? .bytes() .await?; assert_eq!(last_bytes.len(), 10); - assert_eq!(last_bytes.as_ref(), &all_bytes[(all_bytes.len() - 10)..]); + assert_eq!(last_bytes.as_ref(), &bytes[(bytes.len() - 10)..]); let size_from_get_size = client.get_size(parquet_file_path.as_str()).await?; - assert_eq!(size_from_get_size, all_bytes.len()); + assert_eq!(size_from_get_size, bytes.len()); + + Ok(()) + } + + #[tokio::test] + async fn test_local_full_ls() -> Result<()> { + let dir = tempfile::tempdir().unwrap(); + let mut file1 = tempfile::NamedTempFile::new_in(dir.path()).unwrap(); + write_remote_parquet_to_local_file(&mut file1).await?; + let mut file2 = tempfile::NamedTempFile::new_in(dir.path()).unwrap(); + write_remote_parquet_to_local_file(&mut file2).await?; + let mut file3 = tempfile::NamedTempFile::new_in(dir.path()).unwrap(); + write_remote_parquet_to_local_file(&mut file3).await?; + let dir_path = format!("file://{}", dir.path().to_string_lossy()); + let client = LocalSource::get_client().await?; + + let ls_result = client.ls(dir_path.as_ref(), None, None).await?; + let mut files = ls_result.files.clone(); + // Ensure stable sort ordering of file paths before comparing with expected payload. + files.sort_by(|a, b| a.filepath.cmp(&b.filepath)); + let mut expected = vec![ + FileMetadata { + filepath: format!("file://{}", file1.path().to_string_lossy()), + size: Some(file1.as_file().metadata().unwrap().len()), + filetype: FileType::File, + }, + FileMetadata { + filepath: format!("file://{}", file2.path().to_string_lossy()), + size: Some(file2.as_file().metadata().unwrap().len()), + filetype: FileType::File, + }, + FileMetadata { + filepath: format!("file://{}", file3.path().to_string_lossy()), + size: Some(file3.as_file().metadata().unwrap().len()), + filetype: FileType::File, + }, + ]; + expected.sort_by(|a, b| a.filepath.cmp(&b.filepath)); + assert_eq!(files, expected); + assert_eq!(ls_result.continuation_token, None); Ok(()) } diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index 98bc23de31..9613d387d1 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use async_trait::async_trait; use bytes::Bytes; +use common_error::DaftError; use futures::stream::{BoxStream, Stream}; use futures::StreamExt; use tokio::sync::mpsc::Sender; @@ -52,12 +53,32 @@ impl GetResult { } } -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq)] pub enum FileType { File, Directory, } -#[derive(Debug)] + +impl TryFrom for FileType { + type Error = DaftError; + + fn try_from(value: std::fs::FileType) -> Result { + if value.is_dir() { + Ok(Self::Directory) + } else if value.is_file() { + Ok(Self::File) + } else if value.is_symlink() { + Err(DaftError::InternalError(format!("Symlinks should never be encountered when constructing FileMetadata, but got: {:?}", value))) + } else { + unreachable!( + "Can only be a directory, file, or symlink, but got: {:?}", + value + ) + } + } +} + +#[derive(Debug, Clone, PartialEq)] pub struct FileMetadata { pub filepath: String, pub size: Option, diff --git a/tests/integration/io/test_list_files_local.py b/tests/integration/io/test_list_files_local.py new file mode 100644 index 0000000000..dfd016038b --- /dev/null +++ b/tests/integration/io/test_list_files_local.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import pytest +from fsspec.implementations.local import LocalFileSystem + +from daft.daft import io_list + + +def local_recursive_list(fs, path) -> list: + all_results = [] + curr_level_result = fs.ls(path, detail=True) + for item in curr_level_result: + if item["type"] == "directory": + new_path = item["name"] + all_results.extend(local_recursive_list(fs, new_path)) + item["name"] += "/" + all_results.append(item) + else: + all_results.append(item) + return all_results + + +def compare_local_result(daft_ls_result: list, fs_result: list): + daft_files = [(f["path"], f["type"].lower()) for f in daft_ls_result] + fs_files = [(f'file://{f["name"]}', f["type"]) for f in fs_result] + assert sorted(daft_files) == sorted(fs_files) + + +@pytest.mark.integration() +@pytest.mark.parametrize("include_protocol", [False, True]) +def test_flat_directory_listing(tmp_path, include_protocol): + d = tmp_path / "dir" + d.mkdir() + files = ["a", "b", "c"] + for name in files: + p = d / name + p.touch() + d = str(d) + if include_protocol: + d = "file://" + d + daft_ls_result = io_list(d) + fs = LocalFileSystem() + fs_result = fs.ls(d, detail=True) + compare_local_result(daft_ls_result, fs_result) + + +@pytest.mark.integration() +@pytest.mark.parametrize("include_protocol", [False, True]) +def test_recursive_directory_listing(tmp_path, include_protocol): + d = tmp_path / "dir" + d.mkdir() + files = ["a", "b/bb", "c/cc/ccc"] + for name in files: + p = d + segments = name.split("/") + for intermediate_dir in segments[:-1]: + p /= intermediate_dir + p.mkdir() + p /= segments[-1] + p.touch() + d = str(d) + if include_protocol: + d = "file://" + d + daft_ls_result = io_list(d, recursive=True) + fs = LocalFileSystem() + fs_result = local_recursive_list(fs, d) + compare_local_result(daft_ls_result, fs_result) + + +@pytest.mark.integration() +@pytest.mark.parametrize("include_protocol", [False, True]) +@pytest.mark.parametrize( + "recursive", + [False, True], +) +def test_single_file_directory_listing(tmp_path, include_protocol, recursive): + d = tmp_path / "dir" + d.mkdir() + files = ["a", "b/bb", "c/cc/ccc"] + for name in files: + p = d + segments = name.split("/") + for intermediate_dir in segments[:-1]: + p /= intermediate_dir + p.mkdir() + p /= segments[-1] + p.touch() + p = f"{d}/c/cc/ccc" + if include_protocol: + p = "file://" + p + daft_ls_result = io_list(p, recursive=recursive) + fs_result = [{"name": f"{d}/c/cc/ccc", "type": "file"}] + assert len(daft_ls_result) == 1 + compare_local_result(daft_ls_result, fs_result) + + +@pytest.mark.integration() +@pytest.mark.parametrize("include_protocol", [False, True]) +def test_missing_file_path(tmp_path, include_protocol): + d = tmp_path / "dir" + d.mkdir() + files = ["a", "b/bb", "c/cc/ccc"] + for name in files: + p = d + segments = name.split("/") + for intermediate_dir in segments[:-1]: + p /= intermediate_dir + p.mkdir() + p /= segments[-1] + p.touch() + p = f"{d}/c/cc/ddd" + if include_protocol: + p = "file://" + p + with pytest.raises(FileNotFoundError, match=f"File: {d}/c/cc/ddd not found"): + daft_ls_result = io_list(p, recursive=True) From 069432d155f0369d9041245a6b1663322b8acfba Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Thu, 28 Sep 2023 13:22:11 -0700 Subject: [PATCH 7/7] [FEAT] Add `.str.split()` API for splitting string columns. (#1409) This PR adds an `Expression.str.split()` API for splitting strings in string columns on a pattern. Closes #1388 --- daft/daft.pyi | 2 + daft/expressions/expressions.py | 18 +++- daft/series.py | 6 ++ src/daft-core/src/array/ops/full.rs | 3 +- src/daft-core/src/array/ops/utf8.rs | 117 ++++++++++++++++++++++- src/daft-core/src/python/series.rs | 4 + src/daft-core/src/series/ops/utf8.rs | 9 ++ src/daft-dsl/src/functions/utf8/mod.rs | 11 +++ src/daft-dsl/src/functions/utf8/split.rs | 50 ++++++++++ src/daft-dsl/src/python.rs | 5 + tests/expressions/typing/test_str.py | 1 + tests/series/test_utf8_ops.py | 98 +++++++++++++++++++ tests/table/utf8/test_split.py | 28 ++++++ 13 files changed, 349 insertions(+), 3 deletions(-) create mode 100644 src/daft-dsl/src/functions/utf8/split.rs create mode 100644 tests/table/utf8/test_split.py diff --git a/daft/daft.pyi b/daft/daft.pyi index 44d2c6880f..025b0b3db5 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -555,6 +555,7 @@ class PyExpr: def utf8_endswith(self, pattern: PyExpr) -> PyExpr: ... def utf8_startswith(self, pattern: PyExpr) -> PyExpr: ... def utf8_contains(self, pattern: PyExpr) -> PyExpr: ... + def utf8_split(self, pattern: PyExpr) -> PyExpr: ... def utf8_length(self) -> PyExpr: ... def image_decode(self) -> PyExpr: ... def image_encode(self, image_format: ImageFormat) -> PyExpr: ... @@ -617,6 +618,7 @@ class PySeries: def utf8_endswith(self, pattern: PySeries) -> PySeries: ... def utf8_startswith(self, pattern: PySeries) -> PySeries: ... def utf8_contains(self, pattern: PySeries) -> PySeries: ... + def utf8_split(self, pattern: PySeries) -> PySeries: ... def utf8_length(self) -> PySeries: ... def is_nan(self) -> PySeries: ... def dt_date(self) -> PySeries: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 2eaf5eb290..d73f8b4709 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -572,7 +572,7 @@ def endswith(self, suffix: str | Expression) -> Expression: suffix_expr = Expression._to_expression(suffix) return Expression._from_pyexpr(self._expr.utf8_endswith(suffix_expr._expr)) - def startswith(self, prefix: str) -> Expression: + def startswith(self, prefix: str | Expression) -> Expression: """Checks whether each string starts with the given pattern in a string column Example: @@ -587,6 +587,22 @@ def startswith(self, prefix: str) -> Expression: prefix_expr = Expression._to_expression(prefix) return Expression._from_pyexpr(self._expr.utf8_startswith(prefix_expr._expr)) + def split(self, pattern: str | Expression) -> Expression: + """Splits each string on the given pattern, into one or more strings. + + Example: + >>> col("x").str.split(",") + >>> col("x").str.split(col("pattern")) + + Args: + pattern: The pattern on which each string should be split, or a column to pick such patterns from. + + Returns: + Expression: A List[Utf8] expression containing the string splits for each string in the column. + """ + pattern_expr = Expression._to_expression(pattern) + return Expression._from_pyexpr(self._expr.utf8_split(pattern_expr._expr)) + def concat(self, other: str) -> Expression: """Concatenates two string expressions together diff --git a/daft/series.py b/daft/series.py index d81196c391..2430dbf31d 100644 --- a/daft/series.py +++ b/daft/series.py @@ -534,6 +534,12 @@ def contains(self, pattern: Series) -> Series: assert self._series is not None and pattern._series is not None return Series._from_pyseries(self._series.utf8_contains(pattern._series)) + def split(self, pattern: Series) -> Series: + if not isinstance(pattern, Series): + raise ValueError(f"expected another Series but got {type(pattern)}") + assert self._series is not None and pattern._series is not None + return Series._from_pyseries(self._series.utf8_split(pattern._series)) + def concat(self, other: Series) -> Series: if not isinstance(other, Series): raise ValueError(f"expected another Series but got {type(other)}") diff --git a/src/daft-core/src/array/ops/full.rs b/src/daft-core/src/array/ops/full.rs index f39c904709..f2f86c5298 100644 --- a/src/daft-core/src/array/ops/full.rs +++ b/src/daft-core/src/array/ops/full.rs @@ -129,7 +129,8 @@ impl FullNull for ListArray { Self::new( Field::new(name, dtype.clone()), empty_flat_child, - OffsetsBuffer::try_from(repeat(0).take(length).collect::>()).unwrap(), + OffsetsBuffer::try_from(repeat(0).take(length + 1).collect::>()) + .unwrap(), Some(validity), ) } diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index 262559776f..7f9987078b 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -1,10 +1,66 @@ -use crate::datatypes::{BooleanArray, UInt64Array, Utf8Array}; +use crate::{ + array::ListArray, + datatypes::{BooleanArray, Field, UInt64Array, Utf8Array}, + DataType, Series, +}; use arrow2; use common_error::{DaftError, DaftResult}; use super::{as_arrow::AsArrow, full::FullNull}; +fn split_array_on_patterns<'a, T, U>( + arr_iter: T, + pattern_iter: U, + buffer_len: usize, + name: &str, +) -> DaftResult +where + T: arrow2::trusted_len::TrustedLen + Iterator>, + U: Iterator>, +{ + // This will overallocate by pattern_len * N_i, where N_i is the number of pattern occurences in the ith string in arr_iter. + let mut splits = arrow2::array::MutableUtf8Array::with_capacity(buffer_len); + // arr_iter implements TrustedLen, so we can always use size_hint().1 as the exact length of the iterator. The only + // time this would fail is if the length of the iterator exceeds usize::MAX, which should never happen for an i64 + // offset array, since the array length can't exceed i64::MAX on 64-bit machines. + let arr_len = arr_iter.size_hint().1.unwrap(); + let mut offsets = arrow2::offset::Offsets::new(); + let mut validity = arrow2::bitmap::MutableBitmap::with_capacity(arr_len); + for (val, pat) in arr_iter.zip(pattern_iter) { + let mut num_splits = 0i64; + match (val, pat) { + (Some(val), Some(pat)) => { + for split in val.split(pat) { + splits.push(Some(split)); + num_splits += 1; + } + validity.push(true); + } + (_, _) => { + validity.push(false); + } + } + offsets.try_push(num_splits)?; + } + // Shrink splits capacity to current length, since we will have overallocated if any of the patterns actually occurred in the strings. + splits.shrink_to_fit(); + let splits: arrow2::array::Utf8Array = splits.into(); + let offsets: arrow2::offset::OffsetsBuffer = offsets.into(); + let validity: Option = match validity.unset_bits() { + 0 => None, + _ => Some(validity.into()), + }; + let flat_child = + Series::try_from(("splits", Box::new(splits) as Box))?; + Ok(ListArray::new( + Field::new(name, DataType::List(Box::new(DataType::Utf8))), + flat_child, + offsets, + validity, + )) +} + impl Utf8Array { pub fn endswith(&self, pattern: &Utf8Array) -> DaftResult { self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| data.ends_with(pat)) @@ -18,6 +74,65 @@ impl Utf8Array { self.binary_broadcasted_compare(pattern, |data: &str, pat: &str| data.contains(pat)) } + pub fn split(&self, pattern: &Utf8Array) -> DaftResult { + let self_arrow = self.as_arrow(); + let pattern_arrow = pattern.as_arrow(); + // Handle all-null cases. + if self_arrow + .validity() + .map_or(false, |v| v.unset_bits() == v.len()) + || pattern_arrow + .validity() + .map_or(false, |v| v.unset_bits() == v.len()) + { + return Ok(ListArray::full_null( + self.name(), + &DataType::List(Box::new(DataType::Utf8)), + std::cmp::max(self.len(), pattern.len()), + )); + // Handle empty cases. + } else if self.is_empty() || pattern.is_empty() { + return Ok(ListArray::empty( + self.name(), + &DataType::List(Box::new(DataType::Utf8)), + )); + } + let buffer_len = self_arrow.values().len(); + match (self.len(), pattern.len()) { + // Matching len case: + (self_len, pattern_len) if self_len == pattern_len => split_array_on_patterns( + self_arrow.into_iter(), + pattern_arrow.into_iter(), + buffer_len, + self.name(), + ), + // Broadcast pattern case: + (self_len, 1) => { + let pattern_scalar_value = pattern.get(0).unwrap(); + split_array_on_patterns( + self_arrow.into_iter(), + std::iter::repeat(Some(pattern_scalar_value)).take(self_len), + buffer_len, + self.name(), + ) + } + // Broadcast self case: + (1, pattern_len) => { + let self_scalar_value = self.get(0).unwrap(); + split_array_on_patterns( + std::iter::repeat(Some(self_scalar_value)).take(pattern_len), + pattern_arrow.into_iter(), + buffer_len * pattern_len, + self.name(), + ) + } + // Mismatched len case: + (self_len, pattern_len) => Err(DaftError::ComputeError(format!( + "lhs and rhs have different length arrays: {self_len} vs {pattern_len}" + ))), + } + } + pub fn length(&self) -> DaftResult { let self_arrow = self.as_arrow(); let arrow_result = self_arrow diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index a388585940..f3c8cb3fa0 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -247,6 +247,10 @@ impl PySeries { Ok(self.series.utf8_contains(&pattern.series)?.into()) } + pub fn utf8_split(&self, pattern: &Self) -> PyResult { + Ok(self.series.utf8_split(&pattern.series)?.into()) + } + pub fn utf8_length(&self) -> PyResult { Ok(self.series.utf8_length()?.into()) } diff --git a/src/daft-core/src/series/ops/utf8.rs b/src/daft-core/src/series/ops/utf8.rs index bea305d9fc..fb2539b64e 100644 --- a/src/daft-core/src/series/ops/utf8.rs +++ b/src/daft-core/src/series/ops/utf8.rs @@ -32,6 +32,15 @@ impl Series { } } + pub fn utf8_split(&self, pattern: &Series) -> DaftResult { + match self.data_type() { + DataType::Utf8 => Ok(self.utf8()?.split(pattern.utf8()?)?.into_series()), + dt => Err(DaftError::TypeError(format!( + "Split not implemented for type {dt}" + ))), + } + } + pub fn utf8_length(&self) -> DaftResult { match self.data_type() { DataType::Utf8 => Ok(self.utf8()?.length()?.into_series()), diff --git a/src/daft-dsl/src/functions/utf8/mod.rs b/src/daft-dsl/src/functions/utf8/mod.rs index cd23c0883b..5c8901147e 100644 --- a/src/daft-dsl/src/functions/utf8/mod.rs +++ b/src/daft-dsl/src/functions/utf8/mod.rs @@ -1,12 +1,14 @@ mod contains; mod endswith; mod length; +mod split; mod startswith; use contains::ContainsEvaluator; use endswith::EndswithEvaluator; use length::LengthEvaluator; use serde::{Deserialize, Serialize}; +use split::SplitEvaluator; use startswith::StartswithEvaluator; use crate::Expr; @@ -18,6 +20,7 @@ pub enum Utf8Expr { EndsWith, StartsWith, Contains, + Split, Length, } @@ -29,6 +32,7 @@ impl Utf8Expr { EndsWith => &EndswithEvaluator {}, StartsWith => &StartswithEvaluator {}, Contains => &ContainsEvaluator {}, + Split => &SplitEvaluator {}, Length => &LengthEvaluator {}, } } @@ -55,6 +59,13 @@ pub fn contains(data: &Expr, pattern: &Expr) -> Expr { } } +pub fn split(data: &Expr, pattern: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::Utf8(Utf8Expr::Split), + inputs: vec![data.clone(), pattern.clone()], + } +} + pub fn length(data: &Expr) -> Expr { Expr::Function { func: super::FunctionExpr::Utf8(Utf8Expr::Length), diff --git a/src/daft-dsl/src/functions/utf8/split.rs b/src/daft-dsl/src/functions/utf8/split.rs new file mode 100644 index 0000000000..8d2c238b70 --- /dev/null +++ b/src/daft-dsl/src/functions/utf8/split.rs @@ -0,0 +1,50 @@ +use crate::Expr; +use daft_core::{ + datatypes::{DataType, Field}, + schema::Schema, + series::Series, +}; + +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct SplitEvaluator {} + +impl FunctionEvaluator for SplitEvaluator { + fn fn_name(&self) -> &'static str { + "split" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { + (Ok(data_field), Ok(pattern_field)) => { + match (&data_field.dtype, &pattern_field.dtype) { + (DataType::Utf8, DataType::Utf8) => { + Ok(Field::new(data_field.name, DataType::List(Box::new(DataType::Utf8)))) + } + _ => Err(DaftError::TypeError(format!( + "Expects inputs to split to be utf8, but received {data_field} and {pattern_field}", + ))), + } + } + (Err(e), _) | (_, Err(e)) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [data, pattern] => data.utf8_split(pattern), + _ => Err(DaftError::ValueError(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index c64afc4271..cb61044339 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -291,6 +291,11 @@ impl PyExpr { Ok(contains(&self.expr, &pattern.expr).into()) } + pub fn utf8_split(&self, pattern: &Self) -> PyResult { + use crate::functions::utf8::split; + Ok(split(&self.expr, &pattern.expr).into()) + } + pub fn utf8_length(&self) -> PyResult { use crate::functions::utf8::length; Ok(length(&self.expr).into()) diff --git a/tests/expressions/typing/test_str.py b/tests/expressions/typing/test_str.py index 99d85ac460..99ab473190 100644 --- a/tests/expressions/typing/test_str.py +++ b/tests/expressions/typing/test_str.py @@ -15,6 +15,7 @@ pytest.param(lambda data, pat: data.str.contains(pat), id="contains"), pytest.param(lambda data, pat: data.str.startswith(pat), id="startswith"), pytest.param(lambda data, pat: data.str.endswith(pat), id="endswith"), + pytest.param(lambda data, pat: data.str.endswith(pat), id="split"), pytest.param(lambda data, pat: data.str.concat(pat), id="concat"), ], ) diff --git a/tests/series/test_utf8_ops.py b/tests/series/test_utf8_ops.py index dafd294423..1019769501 100644 --- a/tests/series/test_utf8_ops.py +++ b/tests/series/test_utf8_ops.py @@ -102,6 +102,104 @@ def test_series_utf8_compare_invalid_inputs(funcname, bad_series) -> None: getattr(s.str, funcname)(bad_series) +@pytest.mark.parametrize( + ["data", "patterns", "expected"], + [ + # Single-character pattern. + (["a,b,c", "d,e", "f", "g,h"], [","], [["a", "b", "c"], ["d", "e"], ["f"], ["g", "h"]]), + # Multi-character pattern. + (["abbcbbd", "bb", "bbe", "fbb"], ["bb"], [["a", "c", "d"], ["", ""], ["", "e"], ["f", ""]]), + # Empty pattern (character-splitting). + (["foo", "bar"], [""], [["", "f", "o", "o", ""], ["", "b", "a", "r", ""]]), + ], +) +def test_series_utf8_split_broadcast_pattern(data, patterns, expected) -> None: + s = Series.from_arrow(pa.array(data)) + patterns = Series.from_arrow(pa.array(patterns)) + result = s.str.split(patterns) + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + ["data", "patterns", "expected"], + [ + (["a,b,c", "a:b:c", "a;b;c", "a.b.c"], [",", ":", ";", "."], [["a", "b", "c"]] * 4), + (["aabbccdd"] * 4, ["aa", "bb", "cc", "dd"], [["", "bbccdd"], ["aa", "ccdd"], ["aabb", "dd"], ["aabbcc", ""]]), + ], +) +def test_series_utf8_split_multi_pattern(data, patterns, expected) -> None: + s = Series.from_arrow(pa.array(data)) + patterns = Series.from_arrow(pa.array(patterns)) + result = s.str.split(patterns) + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + ["data", "patterns", "expected"], + [ + (["aabbccdd"], ["aa", "bb", "cc", "dd"], [["", "bbccdd"], ["aa", "ccdd"], ["aabb", "dd"], ["aabbcc", ""]]), + ], +) +def test_series_utf8_split_broadcast_arr(data, patterns, expected) -> None: + s = Series.from_arrow(pa.array(data)) + patterns = Series.from_arrow(pa.array(patterns)) + result = s.str.split(patterns) + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + ["data", "patterns", "expected"], + [ + # Mixed-in nulls. + (["a,b,c", None, "a;b;c", "a.b.c"], [",", ":", None, "."], [["a", "b", "c"], None, None, ["a", "b", "c"]]), + # All null data. + ([None] * 4, [","] * 4, [None] * 4), + # All null patterns. + (["foo"] * 4, [None] * 4, [None] * 4), + # Broadcasted null data. + ([None], [","] * 4, [None] * 4), + # Broadcasted null pattern. + (["foo"] * 4, [None], [None] * 4), + ], +) +def test_series_utf8_split_nulls(data, patterns, expected) -> None: + s = Series.from_arrow(pa.array(data, type=pa.string())) + patterns = Series.from_arrow(pa.array(patterns, type=pa.string())) + result = s.str.split(patterns) + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + ["data", "patterns", "expected"], + [ + # Empty data. + ([[], [","] * 4, []]), + # Empty patterns. + ([["foo"] * 4, [], []]), + ], +) +def test_series_utf8_split_empty_arrs(data, patterns, expected) -> None: + s = Series.from_arrow(pa.array(data, type=pa.string())) + patterns = Series.from_arrow(pa.array(patterns, type=pa.string())) + result = s.str.split(patterns) + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + "patterns", + [ + # Wrong number of elements, not broadcastable + Series.from_arrow(pa.array([",", "."], type=pa.string())), + # Bad input type + object(), + ], +) +def test_series_utf8_split_invalid_inputs(patterns) -> None: + s = Series.from_arrow(pa.array(["a,b,c", "d, e", "f"])) + with pytest.raises(ValueError): + s.str.split(patterns) + + def test_series_utf8_length() -> None: s = Series.from_arrow(pa.array(["foo", "barbaz", "quux"])) result = s.str.length() diff --git a/tests/table/utf8/test_split.py b/tests/table/utf8/test_split.py new file mode 100644 index 0000000000..0da7735b1d --- /dev/null +++ b/tests/table/utf8/test_split.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import pytest + +from daft.expressions import col, lit +from daft.table import Table + + +@pytest.mark.parametrize( + ["expr", "data", "expected"], + [ + (col("col").str.split(","), ["a,b,c", "d,e", "f", "g,h"], [["a", "b", "c"], ["d", "e"], ["f"], ["g", "h"]]), + ( + col("col").str.split(lit(",")), + ["a,b,c", "d,e", "f", "g,h"], + [["a", "b", "c"], ["d", "e"], ["f"], ["g", "h"]], + ), + ( + col("col").str.split(col("emptystrings") + lit(",")), + ["a,b,c", "d,e", "f", "g,h"], + [["a", "b", "c"], ["d", "e"], ["f"], ["g", "h"]], + ), + ], +) +def test_series_utf8_split_broadcast_pattern(expr, data, expected) -> None: + table = Table.from_pydict({"col": data, "emptystrings": ["", "", "", ""]}) + result = table.eval_expression_list([expr]) + assert result.to_pydict() == {"col": expected}