From b702e4de47de0fc176ecd4b0bfd8844a688de789 Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 20 Feb 2024 14:12:49 -0800 Subject: [PATCH 01/11] [CHORE] Populate previews only when show() or __repr__() is called (#1889) Closes #1858 Closes #1859 Changes: - no more preview creation in `.collect()` - added a `_populate_preview` method thats called during `repr` or `show()` - changed some tests --- daft/dataframe/dataframe.py | 70 ++++++++++++------------------- daft/runners/partitioning.py | 3 ++ daft/runners/pyrunner.py | 13 ++++++ daft/runners/ray_runner.py | 14 +++++++ tests/cookbook/test_write.py | 12 ++++++ tests/dataframe/test_decimals.py | 2 +- tests/dataframe/test_repr.py | 11 +++++ tests/dataframe/test_show.py | 17 ++++---- tests/dataframe/test_temporals.py | 2 +- 9 files changed, 91 insertions(+), 53 deletions(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index a370d998ed..32f8bb726e 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -85,6 +85,7 @@ def __init__(self, builder: LogicalPlanBuilder) -> None: self.__builder = builder self._result_cache: Optional[PartitionCacheEntry] = None self._preview = DataFramePreview(preview_partition=None, dataframe_num_rows=None) + self._num_preview_rows = get_context().daft_execution_config.num_preview_rows @property def _builder(self) -> LogicalPlanBuilder: @@ -225,13 +226,33 @@ def iter_partitions(self) -> Iterator[Union[MicroPartition, "RayObjectRef"]]: for result in results_iter: yield result.partition() + def _populate_preview(self) -> None: + """Populates the preview of the DataFrame, if it is not already populated.""" + if self._result is None: + return + + preview_partition_invalid = ( + self._preview.preview_partition is None or len(self._preview.preview_partition) < self._num_preview_rows + ) + if preview_partition_invalid: + preview_parts = self._result._get_preview_vpartition(self._num_preview_rows) + preview_results = LocalPartitionSet({i: part for i, part in enumerate(preview_parts)}) + + preview_partition = preview_results._get_merged_vpartition() + self._preview = DataFramePreview( + preview_partition=preview_partition, + dataframe_num_rows=len(self), + ) + @DataframePublicAPI def __repr__(self) -> str: + self._populate_preview() display = DataFrameDisplay(self._preview, self.schema()) return display.__repr__() @DataframePublicAPI def _repr_html_(self) -> str: + self._populate_preview() display = DataFrameDisplay(self._preview, self.schema()) return display._repr_html_() @@ -305,30 +326,7 @@ def _from_tables(cls, *parts: MicroPartition) -> "DataFrame": df._result_cache = cache_entry # build preview - num_preview_rows = context.daft_execution_config.num_preview_rows - dataframe_num_rows = len(df) - if dataframe_num_rows > num_preview_rows: - need = num_preview_rows - preview_parts = [] - for part in parts: - part_len = len(part) - if part_len >= need: # if this part has enough rows, take what we need and break - preview_parts.append(part.slice(0, need)) - break - else: # otherwise, take the whole part and keep going - need -= part_len - preview_parts.append(part) - - preview_results = LocalPartitionSet({i: part for i, part in enumerate(preview_parts)}) - else: - preview_results = result_pset - - # set preview - preview_partition = preview_results._get_merged_vpartition() - df._preview = DataFramePreview( - preview_partition=preview_partition, - dataframe_num_rows=dataframe_num_rows, - ) + df._populate_preview() return df ### @@ -1129,26 +1127,10 @@ def collect(self, num_preview_rows: Optional[int] = 8) -> "DataFrame": assert self._result is not None dataframe_len = len(self._result) - requested_rows = dataframe_len if num_preview_rows is None else num_preview_rows - - # Build a DataFramePreview and cache it if necessary - preview_partition_invalid = ( - self._preview.preview_partition is None or len(self._preview.preview_partition) < requested_rows - ) - if preview_partition_invalid: - preview_df = self - if num_preview_rows is not None: - preview_df = preview_df.limit(num_preview_rows) - preview_df._materialize_results() - preview_results = preview_df._result - assert preview_results is not None - - preview_partition = preview_results._get_merged_vpartition() - self._preview = DataFramePreview( - preview_partition=preview_partition, - dataframe_num_rows=dataframe_len, - ) - + if num_preview_rows is not None: + self._num_preview_rows = num_preview_rows + else: + self._num_preview_rows = dataframe_len return self def _construct_show_display(self, n: int) -> "DataFrameDisplay": diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index 8836a3bc5c..56fda08a3a 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -209,6 +209,9 @@ class PartitionSet(Generic[PartitionT]): def _get_merged_vpartition(self) -> MicroPartition: raise NotImplementedError() + def _get_preview_vpartition(self, num_rows: int) -> list[MicroPartition]: + raise NotImplementedError() + def to_pydict(self) -> dict[str, list[Any]]: """Retrieves all the data in a PartitionSet as a Python dictionary. Values are the raw data from each Block.""" merged_partition = self._get_merged_vpartition() diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 2be28e4c54..41b435ff24 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -51,6 +51,19 @@ def _get_merged_vpartition(self) -> MicroPartition: assert ids_and_partitions[-1][0] + 1 == len(ids_and_partitions) return MicroPartition.concat([part for id, part in ids_and_partitions]) + def _get_preview_vpartition(self, num_rows: int) -> list[MicroPartition]: + ids_and_partitions = self.items() + preview_parts = [] + for _, part in ids_and_partitions: + part_len = len(part) + if part_len >= num_rows: # if this part has enough rows, take what we need and break + preview_parts.append(part.slice(0, num_rows)) + break + else: # otherwise, take the whole part and keep going + num_rows -= part_len + preview_parts.append(part) + return preview_parts + def get_partition(self, idx: PartID) -> MicroPartition: return self._partitions[idx] diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 24d597084c..e2da71b62b 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -151,6 +151,20 @@ def _get_merged_vpartition(self) -> MicroPartition: all_partitions = ray.get([part for id, part in ids_and_partitions]) return MicroPartition.concat(all_partitions) + def _get_preview_vpartition(self, num_rows: int) -> list[MicroPartition]: + ids_and_partitions = self.items() + preview_parts = [] + for _, part in ids_and_partitions: + part = ray.get(part) + part_len = len(part) + if part_len >= num_rows: # if this part has enough rows, take what we need and break + preview_parts.append(part.slice(0, num_rows)) + break + else: # otherwise, take the whole part and keep going + num_rows -= part_len + preview_parts.append(part) + return preview_parts + def to_ray_dataset(self) -> RayDataset: if not _RAY_FROM_ARROW_REFS_AVAILABLE: raise ImportError( diff --git a/tests/cookbook/test_write.py b/tests/cookbook/test_write.py index 2f611c4059..bb1dba9668 100644 --- a/tests/cookbook/test_write.py +++ b/tests/cookbook/test_write.py @@ -21,6 +21,8 @@ def test_parquet_write(tmp_path): assert_df_equals(df.to_pandas(), read_back_pd_df) assert len(pd_df) == 1 + assert pd_df._preview.preview_partition is None + pd_df.__repr__() assert len(pd_df._preview.preview_partition) == 1 @@ -33,6 +35,8 @@ def test_parquet_write_with_partitioning(tmp_path): assert_df_equals(df.to_pandas(), read_back_pd_df) assert len(pd_df) == 5 + assert pd_df._preview.preview_partition is None + pd_df.__repr__() assert len(pd_df._preview.preview_partition) == 5 @@ -41,6 +45,8 @@ def test_empty_parquet_write_without_partitioning(tmp_path): df = df.where(daft.lit(False)) output_files = df.write_parquet(tmp_path) assert len(output_files) == 0 + assert output_files._preview.preview_partition is None + output_files.__repr__() assert len(output_files._preview.preview_partition) == 0 @@ -49,6 +55,8 @@ def test_empty_parquet_write_with_partitioning(tmp_path): df = df.where(daft.lit(False)) output_files = df.write_parquet(tmp_path, partition_cols=["Borough"]) assert len(output_files) == 0 + assert output_files._preview.preview_partition is None + output_files.__repr__() assert len(output_files._preview.preview_partition) == 0 @@ -69,6 +77,8 @@ def test_parquet_write_with_partitioning_readback_values(tmp_path): assert_df_equals(df.to_pandas(), read_back_pd_df) assert len(output_files) == 5 + assert output_files._preview.preview_partition is None + output_files.__repr__() assert len(output_files._preview.preview_partition) == 5 @@ -193,6 +203,8 @@ def test_csv_write(tmp_path): assert_df_equals(df.to_pandas(), read_back_pd_df) assert len(pd_df) == 1 + assert pd_df._preview.preview_partition is None + pd_df.__repr__() assert len(pd_df._preview.preview_partition) == 1 diff --git a/tests/dataframe/test_decimals.py b/tests/dataframe/test_decimals.py index 530146a098..2005790c9d 100644 --- a/tests/dataframe/test_decimals.py +++ b/tests/dataframe/test_decimals.py @@ -24,7 +24,7 @@ def test_decimal_parquet_roundtrip() -> None: df.write_parquet(dirname) df_readback = daft.read_parquet(dirname).collect() - assert str(df.to_pydict()["decimal128"]) == str(df_readback.to_pydict()["decimal128"]) + assert str(df.to_pydict()["decimal128"]) == str(df_readback.to_pydict()["decimal128"]) def test_arrow_decimal() -> None: diff --git a/tests/dataframe/test_repr.py b/tests/dataframe/test_repr.py index f84e13c0e7..636552978a 100644 --- a/tests/dataframe/test_repr.py +++ b/tests/dataframe/test_repr.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd +import pytest from PIL import Image import daft @@ -86,6 +87,16 @@ def test_empty_repr(make_df): assert df._repr_html_() == "(No data to display: Dataframe has no columns)" +@pytest.mark.parametrize("num_preview_rows", [9, 10, None]) +def test_repr_with_non_default_preview_rows(make_df, num_preview_rows): + df = make_df({"A": [i for i in range(10)], "B": [i for i in range(10)]}) + df.collect(num_preview_rows=num_preview_rows) + df.__repr__() + + assert df._preview.dataframe_num_rows == 10 + assert len(df._preview.preview_partition) == (num_preview_rows if num_preview_rows is not None else 10) + + def test_empty_df_repr(make_df): df = make_df({"A": [1, 2, 3], "B": ["a", "b", "c"]}) df = df.where(df["A"] > 10) diff --git a/tests/dataframe/test_show.py b/tests/dataframe/test_show.py index df32865551..dd5ced328f 100644 --- a/tests/dataframe/test_show.py +++ b/tests/dataframe/test_show.py @@ -26,13 +26,14 @@ def test_show_some(make_df, valid_data, data_source): assert df_display.num_rows == 1 -def test_show_from_cached_collect(make_df, valid_data): +def test_show_from_cached_repr(make_df, valid_data): df = make_df(valid_data) df = df.collect() + df.__repr__() collected_preview = df._preview df_display = df._construct_show_display(8) - # Check that cached preview from df.collect() was used. + # Check that cached preview from df.__repr__() was used. assert df_display.preview is collected_preview assert df_display.schema == df.schema() assert len(df_display.preview.preview_partition) == len(valid_data) @@ -40,30 +41,32 @@ def test_show_from_cached_collect(make_df, valid_data): assert df_display.num_rows == 3 -def test_show_from_cached_collect_prefix(make_df, valid_data): +def test_show_from_cached_repr_prefix(make_df, valid_data): df = make_df(valid_data) df = df.collect(3) + df.__repr__() df_display = df._construct_show_display(2) assert df_display.schema == df.schema() assert len(df_display.preview.preview_partition) == 2 - # Check that a prefix of the cached preview from df.collect() was used, so dataframe_num_rows should be set. + # Check that a prefix of the cached preview from df.__repr__() was used, so dataframe_num_rows should be set. assert df_display.preview.dataframe_num_rows == 3 assert df_display.num_rows == 2 -def test_show_not_from_cached_collect(make_df, valid_data, data_source): +def test_show_not_from_cached_repr(make_df, valid_data, data_source): df = make_df(valid_data) df = df.collect(2) + df.__repr__() collected_preview = df._preview df_display = df._construct_show_display(8) variant = data_source if variant == "parquet": - # Cached preview from df.collect() is NOT USED because data was not materialized from parquet. + # Cached preview from df.__repr__() is NOT USED because data was not materialized from parquet. assert df_display.preview != collected_preview elif variant == "arrow": - # Cached preview from df.collect() is USED because data was materialized from arrow. + # Cached preview from df.__repr__() is USED because data was materialized from arrow. assert df_display.preview == collected_preview assert df_display.schema == df.schema() assert len(df_display.preview.preview_partition) == len(valid_data) diff --git a/tests/dataframe/test_temporals.py b/tests/dataframe/test_temporals.py index 1c0cbcda9c..3973d3c7ce 100644 --- a/tests/dataframe/test_temporals.py +++ b/tests/dataframe/test_temporals.py @@ -90,7 +90,7 @@ def test_temporal_file_roundtrip(format, use_native_downloader) -> None: df.write_parquet(dirname) df_readback = daft.read_parquet(dirname, use_native_downloader=use_native_downloader).collect() - assert df.to_pydict() == df_readback.to_pydict() + assert df.to_pydict() == df_readback.to_pydict() @pytest.mark.parametrize( From d115e190449ebeef15782afa6e8d9326a28938d3 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:26:57 -0800 Subject: [PATCH 02/11] [DOCS] Update schema hints documentation (#1935) Schema hint documentation was out of date after: #1636 This PR fixes our docs Co-authored-by: Jay Chia --- daft/io/_csv.py | 4 ++-- daft/io/_json.py | 4 ++-- daft/io/_parquet.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/daft/io/_csv.py b/daft/io/_csv.py index 9101f83872..ec3b1f1e7f 100644 --- a/daft/io/_csv.py +++ b/daft/io/_csv.py @@ -43,8 +43,8 @@ def read_csv( Args: path (str): Path to CSV (allows for wildcards) - schema_hints (dict[str, DataType]): A mapping between column names and datatypes - passing this option will - disable all schema inference on data being read, and throw an error if data being read is incompatible. + schema_hints (dict[str, DataType]): A mapping between column names and datatypes - passing this option + will override the specified columns on the inferred schema with the specified DataTypes has_headers (bool): Whether the CSV has a header or not, defaults to True delimiter (Str): Delimiter used in the CSV, defaults to "," doubled_quote (bool): Whether to support double quote escapes, defaults to True diff --git a/daft/io/_json.py b/daft/io/_json.py index 309686c61b..c052fb2c52 100644 --- a/daft/io/_json.py +++ b/daft/io/_json.py @@ -36,8 +36,8 @@ def read_json( Args: path (str): Path to JSON files (allows for wildcards) - schema_hints (dict[str, DataType]): A mapping between column names and datatypes - passing this option will - disable all schema inference on data being read, and throw an error if data being read is incompatible. + schema_hints (dict[str, DataType]): A mapping between column names and datatypes - passing this option + will override the specified columns on the inferred schema with the specified DataTypes io_config (IOConfig): Config to be used with the native downloader use_native_downloader: Whether to use the native downloader instead of PyArrow for reading Parquet. This is currently experimental. diff --git a/daft/io/_parquet.py b/daft/io/_parquet.py index 6b95723a45..a41d82ec64 100644 --- a/daft/io/_parquet.py +++ b/daft/io/_parquet.py @@ -35,8 +35,8 @@ def read_parquet( Args: path (str): Path to Parquet file (allows for wildcards) - schema_hints (dict[str, DataType]): A mapping between column names and datatypes - passing this option will - disable all schema inference on data being read, and throw an error if data being read is incompatible. + schema_hints (dict[str, DataType]): A mapping between column names and datatypes - passing this option + will override the specified columns on the inferred schema with the specified DataTypes io_config (IOConfig): Config to be used with the native downloader use_native_downloader: Whether to use the native downloader instead of PyArrow for reading Parquet. _multithreaded_io: Whether to use multithreading for IO threads. Setting this to False can be helpful in reducing From 1e878aef5cbc79b173f1ddc73d6e14a7e7ea580e Mon Sep 17 00:00:00 2001 From: Nick Salerni Date: Tue, 20 Feb 2024 23:56:56 -0800 Subject: [PATCH 03/11] [FEAT] Add str.lower() function (#1938) * Adding the `lower` function to match https://ibis-project.org/reference/expression-strings#ibis.expr.types.strings.StringValue.lower * Added tests showing example usage Closes #1921 --- daft/daft.pyi | 2 ++ daft/expressions/expressions.py | 11 ++++++ daft/series.py | 4 +++ docs/source/api_docs/expressions.rst | 1 + src/daft-core/src/array/ops/utf8.rs | 15 +++++++- src/daft-core/src/python/series.rs | 4 +++ src/daft-core/src/series/ops/utf8.rs | 10 ++++++ src/daft-dsl/src/functions/utf8/lower.rs | 46 ++++++++++++++++++++++++ src/daft-dsl/src/functions/utf8/mod.rs | 11 ++++++ src/daft-dsl/src/python.rs | 5 +++ tests/expressions/typing/test_str.py | 10 ++++++ tests/series/test_utf8_ops.py | 30 ++++++++++++++++ tests/table/utf8/test_lower.py | 10 ++++++ 13 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 src/daft-dsl/src/functions/utf8/lower.rs create mode 100644 tests/table/utf8/test_lower.py diff --git a/daft/daft.pyi b/daft/daft.pyi index d5ab46cb89..46fc5d9448 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -909,6 +909,7 @@ class PyExpr: def utf8_contains(self, pattern: PyExpr) -> PyExpr: ... def utf8_split(self, pattern: PyExpr) -> PyExpr: ... def utf8_length(self) -> PyExpr: ... + def utf8_lower(self) -> PyExpr: ... def image_decode(self) -> PyExpr: ... def image_encode(self, image_format: ImageFormat) -> PyExpr: ... def image_resize(self, w: int, h: int) -> PyExpr: ... @@ -984,6 +985,7 @@ class PySeries: def utf8_contains(self, pattern: PySeries) -> PySeries: ... def utf8_split(self, pattern: PySeries) -> PySeries: ... def utf8_length(self) -> PySeries: ... + def utf8_lower(self) -> PySeries: ... def is_nan(self) -> PySeries: ... def dt_date(self) -> PySeries: ... def dt_day(self) -> PySeries: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index a903c968db..2901ce690d 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -732,6 +732,17 @@ def length(self) -> Expression: """ return Expression._from_pyexpr(self._expr.utf8_length()) + def lower(self) -> Expression: + """Convert UTF-8 string to all lowercase + + Example: + >>> col("x").str.lower() + + Returns: + Expression: a String expression which is `self` lowercased + """ + return Expression._from_pyexpr(self._expr.utf8_lower()) + class ExpressionListNamespace(ExpressionNamespace): def join(self, delimiter: str | Expression) -> Expression: diff --git a/daft/series.py b/daft/series.py index 1f4782b4f0..a46c0fdd86 100644 --- a/daft/series.py +++ b/daft/series.py @@ -590,6 +590,10 @@ def length(self) -> Series: assert self._series is not None return Series._from_pyseries(self._series.utf8_length()) + def lower(self) -> Series: + assert self._series is not None + return Series._from_pyseries(self._series.utf8_lower()) + class SeriesDateNamespace(SeriesNamespace): def date(self) -> Series: diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index e977c849d8..d89f6c79ea 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -98,6 +98,7 @@ The following methods are available under the ``expr.str`` attribute. Expression.str.concat Expression.str.length Expression.str.split + Expression.str.lower .. _api-expressions-temporal: diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index 7f9987078b..efc135532c 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -3,7 +3,7 @@ use crate::{ datatypes::{BooleanArray, Field, UInt64Array, Utf8Array}, DataType, Series, }; -use arrow2; +use arrow2::{self}; use common_error::{DaftError, DaftResult}; @@ -146,6 +146,19 @@ impl Utf8Array { Ok(UInt64Array::from((self.name(), Box::new(arrow_result)))) } + pub fn lower(&self) -> DaftResult { + let self_arrow = self.as_arrow(); + let arrow_result = self_arrow + .iter() + .map(|val| { + let v = val?; + Some(v.to_lowercase()) + }) + .collect::>() + .with_validity(self_arrow.validity().cloned()); + Ok(Utf8Array::from((self.name(), Box::new(arrow_result)))) + } + fn binary_broadcasted_compare( &self, other: &Self, diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index e89b206399..dc6eb2c5b4 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -264,6 +264,10 @@ impl PySeries { Ok(self.series.utf8_length()?.into()) } + pub fn utf8_lower(&self) -> PyResult { + Ok(self.series.utf8_lower()?.into()) + } + pub fn is_nan(&self) -> PyResult { Ok(self.series.is_nan()?.into()) } diff --git a/src/daft-core/src/series/ops/utf8.rs b/src/daft-core/src/series/ops/utf8.rs index fb2539b64e..951ff6b0d7 100644 --- a/src/daft-core/src/series/ops/utf8.rs +++ b/src/daft-core/src/series/ops/utf8.rs @@ -50,4 +50,14 @@ impl Series { ))), } } + + pub fn utf8_lower(&self) -> DaftResult { + match self.data_type() { + DataType::Utf8 => Ok(self.utf8()?.lower()?.into_series()), + DataType::Null => Ok(self.clone()), + dt => Err(DaftError::TypeError(format!( + "Lower not implemented for type {dt}" + ))), + } + } } diff --git a/src/daft-dsl/src/functions/utf8/lower.rs b/src/daft-dsl/src/functions/utf8/lower.rs new file mode 100644 index 0000000000..56e153e167 --- /dev/null +++ b/src/daft-dsl/src/functions/utf8/lower.rs @@ -0,0 +1,46 @@ +use daft_core::{ + datatypes::{DataType, Field}, + schema::Schema, + series::Series, +}; + +use crate::Expr; +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct LowerEvaluator {} + +impl FunctionEvaluator for LowerEvaluator { + fn fn_name(&self) -> &'static str { + "lower" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to lower to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [data] => data.utf8_lower(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/utf8/mod.rs b/src/daft-dsl/src/functions/utf8/mod.rs index 5c8901147e..967def3cfa 100644 --- a/src/daft-dsl/src/functions/utf8/mod.rs +++ b/src/daft-dsl/src/functions/utf8/mod.rs @@ -1,12 +1,14 @@ mod contains; mod endswith; mod length; +mod lower; mod split; mod startswith; use contains::ContainsEvaluator; use endswith::EndswithEvaluator; use length::LengthEvaluator; +use lower::LowerEvaluator; use serde::{Deserialize, Serialize}; use split::SplitEvaluator; use startswith::StartswithEvaluator; @@ -22,6 +24,7 @@ pub enum Utf8Expr { Contains, Split, Length, + Lower, } impl Utf8Expr { @@ -34,6 +37,7 @@ impl Utf8Expr { Contains => &ContainsEvaluator {}, Split => &SplitEvaluator {}, Length => &LengthEvaluator {}, + Lower => &LowerEvaluator {}, } } } @@ -72,3 +76,10 @@ pub fn length(data: &Expr) -> Expr { inputs: vec![data.clone()], } } + +pub fn lower(data: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::Utf8(Utf8Expr::Lower), + inputs: vec![data.clone()], + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index f135b96bb4..fdd8c09401 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -333,6 +333,11 @@ impl PyExpr { Ok(length(&self.expr).into()) } + pub fn utf8_lower(&self) -> PyResult { + use crate::functions::utf8::lower; + Ok(lower(&self.expr).into()) + } + pub fn image_decode(&self) -> PyResult { use crate::functions::image::decode; Ok(decode(&self.expr).into()) diff --git a/tests/expressions/typing/test_str.py b/tests/expressions/typing/test_str.py index 99ab473190..da31ecbfff 100644 --- a/tests/expressions/typing/test_str.py +++ b/tests/expressions/typing/test_str.py @@ -43,3 +43,13 @@ def test_str_length(): run_kernel=s.str.length, resolvable=True, ) + + +def test_str_lower(): + s = Series.from_arrow(pa.array(["Foo", "BarBaz", "QUUX"]), name="arg") + assert_typing_resolve_vs_runtime_behavior( + data=[s], + expr=col(s.name()).str.lower(), + run_kernel=s.str.lower, + resolvable=True, + ) diff --git a/tests/series/test_utf8_ops.py b/tests/series/test_utf8_ops.py index 1019769501..203f584afb 100644 --- a/tests/series/test_utf8_ops.py +++ b/tests/series/test_utf8_ops.py @@ -222,3 +222,33 @@ def test_series_utf8_length_all_null() -> None: s = Series.from_arrow(pa.array([None, None, None])) result = s.str.length() assert result.to_pylist() == [None, None, None] + + +def test_series_utf8_lower() -> None: + s = Series.from_arrow(pa.array(["Foo", "BarBaz", "QUUX"])) + result = s.str.lower() + assert result.to_pylist() == ["foo", "barbaz", "quux"] + + +def test_series_utf8_lower_with_nulls() -> None: + s = Series.from_arrow(pa.array(["Foo", None, "BarBaz", "QUUX"])) + result = s.str.lower() + assert result.to_pylist() == ["foo", None, "barbaz", "quux"] + + +def test_series_utf8_lower_empty() -> None: + s = Series.from_arrow(pa.array([], type=pa.string())) + result = s.str.lower() + assert result.to_pylist() == [] + + +def test_series_utf8_lower_all_null() -> None: + s = Series.from_arrow(pa.array([None, None, None])) + result = s.str.lower() + assert result.to_pylist() == [None, None, None] + + +def test_series_utf8_lower_all_numeric_strs() -> None: + s = Series.from_arrow(pa.array(["1", "2", "3"])) + result = s.str.lower() + assert result.to_pylist() == ["1", "2", "3"] diff --git a/tests/table/utf8/test_lower.py b/tests/table/utf8/test_lower.py new file mode 100644 index 0000000000..6ed7d6b050 --- /dev/null +++ b/tests/table/utf8/test_lower.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from daft.expressions import col +from daft.table import MicroPartition + + +def test_utf8_lower(): + table = MicroPartition.from_pydict({"col": ["Foo", None, "BarBaz", "QUUX"]}) + result = table.eval_expression_list([col("col").str.lower()]) + assert result.to_pydict() == {"col": ["foo", None, "barbaz", "quux"]} From 9c66a5e07d6e43202685cf2ae2540a48b165eef5 Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Wed, 21 Feb 2024 10:42:43 -0800 Subject: [PATCH 04/11] [DOCS] Add documentation for using and developing Daft on Ray (#1896) --- CONTRIBUTING.md | 11 +++ docs/source/user_guide/integrations.rst | 1 + docs/source/user_guide/integrations/ray.rst | 90 +++++++++++++++++++++ 3 files changed, 102 insertions(+) create mode 100644 docs/source/user_guide/integrations/ray.rst diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5f13ac06bf..793e4c6c8c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -36,3 +36,14 @@ To set up your development environment: 1. `make build`: recompile your code after modifying any Rust code in `src/` 2. `make test`: run tests 3. `DAFT_RUNNER=ray make test`: set the runner to the Ray runner and run tests (DAFT_RUNNER defaults to `py`) + +### Developing with Ray + +Running a development version of Daft on a local Ray cluster is as simple as including `daft.context.set_runner_ray()` in your Python script and then building and executing it as usual. + +To use a remote Ray cluster, run the following steps on the same operating system version as your Ray nodes, in order to ensure that your binaries are executable on Ray. + +1. `mkdir wd`: this is the working directory, it will hold all the files to be submitted to Ray for a job +2. `ln -s daft wd/daft`: create a symbolic link from the Python module to the working directory +3. `make build-release`: an optimized build to ensure that the module is small enough to be successfully uploaded to Ray. Run this after modifying any Rust code in `src/` +4. `ray job submit --working-dir wd --address "http://:8265" -- python script.py`: submit `wd/script.py` to be run on Ray diff --git a/docs/source/user_guide/integrations.rst b/docs/source/user_guide/integrations.rst index 4a565cfa22..6fdd7adcfe 100644 --- a/docs/source/user_guide/integrations.rst +++ b/docs/source/user_guide/integrations.rst @@ -3,6 +3,7 @@ Integrations .. toctree:: + integrations/ray integrations/iceberg integrations/microsoft-azure integrations/aws diff --git a/docs/source/user_guide/integrations/ray.rst b/docs/source/user_guide/integrations/ray.rst new file mode 100644 index 0000000000..3ea94cd20a --- /dev/null +++ b/docs/source/user_guide/integrations/ray.rst @@ -0,0 +1,90 @@ +Ray +=== + +`Ray `_ is an open-source framework for distributed computing. + +Daft's native support for Ray enables you to run distributed DataFrame workloads at scale. + +Usage +----- + +You can run Daft on Ray in two ways: by using the `Ray Client `_ or by submitting a Ray job. + +Ray Client +********** +The Ray client is quick way to get started with running tasks and retrieving their results on Ray using Python. + +.. WARNING:: + To run tasks using the Ray client, the version of Daft and the minor version (eg. 3.9, 3.10) of Python must match between client and server. + +Here's an example of how you can use the Ray client with Daft: + +.. code:: python + + >>> import daft + >>> import ray + >>> + >>> # Refer to the note under "Ray Job" for details on "runtime_env" + >>> ray.init("ray://:10001", runtime_env={"pip": ["getdaft"]}) + >>> + >>> # Starts the Ray client and tells Daft to use Ray to execute queries + >>> # If ray.init() has already been called, it uses the existing client + >>> daft.context.set_runner_ray("ray://:10001") + >>> + >>> df = daft.from_pydict({ + >>> "a": [3, 2, 5, 6, 1, 4], + >>> "b": [True, False, False, True, True, False] + >>> }) + >>> df = df.where(df["b"]).sort(df["a"]) + >>> + >>> # Daft executes the query remotely and returns a preview to the client + >>> df.collect() + ╭───────┬─────────╮ + │ a ┆ b │ + │ --- ┆ --- │ + │ Int64 ┆ Boolean │ + ╞═══════╪═════════╡ + │ 1 ┆ true │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ + │ 3 ┆ true │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ + │ 6 ┆ true │ + ╰───────┴─────────╯ + + (Showing first 3 of 3 rows) + +Ray Job +******* +Ray jobs allow for more control and observability over using the Ray client. In addition, your entire code runs on Ray, which means it is not constrained by the compute, network, library versions, or availability of your local machine. + +.. code:: python + + # wd/job.py + + import daft + + def main(): + # call without any arguments to connect to Ray from the head node + daft.context.set_runner_ray() + + # ... Run Daft commands here ... + + if __name__ == "__main__": + main() + +To submit this script as a job, use the Ray CLI, which can be installed with `pip install "ray[default]"`. + +.. code:: sh + + ray job submit \ + --working-dir wd \ + --address "http://:8265" \ + --runtime-env-json '{"pip": ["getdaft"]}' \ + -- python job.py + +.. NOTE:: + + The runtime env parameter specifies that Daft should be installed on the Ray workers. Alternative methods of including Daft in the worker dependencies can be found `here `_. + + +For more information about Ray jobs, see `Ray docs -> Ray Jobs Overview `_. From 5b2fe983e404f7861495f6270e4c63c1328be662 Mon Sep 17 00:00:00 2001 From: Nick Salerni Date: Wed, 21 Feb 2024 22:34:24 -0800 Subject: [PATCH 05/11] [FEAT] Add str.upper() function (#1942) * Adding the `upper` function to match https://ibis-project.org/reference/expression-strings#ibis.expr.types.strings.StringValue.upper * Added tests showing example usage * Refactor tests for str.lower to be a single parameterized test Closes #1920 --- daft/daft.pyi | 2 + daft/expressions/expressions.py | 11 +++++ daft/series.py | 4 ++ docs/source/api_docs/expressions.rst | 1 + src/daft-core/src/array/ops/utf8.rs | 15 +++++- src/daft-core/src/python/series.rs | 4 ++ src/daft-core/src/series/ops/utf8.rs | 10 ++++ src/daft-dsl/src/functions/utf8/mod.rs | 11 +++++ src/daft-dsl/src/functions/utf8/upper.rs | 46 ++++++++++++++++++ src/daft-dsl/src/python.rs | 5 ++ tests/expressions/typing/test_str.py | 10 ++++ tests/series/test_utf8_ops.py | 60 ++++++++++++++---------- tests/table/utf8/test_upper.py | 10 ++++ 13 files changed, 163 insertions(+), 26 deletions(-) create mode 100644 src/daft-dsl/src/functions/utf8/upper.rs create mode 100644 tests/table/utf8/test_upper.py diff --git a/daft/daft.pyi b/daft/daft.pyi index 46fc5d9448..fc2ddbd8ac 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -910,6 +910,7 @@ class PyExpr: def utf8_split(self, pattern: PyExpr) -> PyExpr: ... def utf8_length(self) -> PyExpr: ... def utf8_lower(self) -> PyExpr: ... + def utf8_upper(self) -> PyExpr: ... def image_decode(self) -> PyExpr: ... def image_encode(self, image_format: ImageFormat) -> PyExpr: ... def image_resize(self, w: int, h: int) -> PyExpr: ... @@ -986,6 +987,7 @@ class PySeries: def utf8_split(self, pattern: PySeries) -> PySeries: ... def utf8_length(self) -> PySeries: ... def utf8_lower(self) -> PySeries: ... + def utf8_upper(self) -> PySeries: ... def is_nan(self) -> PySeries: ... def dt_date(self) -> PySeries: ... def dt_day(self) -> PySeries: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 2901ce690d..964903a97d 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -743,6 +743,17 @@ def lower(self) -> Expression: """ return Expression._from_pyexpr(self._expr.utf8_lower()) + def upper(self) -> Expression: + """Convert UTF-8 string to all upper + + Example: + >>> col("x").str.upper() + + Returns: + Expression: a String expression which is `self` uppercased + """ + return Expression._from_pyexpr(self._expr.utf8_upper()) + class ExpressionListNamespace(ExpressionNamespace): def join(self, delimiter: str | Expression) -> Expression: diff --git a/daft/series.py b/daft/series.py index a46c0fdd86..f776dd8260 100644 --- a/daft/series.py +++ b/daft/series.py @@ -594,6 +594,10 @@ def lower(self) -> Series: assert self._series is not None return Series._from_pyseries(self._series.utf8_lower()) + def upper(self) -> Series: + assert self._series is not None + return Series._from_pyseries(self._series.utf8_upper()) + class SeriesDateNamespace(SeriesNamespace): def date(self) -> Series: diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index d89f6c79ea..f0d5002846 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -99,6 +99,7 @@ The following methods are available under the ``expr.str`` attribute. Expression.str.length Expression.str.split Expression.str.lower + Expression.str.upper .. _api-expressions-temporal: diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index efc135532c..5f6e3eb127 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -3,7 +3,7 @@ use crate::{ datatypes::{BooleanArray, Field, UInt64Array, Utf8Array}, DataType, Series, }; -use arrow2::{self}; +use arrow2; use common_error::{DaftError, DaftResult}; @@ -159,6 +159,19 @@ impl Utf8Array { Ok(Utf8Array::from((self.name(), Box::new(arrow_result)))) } + pub fn upper(&self) -> DaftResult { + let self_arrow = self.as_arrow(); + let arrow_result = self_arrow + .iter() + .map(|val| { + let v = val?; + Some(v.to_uppercase()) + }) + .collect::>() + .with_validity(self_arrow.validity().cloned()); + Ok(Utf8Array::from((self.name(), Box::new(arrow_result)))) + } + fn binary_broadcasted_compare( &self, other: &Self, diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index dc6eb2c5b4..78d4dd9997 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -268,6 +268,10 @@ impl PySeries { Ok(self.series.utf8_lower()?.into()) } + pub fn utf8_upper(&self) -> PyResult { + Ok(self.series.utf8_upper()?.into()) + } + pub fn is_nan(&self) -> PyResult { Ok(self.series.is_nan()?.into()) } diff --git a/src/daft-core/src/series/ops/utf8.rs b/src/daft-core/src/series/ops/utf8.rs index 951ff6b0d7..d29edafd41 100644 --- a/src/daft-core/src/series/ops/utf8.rs +++ b/src/daft-core/src/series/ops/utf8.rs @@ -60,4 +60,14 @@ impl Series { ))), } } + + pub fn utf8_upper(&self) -> DaftResult { + match self.data_type() { + DataType::Utf8 => Ok(self.utf8()?.upper()?.into_series()), + DataType::Null => Ok(self.clone()), + dt => Err(DaftError::TypeError(format!( + "Upper not implemented for type {dt}" + ))), + } + } } diff --git a/src/daft-dsl/src/functions/utf8/mod.rs b/src/daft-dsl/src/functions/utf8/mod.rs index 967def3cfa..c4789ff212 100644 --- a/src/daft-dsl/src/functions/utf8/mod.rs +++ b/src/daft-dsl/src/functions/utf8/mod.rs @@ -4,6 +4,7 @@ mod length; mod lower; mod split; mod startswith; +mod upper; use contains::ContainsEvaluator; use endswith::EndswithEvaluator; @@ -12,6 +13,7 @@ use lower::LowerEvaluator; use serde::{Deserialize, Serialize}; use split::SplitEvaluator; use startswith::StartswithEvaluator; +use upper::UpperEvaluator; use crate::Expr; @@ -25,6 +27,7 @@ pub enum Utf8Expr { Split, Length, Lower, + Upper, } impl Utf8Expr { @@ -38,6 +41,7 @@ impl Utf8Expr { Split => &SplitEvaluator {}, Length => &LengthEvaluator {}, Lower => &LowerEvaluator {}, + Upper => &UpperEvaluator {}, } } } @@ -83,3 +87,10 @@ pub fn lower(data: &Expr) -> Expr { inputs: vec![data.clone()], } } + +pub fn upper(data: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::Utf8(Utf8Expr::Upper), + inputs: vec![data.clone()], + } +} diff --git a/src/daft-dsl/src/functions/utf8/upper.rs b/src/daft-dsl/src/functions/utf8/upper.rs new file mode 100644 index 0000000000..6c7967561e --- /dev/null +++ b/src/daft-dsl/src/functions/utf8/upper.rs @@ -0,0 +1,46 @@ +use daft_core::{ + datatypes::{DataType, Field}, + schema::Schema, + series::Series, +}; + +use crate::Expr; +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct UpperEvaluator {} + +impl FunctionEvaluator for UpperEvaluator { + fn fn_name(&self) -> &'static str { + "upper" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to upper to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [data] => data.utf8_upper(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index fdd8c09401..f61583f411 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -338,6 +338,11 @@ impl PyExpr { Ok(lower(&self.expr).into()) } + pub fn utf8_upper(&self) -> PyResult { + use crate::functions::utf8::upper; + Ok(upper(&self.expr).into()) + } + pub fn image_decode(&self) -> PyResult { use crate::functions::image::decode; Ok(decode(&self.expr).into()) diff --git a/tests/expressions/typing/test_str.py b/tests/expressions/typing/test_str.py index da31ecbfff..1018dcff4e 100644 --- a/tests/expressions/typing/test_str.py +++ b/tests/expressions/typing/test_str.py @@ -53,3 +53,13 @@ def test_str_lower(): run_kernel=s.str.lower, resolvable=True, ) + + +def test_str_upper(): + s = Series.from_arrow(pa.array(["Foo", "BarBaz", "quux"]), name="arg") + assert_typing_resolve_vs_runtime_behavior( + data=[s], + expr=col(s.name()).str.upper(), + run_kernel=s.str.lower, + resolvable=True, + ) diff --git a/tests/series/test_utf8_ops.py b/tests/series/test_utf8_ops.py index 203f584afb..e1975f0359 100644 --- a/tests/series/test_utf8_ops.py +++ b/tests/series/test_utf8_ops.py @@ -224,31 +224,41 @@ def test_series_utf8_length_all_null() -> None: assert result.to_pylist() == [None, None, None] -def test_series_utf8_lower() -> None: - s = Series.from_arrow(pa.array(["Foo", "BarBaz", "QUUX"])) - result = s.str.lower() - assert result.to_pylist() == ["foo", "barbaz", "quux"] - - -def test_series_utf8_lower_with_nulls() -> None: - s = Series.from_arrow(pa.array(["Foo", None, "BarBaz", "QUUX"])) - result = s.str.lower() - assert result.to_pylist() == ["foo", None, "barbaz", "quux"] - - -def test_series_utf8_lower_empty() -> None: - s = Series.from_arrow(pa.array([], type=pa.string())) - result = s.str.lower() - assert result.to_pylist() == [] - - -def test_series_utf8_lower_all_null() -> None: - s = Series.from_arrow(pa.array([None, None, None])) +@pytest.mark.parametrize( + ["data", "expected"], + [ + (["Foo", "BarBaz", "QUUX"], ["foo", "barbaz", "quux"]), + # With at least one null + (["Foo", None, "BarBaz", "QUUX"], ["foo", None, "barbaz", "quux"]), + # With all nulls + ([None] * 4, [None] * 4), + # With at least one numeric strings + (["Foo", "BarBaz", "QUUX", "2"], ["foo", "barbaz", "quux", "2"]), + # With all numeric strings + (["1", "2", "3"], ["1", "2", "3"]), + ], +) +def test_series_utf8_lower(data, expected) -> None: + s = Series.from_arrow(pa.array(data)) result = s.str.lower() - assert result.to_pylist() == [None, None, None] + assert result.to_pylist() == expected -def test_series_utf8_lower_all_numeric_strs() -> None: - s = Series.from_arrow(pa.array(["1", "2", "3"])) - result = s.str.lower() - assert result.to_pylist() == ["1", "2", "3"] +@pytest.mark.parametrize( + ["data", "expected"], + [ + (["Foo", "BarBaz", "quux"], ["FOO", "BARBAZ", "QUUX"]), + # With at least one null + (["Foo", None, "BarBaz", "quux"], ["FOO", None, "BARBAZ", "QUUX"]), + # With all nulls + ([None] * 4, [None] * 4), + # With at least one numeric strings + (["Foo", "BarBaz", "quux", "2"], ["FOO", "BARBAZ", "QUUX", "2"]), + # With all numeric strings + (["1", "2", "3"], ["1", "2", "3"]), + ], +) +def test_series_utf8_upper(data, expected) -> None: + s = Series.from_arrow(pa.array(data)) + result = s.str.upper() + assert result.to_pylist() == expected diff --git a/tests/table/utf8/test_upper.py b/tests/table/utf8/test_upper.py new file mode 100644 index 0000000000..812afdf7d3 --- /dev/null +++ b/tests/table/utf8/test_upper.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from daft.expressions import col +from daft.table import MicroPartition + + +def test_utf8_upper(): + table = MicroPartition.from_pydict({"col": ["Foo", None, "BarBaz", "quux", "1"]}) + result = table.eval_expression_list([col("col").str.upper()]) + assert result.to_pydict() == {"col": ["FOO", None, "BARBAZ", "QUUX", "1"]} From bfa7621ecabe2a322173c823e92a24cb0d40f86c Mon Sep 17 00:00:00 2001 From: Nick Salerni Date: Fri, 23 Feb 2024 11:14:46 -0800 Subject: [PATCH 06/11] [FEAT] Add str.lstrip() and str.rstrip() functions (#1944) * Added the `lstrip` function to match https://ibis-project.org/reference/expression-strings#ibis.expr.types.strings.StringValue.lstrip * Added the `rstrip` function to match https://ibis-project.org/reference/expression-strings#ibis.expr.types.strings.StringValue.rstrip * Added tests showing example usage Closes #1922 and #1923 --- daft/daft.pyi | 4 ++ daft/expressions/expressions.py | 22 +++++++++++ daft/series.py | 8 ++++ docs/source/api_docs/expressions.rst | 2 + src/daft-core/src/array/ops/utf8.rs | 26 +++++++++++++ src/daft-core/src/python/series.rs | 8 ++++ src/daft-core/src/series/ops/utf8.rs | 20 ++++++++++ src/daft-dsl/src/functions/utf8/lstrip.rs | 46 +++++++++++++++++++++++ src/daft-dsl/src/functions/utf8/mod.rs | 22 +++++++++++ src/daft-dsl/src/functions/utf8/rstrip.rs | 46 +++++++++++++++++++++++ src/daft-dsl/src/python.rs | 10 +++++ tests/expressions/typing/test_str.py | 20 ++++++++++ tests/series/test_utf8_ops.py | 32 ++++++++++++++++ tests/table/utf8/test_lstrip.py | 10 +++++ tests/table/utf8/test_rstrip.py | 10 +++++ 15 files changed, 286 insertions(+) create mode 100644 src/daft-dsl/src/functions/utf8/lstrip.rs create mode 100644 src/daft-dsl/src/functions/utf8/rstrip.rs create mode 100644 tests/table/utf8/test_lstrip.py create mode 100644 tests/table/utf8/test_rstrip.py diff --git a/daft/daft.pyi b/daft/daft.pyi index fc2ddbd8ac..bbb27200a9 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -911,6 +911,8 @@ class PyExpr: def utf8_length(self) -> PyExpr: ... def utf8_lower(self) -> PyExpr: ... def utf8_upper(self) -> PyExpr: ... + def utf8_lstrip(self) -> PyExpr: ... + def utf8_rstrip(self) -> PyExpr: ... def image_decode(self) -> PyExpr: ... def image_encode(self, image_format: ImageFormat) -> PyExpr: ... def image_resize(self, w: int, h: int) -> PyExpr: ... @@ -988,6 +990,8 @@ class PySeries: def utf8_length(self) -> PySeries: ... def utf8_lower(self) -> PySeries: ... def utf8_upper(self) -> PySeries: ... + def utf8_lstrip(self) -> PySeries: ... + def utf8_rstrip(self) -> PySeries: ... def is_nan(self) -> PySeries: ... def dt_date(self) -> PySeries: ... def dt_day(self) -> PySeries: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 964903a97d..d994a86357 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -754,6 +754,28 @@ def upper(self) -> Expression: """ return Expression._from_pyexpr(self._expr.utf8_upper()) + def lstrip(self) -> Expression: + """Strip whitespace from the left side of a UTF-8 string + + Example: + >>> col("x").str.lstrip() + + Returns: + Expression: a String expression which is `self` with leading whitespace stripped + """ + return Expression._from_pyexpr(self._expr.utf8_lstrip()) + + def rstrip(self) -> Expression: + """Strip whitespace from the right side of a UTF-8 string + + Example: + >>> col("x").str.rstrip() + + Returns: + Expression: a String expression which is `self` with trailing whitespace stripped + """ + return Expression._from_pyexpr(self._expr.utf8_rstrip()) + class ExpressionListNamespace(ExpressionNamespace): def join(self, delimiter: str | Expression) -> Expression: diff --git a/daft/series.py b/daft/series.py index f776dd8260..fe392e9bfe 100644 --- a/daft/series.py +++ b/daft/series.py @@ -598,6 +598,14 @@ def upper(self) -> Series: assert self._series is not None return Series._from_pyseries(self._series.utf8_upper()) + def lstrip(self) -> Series: + assert self._series is not None + return Series._from_pyseries(self._series.utf8_lstrip()) + + def rstrip(self) -> Series: + assert self._series is not None + return Series._from_pyseries(self._series.utf8_rstrip()) + class SeriesDateNamespace(SeriesNamespace): def date(self) -> Series: diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index f0d5002846..285fe2b1a8 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -100,6 +100,8 @@ The following methods are available under the ``expr.str`` attribute. Expression.str.split Expression.str.lower Expression.str.upper + Expression.str.lstrip + Expression.str.rstrip .. _api-expressions-temporal: diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index 5f6e3eb127..feb472841f 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -172,6 +172,32 @@ impl Utf8Array { Ok(Utf8Array::from((self.name(), Box::new(arrow_result)))) } + pub fn lstrip(&self) -> DaftResult { + let self_arrow = self.as_arrow(); + let arrow_result = self_arrow + .iter() + .map(|val| { + let v = val?; + Some(v.trim_start()) + }) + .collect::>() + .with_validity(self_arrow.validity().cloned()); + Ok(Utf8Array::from((self.name(), Box::new(arrow_result)))) + } + + pub fn rstrip(&self) -> DaftResult { + let self_arrow = self.as_arrow(); + let arrow_result = self_arrow + .iter() + .map(|val| { + let v = val?; + Some(v.trim_end()) + }) + .collect::>() + .with_validity(self_arrow.validity().cloned()); + Ok(Utf8Array::from((self.name(), Box::new(arrow_result)))) + } + fn binary_broadcasted_compare( &self, other: &Self, diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 78d4dd9997..f7e2899fa3 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -272,6 +272,14 @@ impl PySeries { Ok(self.series.utf8_upper()?.into()) } + pub fn utf8_lstrip(&self) -> PyResult { + Ok(self.series.utf8_lstrip()?.into()) + } + + pub fn utf8_rstrip(&self) -> PyResult { + Ok(self.series.utf8_rstrip()?.into()) + } + pub fn is_nan(&self) -> PyResult { Ok(self.series.is_nan()?.into()) } diff --git a/src/daft-core/src/series/ops/utf8.rs b/src/daft-core/src/series/ops/utf8.rs index d29edafd41..3f04f8394c 100644 --- a/src/daft-core/src/series/ops/utf8.rs +++ b/src/daft-core/src/series/ops/utf8.rs @@ -70,4 +70,24 @@ impl Series { ))), } } + + pub fn utf8_lstrip(&self) -> DaftResult { + match self.data_type() { + DataType::Utf8 => Ok(self.utf8()?.lstrip()?.into_series()), + DataType::Null => Ok(self.clone()), + dt => Err(DaftError::TypeError(format!( + "Lstrip not implemented for type {dt}" + ))), + } + } + + pub fn utf8_rstrip(&self) -> DaftResult { + match self.data_type() { + DataType::Utf8 => Ok(self.utf8()?.rstrip()?.into_series()), + DataType::Null => Ok(self.clone()), + dt => Err(DaftError::TypeError(format!( + "Rstrip not implemented for type {dt}" + ))), + } + } } diff --git a/src/daft-dsl/src/functions/utf8/lstrip.rs b/src/daft-dsl/src/functions/utf8/lstrip.rs new file mode 100644 index 0000000000..7a8e7e2b57 --- /dev/null +++ b/src/daft-dsl/src/functions/utf8/lstrip.rs @@ -0,0 +1,46 @@ +use daft_core::{ + datatypes::{DataType, Field}, + schema::Schema, + series::Series, +}; + +use crate::Expr; +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct LstripEvaluator {} + +impl FunctionEvaluator for LstripEvaluator { + fn fn_name(&self) -> &'static str { + "lstrip" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to lstrip to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [data] => data.utf8_lstrip(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/utf8/mod.rs b/src/daft-dsl/src/functions/utf8/mod.rs index c4789ff212..e55ea5b934 100644 --- a/src/daft-dsl/src/functions/utf8/mod.rs +++ b/src/daft-dsl/src/functions/utf8/mod.rs @@ -2,6 +2,8 @@ mod contains; mod endswith; mod length; mod lower; +mod lstrip; +mod rstrip; mod split; mod startswith; mod upper; @@ -10,6 +12,8 @@ use contains::ContainsEvaluator; use endswith::EndswithEvaluator; use length::LengthEvaluator; use lower::LowerEvaluator; +use lstrip::LstripEvaluator; +use rstrip::RstripEvaluator; use serde::{Deserialize, Serialize}; use split::SplitEvaluator; use startswith::StartswithEvaluator; @@ -28,6 +32,8 @@ pub enum Utf8Expr { Length, Lower, Upper, + Lstrip, + Rstrip, } impl Utf8Expr { @@ -42,6 +48,8 @@ impl Utf8Expr { Length => &LengthEvaluator {}, Lower => &LowerEvaluator {}, Upper => &UpperEvaluator {}, + Lstrip => &LstripEvaluator {}, + Rstrip => &RstripEvaluator {}, } } } @@ -94,3 +102,17 @@ pub fn upper(data: &Expr) -> Expr { inputs: vec![data.clone()], } } + +pub fn lstrip(data: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::Utf8(Utf8Expr::Lstrip), + inputs: vec![data.clone()], + } +} + +pub fn rstrip(data: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::Utf8(Utf8Expr::Rstrip), + inputs: vec![data.clone()], + } +} diff --git a/src/daft-dsl/src/functions/utf8/rstrip.rs b/src/daft-dsl/src/functions/utf8/rstrip.rs new file mode 100644 index 0000000000..27bdb49b4f --- /dev/null +++ b/src/daft-dsl/src/functions/utf8/rstrip.rs @@ -0,0 +1,46 @@ +use daft_core::{ + datatypes::{DataType, Field}, + schema::Schema, + series::Series, +}; + +use crate::Expr; +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct RstripEvaluator {} + +impl FunctionEvaluator for RstripEvaluator { + fn fn_name(&self) -> &'static str { + "rstrip" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to rstrip to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [data] => data.utf8_rstrip(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index f61583f411..4dc8eeb74c 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -343,6 +343,16 @@ impl PyExpr { Ok(upper(&self.expr).into()) } + pub fn utf8_lstrip(&self) -> PyResult { + use crate::functions::utf8::lstrip; + Ok(lstrip(&self.expr).into()) + } + + pub fn utf8_rstrip(&self) -> PyResult { + use crate::functions::utf8::rstrip; + Ok(rstrip(&self.expr).into()) + } + pub fn image_decode(&self) -> PyResult { use crate::functions::image::decode; Ok(decode(&self.expr).into()) diff --git a/tests/expressions/typing/test_str.py b/tests/expressions/typing/test_str.py index 1018dcff4e..a05211c168 100644 --- a/tests/expressions/typing/test_str.py +++ b/tests/expressions/typing/test_str.py @@ -63,3 +63,23 @@ def test_str_upper(): run_kernel=s.str.lower, resolvable=True, ) + + +def test_str_lstrip(): + s = Series.from_arrow(pa.array(["\ta\t", "\nb\n", "\vc\t", " c\t"]), name="arg") + assert_typing_resolve_vs_runtime_behavior( + data=[s], + expr=col(s.name()).str.lstrip(), + run_kernel=s.str.lstrip, + resolvable=True, + ) + + +def test_str_rstrip(): + s = Series.from_arrow(pa.array(["\ta\t", "\nb\n", "\vc\t", "\tc "]), name="arg") + assert_typing_resolve_vs_runtime_behavior( + data=[s], + expr=col(s.name()).str.rstrip(), + run_kernel=s.str.rstrip, + resolvable=True, + ) diff --git a/tests/series/test_utf8_ops.py b/tests/series/test_utf8_ops.py index e1975f0359..2f31302724 100644 --- a/tests/series/test_utf8_ops.py +++ b/tests/series/test_utf8_ops.py @@ -262,3 +262,35 @@ def test_series_utf8_upper(data, expected) -> None: s = Series.from_arrow(pa.array(data)) result = s.str.upper() assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + ["data", "expected"], + [ + (["\ta\t", "\nb\n", "\vc\t", "\td ", "e"], ["a\t", "b\n", "c\t", "d ", "e"]), + # With at least one null + (["\ta\t", None, "\vc\t", "\td ", "e"], ["a\t", None, "c\t", "d ", "e"]), + # With all nulls + ([None] * 4, [None] * 4), + ], +) +def test_series_utf8_lstrip(data, expected) -> None: + s = Series.from_arrow(pa.array(data)) + result = s.str.lstrip() + assert result.to_pylist() == expected + + +@pytest.mark.parametrize( + ["data", "expected"], + [ + (["\ta\t", "\nb\n", "\vc\t", "\td ", "e"], ["\ta", "\nb", "\vc", "\td", "e"]), + # With at least one null + (["\ta\t", None, "\vc\t", "\td ", "e"], ["\ta", None, "\vc", "\td", "e"]), + # With all nulls + ([None] * 4, [None] * 4), + ], +) +def test_series_utf8_rstrip(data, expected) -> None: + s = Series.from_arrow(pa.array(data)) + result = s.str.rstrip() + assert result.to_pylist() == expected diff --git a/tests/table/utf8/test_lstrip.py b/tests/table/utf8/test_lstrip.py new file mode 100644 index 0000000000..e515131fa1 --- /dev/null +++ b/tests/table/utf8/test_lstrip.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from daft.expressions import col +from daft.table import MicroPartition + + +def test_utf8_lstrip(): + table = MicroPartition.from_pydict({"col": ["\ta\t", None, "\nb\n", "\vc\t", "\td ", "e"]}) + result = table.eval_expression_list([col("col").str.lstrip()]) + assert result.to_pydict() == {"col": ["a\t", None, "b\n", "c\t", "d ", "e"]} diff --git a/tests/table/utf8/test_rstrip.py b/tests/table/utf8/test_rstrip.py new file mode 100644 index 0000000000..db7b21218a --- /dev/null +++ b/tests/table/utf8/test_rstrip.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from daft.expressions import col +from daft.table import MicroPartition + + +def test_utf8_rstrip(): + table = MicroPartition.from_pydict({"col": ["\ta\t", None, "\nb\n", "\vc\t", "\td ", "e"]}) + result = table.eval_expression_list([col("col").str.rstrip()]) + assert result.to_pydict() == {"col": ["\ta", None, "\nb", "\vc", "\td", "e"]} From 99db7772d5a0e8109eb2938cb0c0e770531ab47e Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 23 Feb 2024 13:55:24 -0800 Subject: [PATCH 07/11] [FEAT] Time Array (#1892) Closes #1846 Adds a Time Array logical type - Implemented with i64 physical type - Supports microseconds and nanoseconds Screenshot 2024-02-16 at 6 05 10 PM Todo in separate PR: - temporal expressions to extract minutes, seconds, milliseconds, etc --- daft/daft.pyi | 3 + daft/datatype.py | 10 ++++ daft/expressions/expressions.py | 9 ++- .../src/array/growable/logical_growable.rs | 3 +- src/daft-core/src/array/growable/mod.rs | 3 +- src/daft-core/src/array/ops/as_arrow.rs | 3 +- src/daft-core/src/array/ops/cast.rs | 40 +++++++++++-- src/daft-core/src/array/ops/date.rs | 59 ++++++++++++++++++- src/daft-core/src/array/ops/get.rs | 5 +- src/daft-core/src/array/ops/hash.rs | 11 +++- src/daft-core/src/array/ops/repr.rs | 20 ++++++- src/daft-core/src/array/ops/sort.rs | 9 ++- src/daft-core/src/array/ops/take.rs | 3 +- src/daft-core/src/datatypes/dtype.rs | 1 + src/daft-core/src/datatypes/logical.rs | 5 +- src/daft-core/src/datatypes/matching.rs | 3 +- src/daft-core/src/datatypes/mod.rs | 2 +- src/daft-core/src/python/datatype.rs | 14 +++++ .../src/series/array_impl/binary_ops.rs | 3 +- .../src/series/array_impl/logical_array.rs | 3 +- src/daft-core/src/series/ops/downcast.rs | 8 ++- src/daft-core/src/series/ops/hash.rs | 1 + src/daft-core/src/series/serdes.rs | 11 +++- src/daft-core/src/utils/display_table.rs | 17 ++++++ src/daft-dsl/src/lib.rs | 1 + src/daft-dsl/src/lit.rs | 14 +++++ src/daft-dsl/src/python.rs | 6 ++ tests/expressions/test_expressions.py | 3 +- tests/expressions/typing/conftest.py | 2 + tests/io/test_csv_roundtrip.py | 12 ++++ tests/io/test_parquet_roundtrip.py | 10 ++++ tests/series/test_cast.py | 18 ++++++ tests/series/test_hash.py | 17 +++++- tests/series/test_size_bytes.py | 19 ++++++ tests/series/test_sort.py | 41 +++++++++++++ tests/series/test_take.py | 17 ++++++ tests/table/test_from_py.py | 7 +++ 37 files changed, 384 insertions(+), 29 deletions(-) diff --git a/daft/daft.pyi b/daft/daft.pyi index bbb27200a9..a87403a3b2 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -801,6 +801,8 @@ class PyDataType: @staticmethod def date() -> PyDataType: ... @staticmethod + def time(time_unit: PyTimeUnit) -> PyDataType: ... + @staticmethod def timestamp(time_unit: PyTimeUnit, timezone: str | None = None) -> PyDataType: ... @staticmethod def duration(time_unit: PyTimeUnit) -> PyDataType: ... @@ -935,6 +937,7 @@ def eq(expr1: PyExpr, expr2: PyExpr) -> bool: ... def col(name: str) -> PyExpr: ... def lit(item: Any) -> PyExpr: ... def date_lit(item: int) -> PyExpr: ... +def time_lit(item: int, tu: PyTimeUnit) -> PyExpr: ... def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ... def series_lit(item: PySeries) -> PyExpr: ... def udf(func: Callable, expressions: list[PyExpr], return_dtype: PyDataType) -> PyExpr: ... diff --git a/daft/datatype.py b/daft/datatype.py index d4e1456399..06b5443a89 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -193,6 +193,13 @@ def date(cls) -> DataType: """Create a Date DataType: A date with a year, month and day""" return cls._from_pydatatype(PyDataType.date()) + @classmethod + def time(cls, timeunit: TimeUnit | str) -> DataType: + """Time DataType. Supported timeunits are "us", "ns".""" + if isinstance(timeunit, str): + timeunit = TimeUnit.from_str(timeunit) + return cls._from_pydatatype(PyDataType.time(timeunit._timeunit)) + @classmethod def timestamp(cls, timeunit: TimeUnit | str, timezone: str | None = None) -> DataType: """Timestamp DataType.""" @@ -359,6 +366,9 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType: return cls.decimal128(arrow_type.precision, arrow_type.scale) elif pa.types.is_date32(arrow_type): return cls.date() + elif pa.types.is_time64(arrow_type): + timeunit = TimeUnit.from_str(pa.type_for_alias(str(arrow_type)).unit) + return cls.time(timeunit) elif pa.types.is_timestamp(arrow_type): timeunit = TimeUnit.from_str(arrow_type.unit) return cls.timestamp(timeunit=timeunit, timezone=arrow_type.tz) diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index d994a86357..b2ca62815a 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -3,7 +3,7 @@ import builtins import os import sys -from datetime import date, datetime +from datetime import date, datetime, time from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, TypeVar, overload import pyarrow as pa @@ -15,6 +15,7 @@ from daft.daft import date_lit as _date_lit from daft.daft import lit as _lit from daft.daft import series_lit as _series_lit +from daft.daft import time_lit as _time_lit from daft.daft import timestamp_lit as _timestamp_lit from daft.daft import udf as _udf from daft.datatype import DataType, TimeUnit @@ -80,6 +81,12 @@ def lit(value: object) -> Expression: # pyo3 date (PyDate) is not available when running in abi3 mode, workaround epoch_time = value - date(1970, 1, 1) lit_value = _date_lit(epoch_time.days) + elif isinstance(value, time): + # pyo3 time (PyTime) is not available when running in abi3 mode, workaround + pa_time = pa.scalar(value) + i64_value = pa_time.cast(pa.int64()).as_py() + time_unit = TimeUnit.from_str(pa.type_for_alias(str(pa_time.type)).unit)._timeunit + lit_value = _time_lit(i64_value, time_unit) elif isinstance(value, Series): lit_value = _series_lit(value._series) else: diff --git a/src/daft-core/src/array/growable/logical_growable.rs b/src/daft-core/src/array/growable/logical_growable.rs index c62345851b..5f0770789b 100644 --- a/src/daft-core/src/array/growable/logical_growable.rs +++ b/src/daft-core/src/array/growable/logical_growable.rs @@ -6,7 +6,7 @@ use crate::{ datatypes::{ logical::LogicalArray, DaftDataType, DaftLogicalType, DateType, Decimal128Type, DurationType, EmbeddingType, Field, FixedShapeImageType, FixedShapeTensorType, ImageType, - TensorType, TimestampType, + TensorType, TimeType, TimestampType, }, DataType, IntoSeries, Series, }; @@ -77,6 +77,7 @@ macro_rules! impl_logical_growable { impl_logical_growable!(LogicalTimestampGrowable, TimestampType); impl_logical_growable!(LogicalDurationGrowable, DurationType); impl_logical_growable!(LogicalDateGrowable, DateType); +impl_logical_growable!(LogicalTimeGrowable, TimeType); impl_logical_growable!(LogicalEmbeddingGrowable, EmbeddingType); impl_logical_growable!(LogicalFixedShapeImageGrowable, FixedShapeImageType); impl_logical_growable!(LogicalFixedShapeTensorGrowable, FixedShapeTensorType); diff --git a/src/daft-core/src/array/growable/mod.rs b/src/daft-core/src/array/growable/mod.rs index f745809e42..b0de488352 100644 --- a/src/daft-core/src/array/growable/mod.rs +++ b/src/daft-core/src/array/growable/mod.rs @@ -5,7 +5,7 @@ use crate::{ datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, TensorArray, TimestampArray, + FixedShapeTensorArray, ImageArray, TensorArray, TimeArray, TimestampArray, }, BinaryArray, BooleanArray, ExtensionArray, Float32Array, Float64Array, Int128Array, Int16Array, Int32Array, Int64Array, Int8Array, NullArray, UInt16Array, UInt32Array, @@ -192,6 +192,7 @@ impl_growable_array!( ); impl_growable_array!(DurationArray, logical_growable::LogicalDurationGrowable<'a>); impl_growable_array!(DateArray, logical_growable::LogicalDateGrowable<'a>); +impl_growable_array!(TimeArray, logical_growable::LogicalTimeGrowable<'a>); impl_growable_array!( EmbeddingArray, logical_growable::LogicalEmbeddingGrowable<'a> diff --git a/src/daft-core/src/array/ops/as_arrow.rs b/src/daft-core/src/array/ops/as_arrow.rs index 62950bb5e4..f9d2ef566b 100644 --- a/src/daft-core/src/array/ops/as_arrow.rs +++ b/src/daft-core/src/array/ops/as_arrow.rs @@ -4,7 +4,7 @@ use arrow2::array; use crate::{ array::DataArray, datatypes::{ - logical::{DateArray, Decimal128Array, DurationArray, TimestampArray}, + logical::{DateArray, Decimal128Array, DurationArray, TimeArray, TimestampArray}, BinaryArray, BooleanArray, DaftNumericType, NullArray, Utf8Array, }, }; @@ -65,5 +65,6 @@ impl_asarrow_dataarray!(PythonArray, PseudoArrowArray); impl_asarrow_logicalarray!(Decimal128Array, array::PrimitiveArray); impl_asarrow_logicalarray!(DateArray, array::PrimitiveArray); +impl_asarrow_logicalarray!(TimeArray, array::PrimitiveArray); impl_asarrow_logicalarray!(DurationArray, array::PrimitiveArray); impl_asarrow_logicalarray!(TimestampArray, array::PrimitiveArray); diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index 482f2ab168..417cf77874 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -4,20 +4,20 @@ use super::as_arrow::AsArrow; use crate::{ array::{ growable::make_growable, - ops::image::ImageArraySidecarData, - ops::{from_arrow::FromArrow, full::FullNull}, + ops::{from_arrow::FromArrow, full::FullNull, image::ImageArraySidecarData}, DataArray, FixedSizeListArray, ListArray, StructArray, }, datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, FixedShapeTensorArray, ImageArray, LogicalArray, LogicalArrayImpl, TensorArray, - TimestampArray, + TimeArray, TimestampArray, }, DaftArrowBackedType, DaftLogicalType, DataType, Field, ImageMode, Int64Array, TimeUnit, UInt64Array, Utf8Array, }, series::{IntoSeries, Series}, + utils::display_table::display_time64, with_match_daft_logical_primitive_types, }; use common_error::{DaftError, DaftResult}; @@ -72,7 +72,7 @@ where use DataType::*; let source_arrow_array = match source_dtype { // Wrapped primitives - Decimal128(..) | Date | Timestamp(..) | Duration(..) => { + Decimal128(..) | Date | Timestamp(..) | Duration(..) | Time(..) => { with_match_daft_logical_primitive_types!(source_dtype, |$T| { use arrow2::array::Array; to_cast @@ -115,7 +115,7 @@ where let target_physical_type = dtype.to_physical().to_arrow()?; match dtype { // Primitive wrapper types: change the arrow2 array's type field to primitive - Decimal128(..) | Date | Timestamp(..) | Duration(..) => { + Decimal128(..) | Date | Timestamp(..) | Duration(..) | Time(..) => { with_match_daft_logical_primitive_types!(dtype, |$P| { use arrow2::array::Array; result_arrow_array @@ -363,6 +363,7 @@ impl TimestampArray { match dtype { DataType::Timestamp(..) => arrow_logical_cast(self, dtype), DataType::Date => Ok(self.date()?.into_series()), + DataType::Time(tu) => Ok(self.time(tu)?.into_series()), DataType::Utf8 => { let DataType::Timestamp(unit, timezone) = self.data_type() else { panic!("Wrong dtype for TimestampArray: {}", self.data_type()) @@ -407,6 +408,35 @@ impl TimestampArray { } } +impl TimeArray { + pub fn cast(&self, dtype: &DataType) -> DaftResult { + match dtype { + DataType::Time(..) => arrow_logical_cast(self, dtype), + DataType::Utf8 => { + let time_array = self.as_arrow(); + let time_str: arrow2::array::Utf8Array = time_array + .iter() + .map(|val| { + val.map(|val| { + let DataType::Time(unit) = &self.field.dtype else { + panic!("Wrong dtype for TimeArray: {}", self.field.dtype) + }; + display_time64(*val, unit) + }) + }) + .collect(); + Ok(Utf8Array::from((self.name(), Box::new(time_str))).into_series()) + } + DataType::Int64 => Ok(self.physical.clone().into_series()), + DataType::Float32 => self.cast(&DataType::Int64)?.cast(&DataType::Float32), + DataType::Float64 => self.cast(&DataType::Int64)?.cast(&DataType::Float64), + #[cfg(feature = "python")] + DataType::Python => cast_logical_to_python_array(self, dtype), + _ => arrow_cast(&self.physical, dtype), + } + } +} + impl DurationArray { pub fn cast(&self, dtype: &DataType) -> DaftResult { match dtype { diff --git a/src/daft-core/src/array/ops/date.rs b/src/daft-core/src/array/ops/date.rs index af1e896049..8bba15ee91 100644 --- a/src/daft-core/src/array/ops/date.rs +++ b/src/daft-core/src/array/ops/date.rs @@ -1,12 +1,12 @@ use crate::{ datatypes::{ - logical::{DateArray, TimestampArray}, - Field, Int32Array, UInt32Array, + logical::{DateArray, TimeArray, TimestampArray}, + Field, Int32Array, Int64Array, TimeUnit, UInt32Array, }, DataType, }; use arrow2::compute::arithmetics::ArraySub; -use chrono::{NaiveDate, Timelike}; +use chrono::{NaiveDate, NaiveTime, Timelike}; use common_error::{DaftError, DaftResult}; use super::as_arrow::AsArrow; @@ -108,6 +108,59 @@ impl TimestampArray { )) } + pub fn time(&self, timeunit_for_cast: &TimeUnit) -> DaftResult { + let physical = self.physical.as_arrow(); + let DataType::Timestamp(timeunit, tz) = self.data_type() else { + unreachable!("Timestamp array must have Timestamp datatype") + }; + let tu = timeunit.to_arrow(); + if !matches!( + timeunit_for_cast, + TimeUnit::Microseconds | TimeUnit::Nanoseconds + ) { + return Err(DaftError::ValueError(format!("Only microseconds and nanoseconds time units are supported for the Time dtype, but got {timeunit_for_cast}"))); + } + let time_arrow = match tz { + Some(tz) => match arrow2::temporal_conversions::parse_offset(tz) { + Ok(tz) => Ok(arrow2::array::PrimitiveArray::::from_iter( + physical.iter().map(|ts| { + ts.map(|ts| { + let dt = + arrow2::temporal_conversions::timestamp_to_datetime(*ts, tu, &tz); + let time_delta = dt.time() - NaiveTime::from_hms_opt(0,0,0).unwrap(); + match timeunit_for_cast { + TimeUnit::Microseconds => time_delta.num_microseconds().unwrap(), + TimeUnit::Nanoseconds => time_delta.num_nanoseconds().unwrap(), + _ => unreachable!("Only microseconds and nanoseconds time units are supported for the Time dtype, but got {timeunit_for_cast}"), + } + }) + }), + )), + Err(e) => Err(DaftError::TypeError(format!( + "Cannot parse timezone in Timestamp datatype: {}, error: {}", + tz, e + ))), + }, + None => Ok(arrow2::array::PrimitiveArray::::from_iter( + physical.iter().map(|ts| { + ts.map(|ts| { + let dt = arrow2::temporal_conversions::timestamp_to_naive_datetime(*ts, tu); + let time_delta = dt.time() - NaiveTime::from_hms_opt(0,0,0).unwrap(); + match timeunit_for_cast { + TimeUnit::Microseconds => time_delta.num_microseconds().unwrap(), + TimeUnit::Nanoseconds => time_delta.num_nanoseconds().unwrap(), + _ => unreachable!("Only microseconds and nanoseconds time units are supported for the Time dtype, but got {timeunit_for_cast}"), + } + }) + }), + )), + }?; + Ok(TimeArray::new( + Field::new(self.name(), DataType::Time(*timeunit_for_cast)), + Int64Array::from((self.name(), Box::new(time_arrow))), + )) + } + pub fn hour(&self) -> DaftResult { let physical = self.physical.as_arrow(); let DataType::Timestamp(timeunit, tz) = self.data_type() else { diff --git a/src/daft-core/src/array/ops/get.rs b/src/daft-core/src/array/ops/get.rs index a0f534da07..6f972325ef 100644 --- a/src/daft-core/src/array/ops/get.rs +++ b/src/daft-core/src/array/ops/get.rs @@ -1,7 +1,9 @@ use crate::{ array::{DataArray, FixedSizeListArray, ListArray}, datatypes::{ - logical::{DateArray, Decimal128Array, DurationArray, LogicalArrayImpl, TimestampArray}, + logical::{ + DateArray, Decimal128Array, DurationArray, LogicalArrayImpl, TimeArray, TimestampArray, + }, BinaryArray, BooleanArray, DaftLogicalType, DaftNumericType, ExtensionArray, NullArray, Utf8Array, }, @@ -66,6 +68,7 @@ impl_array_arrow_get!(BooleanArray, bool); impl_array_arrow_get!(BinaryArray, &[u8]); impl_array_arrow_get!(Decimal128Array, i128); impl_array_arrow_get!(DateArray, i32); +impl_array_arrow_get!(TimeArray, i64); impl_array_arrow_get!(DurationArray, i64); impl_array_arrow_get!(TimestampArray, i64); diff --git a/src/daft-core/src/array/ops/hash.rs b/src/daft-core/src/array/ops/hash.rs index 370c4208e6..f1a84529f3 100644 --- a/src/daft-core/src/array/ops/hash.rs +++ b/src/daft-core/src/array/ops/hash.rs @@ -1,7 +1,7 @@ use crate::{ array::DataArray, datatypes::{ - logical::{DateArray, Decimal128Array, TimestampArray}, + logical::{DateArray, Decimal128Array, TimeArray, TimestampArray}, BinaryArray, BooleanArray, DaftNumericType, Int16Array, Int32Array, Int64Array, Int8Array, NullArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, }, @@ -154,6 +154,15 @@ impl DateArray { } } +impl TimeArray { + pub fn murmur3_32(&self) -> DaftResult { + let us = self.cast(&crate::DataType::Time( + crate::datatypes::TimeUnit::Microseconds, + ))?; + us.time()?.physical.murmur3_32() + } +} + impl TimestampArray { pub fn murmur3_32(&self) -> DaftResult { let us = self.cast(&crate::DataType::Timestamp( diff --git a/src/daft-core/src/array/ops/repr.rs b/src/daft-core/src/array/ops/repr.rs index 2e50379473..e0d8c959d6 100644 --- a/src/daft-core/src/array/ops/repr.rs +++ b/src/daft-core/src/array/ops/repr.rs @@ -5,12 +5,12 @@ use crate::{ datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, TensorArray, TimestampArray, + FixedShapeTensorArray, ImageArray, TensorArray, TimeArray, TimestampArray, }, BinaryArray, BooleanArray, DaftNumericType, ExtensionArray, ImageFormat, NullArray, UInt64Array, Utf8Array, }, - utils::display_table::{display_date32, display_timestamp}, + utils::display_table::{display_date32, display_time64, display_timestamp}, with_match_daft_types, DataType, Series, }; use common_error::DaftResult; @@ -149,6 +149,21 @@ impl DateArray { } } +impl TimeArray { + pub fn str_value(&self, idx: usize) -> DaftResult { + let res = self.get(idx).map_or_else( + || "None".to_string(), + |val| -> String { + let DataType::Time(unit) = &self.field.dtype else { + panic!("Wrong dtype for TimeArray: {}", self.field.dtype) + }; + display_time64(val, unit) + }, + ); + Ok(res) + } +} + impl TimestampArray { pub fn str_value(&self, idx: usize) -> DaftResult { let res = self.get(idx).map_or_else( @@ -328,6 +343,7 @@ impl_array_html_value!(StructArray); impl_array_html_value!(ExtensionArray); impl_array_html_value!(Decimal128Array); impl_array_html_value!(DateArray); +impl_array_html_value!(TimeArray); impl_array_html_value!(DurationArray); impl_array_html_value!(TimestampArray); impl_array_html_value!(EmbeddingArray); diff --git a/src/daft-core/src/array/ops/sort.rs b/src/daft-core/src/array/ops/sort.rs index f677037e62..f1813c7a4c 100644 --- a/src/daft-core/src/array/ops/sort.rs +++ b/src/daft-core/src/array/ops/sort.rs @@ -3,7 +3,7 @@ use crate::{ datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, TensorArray, TimestampArray, + FixedShapeTensorArray, ImageArray, TensorArray, TimeArray, TimestampArray, }, BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray, Float32Array, Float64Array, NullArray, Utf8Array, @@ -607,6 +607,13 @@ impl DateArray { } } +impl TimeArray { + pub fn sort(&self, descending: bool) -> DaftResult { + let new_array = self.physical.sort(descending)?; + Ok(Self::new(self.field.clone(), new_array)) + } +} + impl DurationArray { pub fn sort(&self, descending: bool) -> DaftResult { let new_array = self.physical.sort(descending)?; diff --git a/src/daft-core/src/array/ops/take.rs b/src/daft-core/src/array/ops/take.rs index de1b59b48b..e393d4342e 100644 --- a/src/daft-core/src/array/ops/take.rs +++ b/src/daft-core/src/array/ops/take.rs @@ -6,7 +6,7 @@ use crate::{ datatypes::{ logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, TensorArray, TimestampArray, + FixedShapeTensorArray, ImageArray, TensorArray, TimeArray, TimestampArray, }, BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray, NullArray, Utf8Array, @@ -70,6 +70,7 @@ impl_dataarray_take!(NullArray); impl_dataarray_take!(ExtensionArray); impl_logicalarray_take!(Decimal128Array); impl_logicalarray_take!(DateArray); +impl_logicalarray_take!(TimeArray); impl_logicalarray_take!(DurationArray); impl_logicalarray_take!(TimestampArray); impl_logicalarray_take!(EmbeddingArray); diff --git a/src/daft-core/src/datatypes/dtype.rs b/src/daft-core/src/datatypes/dtype.rs index 37cdfd016a..65dbfe2b50 100644 --- a/src/daft-core/src/datatypes/dtype.rs +++ b/src/daft-core/src/datatypes/dtype.rs @@ -340,6 +340,7 @@ impl DataType { self, DataType::Decimal128(..) | DataType::Date + | DataType::Time(..) | DataType::Timestamp(..) | DataType::Duration(..) | DataType::Embedding(..) diff --git a/src/daft-core/src/datatypes/logical.rs b/src/daft-core/src/datatypes/logical.rs index 75c9d80bb2..ec00c609b7 100644 --- a/src/daft-core/src/datatypes/logical.rs +++ b/src/daft-core/src/datatypes/logical.rs @@ -9,7 +9,7 @@ use common_error::DaftResult; use super::{ DaftArrayType, DaftDataType, DataArray, DataType, Decimal128Type, DurationType, EmbeddingType, - FixedShapeImageType, FixedShapeTensorType, FixedSizeListArray, ImageType, TensorType, + FixedShapeImageType, FixedShapeTensorType, FixedSizeListArray, ImageType, TensorType, TimeType, TimestampType, }; @@ -94,7 +94,7 @@ impl LogicalArrayImpl> { use crate::datatypes::DataType::*; match daft_type { // For wrapped primitive types, switch the datatype label on the arrow2 Array. - Decimal128(..) | Date | Timestamp(..) | Duration(..) => { + Decimal128(..) | Date | Timestamp(..) | Duration(..) | Time(..) => { with_match_daft_logical_primitive_types!(daft_type, |$P| { use arrow2::array::Array; physical_arrow_array @@ -148,6 +148,7 @@ pub type LogicalArray = LogicalArrayImpl::PhysicalType as DaftDataType>::ArrayType>; pub type Decimal128Array = LogicalArray; pub type DateArray = LogicalArray; +pub type TimeArray = LogicalArray; pub type DurationArray = LogicalArray; pub type ImageArray = LogicalArray; pub type TimestampArray = LogicalArray; diff --git a/src/daft-core/src/datatypes/matching.rs b/src/daft-core/src/datatypes/matching.rs index 68a7ba0040..42778f8e85 100644 --- a/src/daft-core/src/datatypes/matching.rs +++ b/src/daft-core/src/datatypes/matching.rs @@ -23,6 +23,7 @@ macro_rules! with_match_daft_types {( Float64 => __with_ty__! { Float64Type }, Timestamp(_, _) => __with_ty__! { TimestampType }, Date => __with_ty__! { DateType }, + Time(_) => __with_ty__! { TimeType }, Duration(_) => __with_ty__! { DurationType }, Binary => __with_ty__! { BinaryType }, Utf8 => __with_ty__! { Utf8Type }, @@ -38,7 +39,6 @@ macro_rules! with_match_daft_types {( Tensor(..) => __with_ty__! { TensorType }, FixedShapeTensor(..) => __with_ty__! { FixedShapeTensorType }, Decimal128(..) => __with_ty__! { Decimal128Type }, - Time(_) => unimplemented!("Array for Time DataType not implemented"), // Float16 => unimplemented!("Array for Float16 DataType not implemented"), Unknown => unimplemented!("Array for Unknown DataType not implemented"), @@ -218,6 +218,7 @@ macro_rules! with_match_daft_logical_primitive_types {( Decimal128(..) => __with_ty__! { i128 }, Duration(..) => __with_ty__! { i64 }, Date => __with_ty__! { i32 }, + Time(..) => __with_ty__! { i64 }, Timestamp(..) => __with_ty__! { i64 }, _ => panic!("no logical -> primitive conversion available for {:?}", $key_type) } diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index efa9a34f09..1b91b499ca 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -174,7 +174,7 @@ impl_nested_datatype!(ListType, ListArray); impl_daft_logical_data_array_datatype!(Decimal128Type, Unknown, Int128Type); impl_daft_logical_data_array_datatype!(TimestampType, Unknown, Int64Type); impl_daft_logical_data_array_datatype!(DateType, Date, Int32Type); -// impl_daft_logical_data_array_datatype!(TimeType, Unknown, Int64Type); +impl_daft_logical_data_array_datatype!(TimeType, Unknown, Int64Type); impl_daft_logical_data_array_datatype!(DurationType, Unknown, Int64Type); impl_daft_logical_data_array_datatype!(ImageType, Unknown, StructType); impl_daft_logical_data_array_datatype!(TensorType, Unknown, StructType); diff --git a/src/daft-core/src/python/datatype.rs b/src/daft-core/src/python/datatype.rs index 09973227d9..aacdd647ad 100644 --- a/src/daft-core/src/python/datatype.rs +++ b/src/daft-core/src/python/datatype.rs @@ -160,6 +160,20 @@ impl PyDataType { Ok(DataType::Date.into()) } + #[staticmethod] + pub fn time(timeunit: PyTimeUnit) -> PyResult { + if !matches!( + timeunit.timeunit, + TimeUnit::Microseconds | TimeUnit::Nanoseconds + ) { + return Err(PyValueError::new_err(format!( + "The time unit for time types must be microseconds or nanoseconds, but got: {}", + timeunit.timeunit + ))); + } + Ok(DataType::Time(timeunit.timeunit).into()) + } + #[staticmethod] pub fn timestamp(timeunit: PyTimeUnit, timezone: Option) -> PyResult { Ok(DataType::Timestamp(timeunit.timeunit, timezone).into()) diff --git a/src/daft-core/src/series/array_impl/binary_ops.rs b/src/daft-core/src/series/array_impl/binary_ops.rs index b2a652d529..5fd46e106a 100644 --- a/src/daft-core/src/series/array_impl/binary_ops.rs +++ b/src/daft-core/src/series/array_impl/binary_ops.rs @@ -14,7 +14,7 @@ use crate::{ use crate::datatypes::logical::{ DateArray, DurationArray, EmbeddingArray, FixedShapeImageArray, FixedShapeTensorArray, - ImageArray, TensorArray, TimestampArray, + ImageArray, TensorArray, TimeArray, TimestampArray, }; use crate::datatypes::{ BinaryArray, BooleanArray, ExtensionArray, Float32Array, Float64Array, Int16Array, Int32Array, @@ -229,6 +229,7 @@ impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} +impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper { fn add(&self, rhs: &Series) -> DaftResult { use DataType::*; diff --git a/src/daft-core/src/series/array_impl/logical_array.rs b/src/daft-core/src/series/array_impl/logical_array.rs index 51a2ecd39d..8f4e82aab9 100644 --- a/src/daft-core/src/series/array_impl/logical_array.rs +++ b/src/daft-core/src/series/array_impl/logical_array.rs @@ -1,6 +1,6 @@ use crate::datatypes::logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, LogicalArray, TensorArray, TimestampArray, + FixedShapeTensorArray, ImageArray, LogicalArray, TensorArray, TimeArray, TimestampArray, }; use crate::datatypes::{BooleanArray, DaftLogicalType, Field}; @@ -225,6 +225,7 @@ macro_rules! impl_series_like_for_logical_array { impl_series_like_for_logical_array!(Decimal128Array); impl_series_like_for_logical_array!(DateArray); +impl_series_like_for_logical_array!(TimeArray); impl_series_like_for_logical_array!(DurationArray); impl_series_like_for_logical_array!(TimestampArray); impl_series_like_for_logical_array!(ImageArray); diff --git a/src/daft-core/src/series/ops/downcast.rs b/src/daft-core/src/series/ops/downcast.rs index 8d8469fe06..a9c5780fee 100644 --- a/src/daft-core/src/series/ops/downcast.rs +++ b/src/daft-core/src/series/ops/downcast.rs @@ -1,5 +1,7 @@ use crate::array::{FixedSizeListArray, ListArray, StructArray}; -use crate::datatypes::logical::{DateArray, Decimal128Array, FixedShapeImageArray, TimestampArray}; +use crate::datatypes::logical::{ + DateArray, Decimal128Array, FixedShapeImageArray, TimeArray, TimestampArray, +}; use crate::datatypes::*; use crate::series::array_impl::ArrayWrapper; use crate::series::Series; @@ -99,6 +101,10 @@ impl Series { self.downcast() } + pub fn time(&self) -> DaftResult<&TimeArray> { + self.downcast() + } + pub fn timestamp(&self) -> DaftResult<&TimestampArray> { self.downcast() } diff --git a/src/daft-core/src/series/ops/hash.rs b/src/daft-core/src/series/ops/hash.rs index 80359f06e5..f5196eeae4 100644 --- a/src/daft-core/src/series/ops/hash.rs +++ b/src/daft-core/src/series/ops/hash.rs @@ -28,6 +28,7 @@ impl Series { Utf8 => self.utf8()?.murmur3_32(), Binary => self.binary()?.murmur3_32(), Date => self.date()?.murmur3_32(), + Time(..) => self.time()?.murmur3_32(), Timestamp(..) => self.timestamp()?.murmur3_32(), Decimal128(..) => self.decimal128()?.murmur3_32(), v => panic!("murmur3 hash not implemented for datatype: {v}"), diff --git a/src/daft-core/src/series/serdes.rs b/src/daft-core/src/series/serdes.rs index 8d9370a7cf..6503a6ebae 100644 --- a/src/daft-core/src/series/serdes.rs +++ b/src/daft-core/src/series/serdes.rs @@ -10,7 +10,7 @@ use crate::{ }, datatypes::logical::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeTensorArray, ImageArray, TensorArray, TimestampArray, + FixedShapeTensorArray, ImageArray, TensorArray, TimeArray, TimestampArray, }, with_match_daft_types, DataType, IntoSeries, Series, }; @@ -227,7 +227,14 @@ impl<'d> serde::Deserialize<'d> for Series { .into_series(), ) } - Time(..) => panic!("Time Deserialization not implemented"), + Time(..) => { + type PType = <::PhysicalType as DaftDataType>::ArrayType; + let physical = map.next_value::()?; + Ok( + TimeArray::new(field, physical.downcast::().unwrap().clone()) + .into_series(), + ) + } Duration(..) => { type PType = <::PhysicalType as DaftDataType>::ArrayType; let physical = map.next_value::()?; diff --git a/src/daft-core/src/utils/display_table.rs b/src/daft-core/src/utils/display_table.rs index 5bb2a31482..ec63a0deda 100644 --- a/src/daft-core/src/utils/display_table.rs +++ b/src/daft-core/src/utils/display_table.rs @@ -15,6 +15,23 @@ pub fn display_date32(val: i32) -> String { format!("{date}") } +pub fn display_time64(val: i64, unit: &TimeUnit) -> String { + let time = match unit { + TimeUnit::Nanoseconds => chrono::NaiveTime::from_num_seconds_from_midnight_opt( + (val / 1_000_000_000) as u32, + (val % 1_000_000_000) as u32, + ) + .unwrap(), + TimeUnit::Microseconds => chrono::NaiveTime::from_num_seconds_from_midnight_opt( + (val / 1_000_000) as u32, + ((val % 1_000_000) * 1_000) as u32, + ) + .unwrap(), + _ => panic!("Unsupported time unit for time64: {unit}"), + }; + format!("{time}") +} + pub fn display_timestamp(val: i64, unit: &TimeUnit, timezone: &Option) -> String { use crate::array::ops::cast::{ timestamp_to_str_naive, timestamp_to_str_offset, timestamp_to_str_tz, diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index 78696a72e1..ea6cafc875 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -24,6 +24,7 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { parent.add_wrapped(wrap_pyfunction!(python::col))?; parent.add_wrapped(wrap_pyfunction!(python::lit))?; parent.add_wrapped(wrap_pyfunction!(python::date_lit))?; + parent.add_wrapped(wrap_pyfunction!(python::time_lit))?; parent.add_wrapped(wrap_pyfunction!(python::timestamp_lit))?; parent.add_wrapped(wrap_pyfunction!(python::series_lit))?; parent.add_wrapped(wrap_pyfunction!(python::udf))?; diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index ec8923cff9..c0fc08ee21 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -1,5 +1,7 @@ use crate::expr::Expr; +use daft_core::datatypes::logical::TimeArray; +use daft_core::utils::display_table::display_time64; use daft_core::utils::hashable_float_wrapper::FloatWrapper; use daft_core::{array::ops::full::FullNull, datatypes::DataType}; use daft_core::{ @@ -55,6 +57,8 @@ pub enum LiteralValue { /// An [`i32`] representing the elapsed time since UNIX epoch (1970-01-01) /// in days. Date(i32), + /// An [`i64`] representing a time in microseconds or nanoseconds since midnight. + Time(i64, TimeUnit), /// A 64-bit floating point number. Float64(f64), /// A list @@ -81,6 +85,10 @@ impl Hash for LiteralValue { Int64(n) => n.hash(state), UInt64(n) => n.hash(state), Date(n) => n.hash(state), + Time(n, tu) => { + n.hash(state); + tu.hash(state); + } Timestamp(n, tu, tz) => { n.hash(state); tu.hash(state); @@ -115,6 +123,7 @@ impl Display for LiteralValue { Int64(val) => write!(f, "{val}"), UInt64(val) => write!(f, "{val}"), Date(val) => write!(f, "{}", display_date32(*val)), + Time(val, tu) => write!(f, "{}", display_time64(*val, tu)), Timestamp(val, tu, tz) => write!(f, "{}", display_timestamp(*val, tu, tz)), Float64(val) => write!(f, "{val:.1}"), Series(series) => write!(f, "{}", display_series_literal(series)), @@ -145,6 +154,7 @@ impl LiteralValue { Int64(_) => DataType::Int64, UInt64(_) => DataType::UInt64, Date(_) => DataType::Date, + Time(_, tu) => DataType::Time(*tu), Timestamp(_, tu, tz) => DataType::Timestamp(*tu, tz.clone()), Float64(_) => DataType::Float64, Series(series) => series.data_type().clone(), @@ -170,6 +180,10 @@ impl LiteralValue { let physical = Int32Array::from(("literal", [*val].as_slice())); DateArray::new(Field::new("literal", self.get_type()), physical).into_series() } + Time(val, ..) => { + let physical = Int64Array::from(("literal", [*val].as_slice())); + TimeArray::new(Field::new("literal", self.get_type()), physical).into_series() + } Timestamp(val, ..) => { let physical = Int64Array::from(("literal", [*val].as_slice())); TimestampArray::new(Field::new("literal", self.get_type()), physical).into_series() diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 4dc8eeb74c..e5c26b7bed 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -34,6 +34,12 @@ pub fn date_lit(item: i32) -> PyResult { Ok(expr.into()) } +#[pyfunction] +pub fn time_lit(item: i64, tu: PyTimeUnit) -> PyResult { + let expr = Expr::Literal(LiteralValue::Time(item, tu.timeunit)); + Ok(expr.into()) +} + #[pyfunction] pub fn timestamp_lit(val: i64, tu: PyTimeUnit, tz: Option) -> PyResult { let expr = Expr::Literal(LiteralValue::Timestamp(val, tu.timeunit, tz)); diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 84102ef625..92c80df665 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -1,7 +1,7 @@ from __future__ import annotations import copy -from datetime import date, datetime +from datetime import date, datetime, time import pytest import pytz @@ -26,6 +26,7 @@ (None, DataType.null()), (Series.from_pylist([1, 2, 3]), DataType.int64()), (date(2023, 1, 1), DataType.date()), + (time(1, 2, 3, 4), DataType.time(timeunit=TimeUnit.from_str("us"))), (datetime(2023, 1, 1), DataType.timestamp(timeunit=TimeUnit.from_str("us"))), (datetime(2022, 1, 1, tzinfo=pytz.utc), DataType.timestamp(timeunit=TimeUnit.from_str("us"), timezone="UTC")), ], diff --git a/tests/expressions/typing/conftest.py b/tests/expressions/typing/conftest.py index 5c13c7e378..89afb366e6 100644 --- a/tests/expressions/typing/conftest.py +++ b/tests/expressions/typing/conftest.py @@ -226,6 +226,7 @@ def has_supertype(dt1: DataType, dt2: DataType) -> bool: # --- Across type hierarchies --- date_and_numeric = x == DataType.date() and is_numeric(y) + time_and_numeric = x == (DataType.time("us") or DataType.time("ns")) and is_numeric(y) timestamp_and_big_numeric = x._is_temporal_type() and is_numeric_bitwidth_gte_32(y) if ( @@ -234,6 +235,7 @@ def has_supertype(dt1: DataType, dt2: DataType) -> bool: or both_numeric or both_temporal or date_and_numeric + or time_and_numeric or timestamp_and_big_numeric ): return True diff --git a/tests/io/test_csv_roundtrip.py b/tests/io/test_csv_roundtrip.py index 1508376653..2043e30a1b 100644 --- a/tests/io/test_csv_roundtrip.py +++ b/tests/io/test_csv_roundtrip.py @@ -31,6 +31,18 @@ DataType.float64(), ), ([datetime.date(1994, 1, 1), datetime.date(1995, 1, 1), None], pa.date32(), DataType.date(), DataType.date()), + ( + [datetime.time(1, 2, 3, 4), datetime.time(5, 6, 7, 8), None], + pa.time64("us"), + DataType.time(TimeUnit.us()), + DataType.time(TimeUnit.us()), + ), + ( + [datetime.time(1, 2, 3, 4), datetime.time(5, 6, 7, 8), None], + pa.time64("ns"), + DataType.time(TimeUnit.ns()), + DataType.time(TimeUnit.us()), + ), ( [datetime.datetime(1994, 1, 1), datetime.datetime(1995, 1, 1), None], pa.timestamp("ms"), diff --git a/tests/io/test_parquet_roundtrip.py b/tests/io/test_parquet_roundtrip.py index 46fb92a661..79050f3fcc 100644 --- a/tests/io/test_parquet_roundtrip.py +++ b/tests/io/test_parquet_roundtrip.py @@ -27,6 +27,16 @@ ([None, None, None], pa.null(), DataType.null()), ([decimal.Decimal("1.23"), decimal.Decimal("1.24"), None], pa.decimal128(16, 8), DataType.decimal128(16, 8)), ([datetime.date(1994, 1, 1), datetime.date(1995, 1, 1), None], pa.date32(), DataType.date()), + ( + [datetime.time(12, 1, 22, 4), datetime.time(13, 8, 45, 34), None], + pa.time64("us"), + DataType.time(TimeUnit.us()), + ), + ( + [datetime.time(12, 1, 22, 4), datetime.time(13, 8, 45, 34), None], + pa.time64("ns"), + DataType.time(TimeUnit.ns()), + ), ( [datetime.datetime(1994, 1, 1), datetime.datetime(1995, 1, 1), None], pa.timestamp("ms"), diff --git a/tests/series/test_cast.py b/tests/series/test_cast.py index 9367ba131c..eb664e0118 100644 --- a/tests/series/test_cast.py +++ b/tests/series/test_cast.py @@ -773,3 +773,21 @@ def test_cast_date_to_timestamp(): back = casted.dt.date() assert (input == back).to_pylist() == [True] + + +@pytest.mark.parametrize("timeunit", ["us", "ns"]) +def test_cast_timestamp_to_time(timeunit): + from datetime import datetime, time + + input = Series.from_pylist([datetime(2022, 1, 6, 12, 34, 56, 78)]) + casted = input.cast(DataType.time(timeunit)) + assert casted.to_pylist() == [time(12, 34, 56, 78)] + + +@pytest.mark.parametrize("timeunit", ["s", "ms"]) +def test_cast_timestamp_to_time_unsupported_timeunit(timeunit): + from datetime import datetime + + input = Series.from_pylist([datetime(2022, 1, 6, 12, 34, 56, 78)]) + with pytest.raises(ValueError): + input.cast(DataType.time(timeunit)) diff --git a/tests/series/test_hash.py b/tests/series/test_hash.py index f2e3f8951a..444c69fb22 100644 --- a/tests/series/test_hash.py +++ b/tests/series/test_hash.py @@ -1,7 +1,7 @@ from __future__ import annotations import decimal -from datetime import date, datetime +from datetime import date, datetime, time import numpy as np import pytest @@ -176,6 +176,21 @@ def test_murmur3_32_hash_date(): assert hashes.to_pylist() == [-653330422, None] +def test_murmur3_32_hash_time(): + arr = Series.from_pylist([time(22, 31, 8, 0), None]) + assert arr.datatype() == DataType.time("us") + hashes = arr.murmur3_32() + assert hashes.to_pylist() == [-662762989, None] + + +def test_murmur3_32_hash_time_nanoseconds(): + arr = Series.from_pylist([time(22, 31, 8, 0), None]) + arr = arr.cast(DataType.time("ns")) + assert arr.datatype() == DataType.time("ns") + hashes = arr.murmur3_32() + assert hashes.to_pylist() == [-662762989, None] + + def test_murmur3_32_hash_timestamp(): arr = Series.from_pylist([datetime(2017, 11, 16, 22, 31, 8), None]) hashes = arr.murmur3_32() diff --git a/tests/series/test_size_bytes.py b/tests/series/test_size_bytes.py index 8bbd2ff883..3fe42b8662 100644 --- a/tests/series/test_size_bytes.py +++ b/tests/series/test_size_bytes.py @@ -88,6 +88,25 @@ def test_series_date_size_bytes(size, with_nulls) -> None: assert s.size_bytes() == get_total_buffer_size(data) +@pytest.mark.parametrize("size", [0, 1, 2, 8, 9, 16]) +@pytest.mark.parametrize("with_nulls", [True, False]) +@pytest.mark.parametrize("precision", ["us", "ns"]) +def test_series_time_size_bytes(size, with_nulls, precision) -> None: + from datetime import time + + pydata = [time(i, i, i, i) for i in range(size)] + + if with_nulls and size > 0: + data = pa.array(pydata[:-1] + [None], pa.time64(precision)) + else: + data = pa.array(pydata, pa.time64(precision)) + + s = Series.from_arrow(data) + + assert s.datatype() == DataType.time(precision) + assert s.size_bytes() == get_total_buffer_size(data) + + @pytest.mark.parametrize("size", [0, 1, 2, 8, 9, 16]) @pytest.mark.parametrize("with_nulls", [True, False]) def test_series_binary_size_bytes(size, with_nulls) -> None: diff --git a/tests/series/test_sort.py b/tests/series/test_sort.py index d9c67892ea..790aa4a1d1 100644 --- a/tests/series/test_sort.py +++ b/tests/series/test_sort.py @@ -112,6 +112,47 @@ def date_maker(d): assert taken.to_pylist() == sorted_order[::-1] +@pytest.mark.parametrize("timeunit", ["us", "ns"]) +def test_series_time_sorting(timeunit) -> None: + from datetime import time + + def time_maker(h, m, s, us): + if us is None: + return None + return time(h, m, s, us) + + times = list(map(time_maker, [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [5, 4, 1, None, 2, None])) + s = Series.from_pylist(times) + sorted_order = list( + map(time_maker, [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 2, 4, 5, None, None]) + ) + s = s.cast(DataType.time(timeunit)) + s_sorted = s.sort() + assert len(s_sorted) == len(s) + assert s_sorted.datatype() == s.datatype() + assert s_sorted.to_pylist() == sorted_order + + s_argsorted = s.argsort() + assert len(s_argsorted) == len(s) + + taken = s.take(s_argsorted) + assert len(taken) == len(s) + assert taken.to_pylist() == sorted_order + + ## Descending + s_sorted = s.sort(descending=True) + assert len(s_sorted) == len(s) + assert s_sorted.datatype() == s.datatype() + assert s_sorted.to_pylist() == sorted_order[::-1] + + s_argsorted = s.argsort(descending=True) + assert len(s_argsorted) == len(s) + + taken = s.take(s_argsorted) + assert len(taken) == len(s) + assert taken.to_pylist() == sorted_order[::-1] + + def test_series_string_sorting() -> None: data = pa.array(["hi", "bye", "thai", None, "2", None, "h", "by"]) sorted_order = ["2", "by", "bye", "h", "hi", "thai", None, None] diff --git a/tests/series/test_take.py b/tests/series/test_take.py index b86015847f..ea68ff515e 100644 --- a/tests/series/test_take.py +++ b/tests/series/test_take.py @@ -45,6 +45,23 @@ def date_maker(d): assert taken.to_pylist() == days[::-1] +@pytest.mark.parametrize("time_unit", ["us", "ns"]) +def test_series_time_take(time_unit) -> None: + from datetime import time + + def time_maker(h, m, s, us): + if us is None: + return None + return time(h, m, s, us) + + times = list(map(time_maker, [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [5, 4, 1, None, 2, None])) + s = Series.from_pylist(times) + s = s.cast(DataType.time(time_unit)) + taken = s.take(Series.from_pylist([5, 4, 3, 2, 1, 0])) + assert taken.datatype() == DataType.time(time_unit) + assert taken.to_pylist() == times[::-1] + + def test_series_binary_take() -> None: data = pa.array([b"1", b"2", b"3", None, b"5", None]) diff --git a/tests/table/test_from_py.py b/tests/table/test_from_py.py index 8f7cf68c99..ec89cd87be 100644 --- a/tests/table/test_from_py.py +++ b/tests/table/test_from_py.py @@ -24,6 +24,7 @@ "str": ["foo", "bar"], "binary": [b"foo", b"bar"], "date": [datetime.date.today(), datetime.date.today()], + "time": [datetime.time(1, 2, 3, 4), datetime.time(5, 6, 7, 8)], "list": [[1, 2], [3]], "struct": [{"a": 1, "b": 2.0}, {"b": 3.0}], "empty_struct": [{}, {}], @@ -41,6 +42,7 @@ "str": DataType.string(), "binary": DataType.binary(), "date": DataType.date(), + "time": DataType.time(TimeUnit.us()), "list": DataType.list(DataType.int64()), "struct": DataType.struct({"a": DataType.int64(), "b": DataType.float64()}), "empty_struct": DataType.struct({"": DataType.null()}), @@ -65,6 +67,7 @@ "str": pa.large_string(), "binary": pa.large_binary(), "date": pa.date32(), + "time": pa.time64("us"), "list": pa.large_list(pa.int64()), "struct": pa.struct({"a": pa.int64(), "b": pa.float64()}), "empty_struct": pa.struct({}), @@ -91,6 +94,8 @@ "binary": pa.array(PYTHON_TYPE_ARRAYS["binary"], pa.binary()), "boolean": pa.array(PYTHON_TYPE_ARRAYS["bool"], pa.bool_()), "date32": pa.array(PYTHON_TYPE_ARRAYS["date"], pa.date32()), + "time64_microseconds": pa.array(PYTHON_TYPE_ARRAYS["time"], pa.time64("us")), + "time64_nanoseconds": pa.array(PYTHON_TYPE_ARRAYS["time"], pa.time64("ns")), "list": pa.array(PYTHON_TYPE_ARRAYS["list"], pa.list_(pa.int64())), "fixed_size_list": pa.array([[1, 2], [3, 4]], pa.list_(pa.int64(), 2)), "struct": pa.array(PYTHON_TYPE_ARRAYS["struct"], pa.struct([("a", pa.int64()), ("b", pa.float64())])), @@ -140,6 +145,8 @@ "binary": pa.large_binary(), "boolean": pa.bool_(), "date32": pa.date32(), + "time64_microseconds": pa.time64("us"), + "time64_nanoseconds": pa.time64("ns"), "list": pa.large_list(pa.int64()), "fixed_size_list": pa.list_(pa.int64(), 2), "struct": pa.struct([("a", pa.int64()), ("b", pa.float64())]), From 0a51db16a56375f0b29b1cb0698e32a72b828a0e Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Fri, 23 Feb 2024 19:41:26 -0800 Subject: [PATCH 08/11] [CHORE] Remove non-`MicroPartition` and non-`ScanOperator` paths (#1946) This PR removes the legacy non-`MicroPartition` and non-`ScanOperator` code paths, refactoring where required. As a drive-by, this PR also refactors the query optimization tests to use structural equality assertions instead of repr-based ones, and adds additional test coverage around scan pushdowns. Closes #1863 Closes #1939 Closes #1945 --- .github/workflows/python-package.yml | 10 - daft/daft.pyi | 11 +- daft/execution/execution_step.py | 137 +--- daft/execution/physical_plan.py | 74 +- daft/execution/rust_physical_plan_shim.py | 34 - daft/io/_csv.py | 4 +- daft/io/_iceberg.py | 2 +- daft/io/_json.py | 4 +- daft/io/_parquet.py | 4 +- daft/io/common.py | 74 +- daft/logical/builder.py | 27 +- daft/logical/schema.py | 6 - .../plan_scheduler/physical_plan_scheduler.py | 6 +- daft/runners/pyrunner.py | 22 +- daft/runners/ray_runner.py | 32 +- daft/runners/runner_io.py | 57 +- daft/table/__init__.py | 15 +- src/daft-core/src/python/schema.rs | 5 - src/daft-plan/src/builder.rs | 100 +-- src/daft-plan/src/logical_ops/project.rs | 18 +- src/daft-plan/src/logical_ops/source.rs | 13 +- .../src/optimization/logical_plan_tracker.rs | 42 +- src/daft-plan/src/optimization/mod.rs | 2 + src/daft-plan/src/optimization/optimizer.rs | 129 ++-- .../optimization/rules/drop_repartition.rs | 77 +- .../optimization/rules/push_down_filter.rs | 675 ++++++++++++------ .../src/optimization/rules/push_down_limit.rs | 259 +++---- .../rules/push_down_projection.rs | 223 +++--- src/daft-plan/src/optimization/test/mod.rs | 44 ++ src/daft-plan/src/physical_ops/csv.rs | 47 +- src/daft-plan/src/physical_ops/explode.rs | 14 +- src/daft-plan/src/physical_ops/json.rs | 47 +- src/daft-plan/src/physical_ops/mod.rs | 6 +- src/daft-plan/src/physical_ops/parquet.rs | 47 +- src/daft-plan/src/physical_ops/project.rs | 17 +- src/daft-plan/src/physical_plan.rs | 214 +----- src/daft-plan/src/planner.rs | 72 +- src/daft-plan/src/source_info/mod.rs | 91 +-- src/daft-plan/src/test/mod.rs | 48 +- src/daft-scan/src/glob.rs | 2 +- src/daft-scan/src/lib.rs | 9 + tests/dataframe/test_creation.py | 15 +- .../io/parquet/test_reads_s3_minio.py | 9 +- tests/io/test_merge_scan_tasks.py | 5 - 44 files changed, 1024 insertions(+), 1725 deletions(-) create mode 100644 src/daft-plan/src/optimization/test/mod.rs diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 69e42b4378..073eb1841a 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -22,7 +22,6 @@ jobs: daft-runner: [py, ray] pyarrow-version: [6.0.1, 12.0] os: [ubuntu, windows] - micropartitions: [1, 0] exclude: - daft-runner: ray pyarrow-version: 6.0.1 @@ -92,7 +91,6 @@ jobs: # cargo llvm-cov --no-run --lcov --output-path report-output/rust-coverage-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.daft-runner }}-${{ matrix.pyarrow-version }}.lcov env: DAFT_RUNNER: ${{ matrix.daft-runner }} - DAFT_MICROPARTITIONS: ${{ matrix.micropartitions }} - name: Build library and Test with pytest (windows) if: ${{ (matrix.os == 'windows') }} run: | @@ -182,7 +180,6 @@ jobs: matrix: python-version: ['3.7'] daft-runner: [py, ray] - micropartitions: [1, 0] steps: - uses: actions/checkout@v4 with: @@ -222,7 +219,6 @@ jobs: pytest tests/integration/test_tpch.py --durations=50 env: DAFT_RUNNER: ${{ matrix.daft-runner }} - DAFT_MICROPARTITIONS: ${{ matrix.micropartitions }} - name: Send Slack notification on failure uses: slackapi/slack-github-action@v1.24.0 if: ${{ failure() && (github.ref == 'refs/heads/main') }} @@ -255,7 +251,6 @@ jobs: matrix: python-version: ['3.8'] # can't use 3.7 due to requiring anon mode for adlfs daft-runner: [py, ray] - micropartitions: [1, 0] steps: - uses: actions/checkout@v4 with: @@ -298,7 +293,6 @@ jobs: pytest tests/integration/io -m 'integration' --durations=50 env: DAFT_RUNNER: ${{ matrix.daft-runner }} - DAFT_MICROPARTITIONS: ${{ matrix.micropartitions }} - name: Send Slack notification on failure uses: slackapi/slack-github-action@v1.24.0 if: ${{ failure() && (github.ref == 'refs/heads/main') }} @@ -333,7 +327,6 @@ jobs: matrix: python-version: ['3.8'] # can't use 3.7 due to requiring anon mode for adlfs daft-runner: [py, ray] - micropartitions: [1, 0] # These permissions are needed to interact with GitHub's OIDC Token endpoint. # This is used in the step "Assume GitHub Actions AWS Credentials" permissions: @@ -395,7 +388,6 @@ jobs: pytest tests/integration/io -m 'integration' --credentials --durations=50 env: DAFT_RUNNER: ${{ matrix.daft-runner }} - DAFT_MICROPARTITIONS: ${{ matrix.micropartitions }} - name: Send Slack notification on failure uses: slackapi/slack-github-action@v1.24.0 if: ${{ failure() }} @@ -430,7 +422,6 @@ jobs: matrix: python-version: ['3.8'] # can't use 3.7 due to requiring anon mode for adlfs daft-runner: [py, ray] - micropartitions: [1] steps: - uses: actions/checkout@v4 with: @@ -475,7 +466,6 @@ jobs: pytest tests/integration/iceberg -m 'integration' --durations=50 env: DAFT_RUNNER: ${{ matrix.daft-runner }} - DAFT_MICROPARTITIONS: ${{ matrix.micropartitions }} - name: Send Slack notification on failure uses: slackapi/slack-github-action@v1.24.0 if: ${{ failure() && (github.ref == 'refs/heads/main') }} diff --git a/daft/daft.pyi b/daft/daft.pyi index a87403a3b2..8f9ac04c74 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -847,7 +847,6 @@ class PySchema: def __getitem__(self, name: str) -> PyField: ... def names(self) -> list[str]: ... def union(self, other: PySchema) -> PySchema: ... - def apply_hints(self, other: PySchema) -> PySchema: ... def eq(self, other: PySchema) -> bool: ... @staticmethod def from_field_name_and_types(names_and_types: list[tuple[str, PyDataType]]) -> PySchema: ... @@ -1164,9 +1163,7 @@ class PhysicalPlanScheduler: def num_partitions(self) -> int: ... def partition_spec(self) -> PartitionSpec: ... def repr_ascii(self, simple: bool) -> str: ... - def to_partition_tasks( - self, psets: dict[str, list[PartitionT]], is_ray_runner: bool - ) -> physical_plan.InProgressPhysicalPlan: ... + def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.InProgressPhysicalPlan: ... class LogicalPlanBuilder: """ @@ -1182,11 +1179,7 @@ class LogicalPlanBuilder: partition_key: str, cache_entry: PartitionCacheEntry, schema: PySchema, num_partitions: int, size_bytes: int ) -> LogicalPlanBuilder: ... @staticmethod - def table_scan_with_scan_operator(scan_operator: ScanOperatorHandle) -> LogicalPlanBuilder: ... - @staticmethod - def table_scan( - file_infos: FileInfos, schema: PySchema, file_format_config: FileFormatConfig, storage_config: StorageConfig - ) -> LogicalPlanBuilder: ... + def table_scan(scan_operator: ScanOperatorHandle) -> LogicalPlanBuilder: ... def project(self, projection: list[PyExpr], resource_request: ResourceRequest) -> LogicalPlanBuilder: ... def filter(self, predicate: PyExpr) -> LogicalPlanBuilder: ... def limit(self, limit: int, eager: bool) -> LogicalPlanBuilder: ... diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 73857cc4bb..5bd5e483df 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -11,18 +11,7 @@ else: from typing import Protocol -from daft.daft import ( - CsvSourceConfig, - FileFormat, - FileFormatConfig, - IOConfig, - JoinType, - JsonReadOptions, - JsonSourceConfig, - ParquetSourceConfig, - ResourceRequest, - StorageConfig, -) +from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest from daft.expressions import Expression, ExpressionsProjection, col from daft.logical.map_partition_ops import MapPartitionOp from daft.logical.schema import Schema @@ -32,8 +21,6 @@ PartialPartitionMetadata, PartitionMetadata, PartitionT, - TableParseCSVOptions, - TableReadOptions, ) from daft.table import MicroPartition, table_io @@ -303,128 +290,6 @@ def num_outputs(self) -> int: return 1 -@dataclass(frozen=True) -class ReadFile(SingleOutputInstruction): - index: int | None - # Known number of rows. - file_rows: int | None - # Max number of rows to read. - limit_rows: int | None - schema: Schema - storage_config: StorageConfig - columns_to_read: list[str] | None - file_format_config: FileFormatConfig - - def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: - return self._read_file(inputs) - - def _read_file(self, inputs: list[MicroPartition]) -> list[MicroPartition]: - assert len(inputs) == 1 - [filepaths_partition] = inputs - partition = self._handle_tabular_files_scan( - filepaths_partition=filepaths_partition, - ) - return [partition] - - def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: - assert len(input_metadatas) == 1 - - num_rows = self.file_rows - # Only take the file read limit into account if we know how big the file is to begin with. - if num_rows is not None and self.limit_rows is not None: - num_rows = min(num_rows, self.limit_rows) - - return [ - PartialPartitionMetadata( - num_rows=num_rows, - size_bytes=None, - ) - ] - - def _handle_tabular_files_scan( - self, - filepaths_partition: MicroPartition, - ) -> MicroPartition: - data = filepaths_partition.to_pydict() - filepaths = data["path"] - - if self.index is not None: - filepaths = [filepaths[self.index]] - - # Common options for reading vPartition - read_options = TableReadOptions( - num_rows=self.limit_rows, - column_names=self.columns_to_read, # read only specified columns - ) - - file_format = self.file_format_config.file_format() - format_config = self.file_format_config.config - if file_format == FileFormat.Csv: - assert isinstance(format_config, CsvSourceConfig) - table = MicroPartition.concat( - [ - table_io.read_csv( - file=fp, - schema=self.schema, - storage_config=self.storage_config, - csv_options=TableParseCSVOptions( - delimiter=format_config.delimiter, - header_index=0 if format_config.has_headers else None, - double_quote=format_config.double_quote, - quote=format_config.quote, - escape_char=format_config.escape_char, - comment=format_config.comment, - buffer_size=format_config.buffer_size, - chunk_size=format_config.chunk_size, - ), - read_options=read_options, - ) - for fp in filepaths - ] - ) - elif file_format == FileFormat.Json: - assert isinstance(format_config, JsonSourceConfig) - table = MicroPartition.concat( - [ - table_io.read_json( - file=fp, - schema=self.schema, - storage_config=self.storage_config, - json_read_options=JsonReadOptions( - buffer_size=format_config.buffer_size, chunk_size=format_config.chunk_size - ), - read_options=read_options, - ) - for fp in filepaths - ] - ) - elif file_format == FileFormat.Parquet: - assert isinstance(format_config, ParquetSourceConfig) - table = MicroPartition.concat( - [ - table_io.read_parquet( - file=fp, - schema=self.schema, - storage_config=self.storage_config, - read_options=read_options, - ) - for fp in filepaths - ] - ) - else: - raise NotImplementedError(f"PyRunner has not implemented scan: {file_format}") - - expected_schema = ( - Schema._from_fields([self.schema[name] for name in read_options.column_names]) - if read_options.column_names is not None - else self.schema - ) - assert ( - table.schema() == expected_schema - ), f"Expected table to have schema:\n{expected_schema}\n\nReceived instead:\n{table.schema()}" - return table - - @dataclass(frozen=True) class WriteFile(SingleOutputInstruction): file_format: FileFormat diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index f7509493b3..6989915e77 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -22,14 +22,7 @@ from typing import Generator, Generic, Iterable, Iterator, TypeVar, Union from daft.context import get_context -from daft.daft import ( - FileFormat, - FileFormatConfig, - IOConfig, - JoinType, - ResourceRequest, - StorageConfig, -) +from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest from daft.execution import execution_step from daft.execution.execution_step import ( Instruction, @@ -84,71 +77,6 @@ def partition_read( ) -def file_read( - child_plan: InProgressPhysicalPlan[PartitionT], - # Max number of rows to read. - limit_rows: int | None, - schema: Schema, - storage_config: StorageConfig, - columns_to_read: list[str] | None, - file_format_config: FileFormatConfig, -) -> InProgressPhysicalPlan[PartitionT]: - """child_plan represents partitions with filenames. - - Yield a plan to read those filenames. - """ - materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque() - stage_id = next(stage_id_counter) - output_partition_index = 0 - - while True: - # Check if any inputs finished executing. - while len(materializations) > 0 and materializations[0].done(): - done_task = materializations.popleft() - - vpartition = done_task.vpartition() - file_infos = vpartition.to_pydict() - file_sizes_bytes = file_infos["size"] - file_rows = file_infos["num_rows"] - - # Emit one partition for each file (NOTE: hardcoded for now). - for i in range(len(vpartition)): - file_read_step = PartitionTaskBuilder[PartitionT]( - inputs=[done_task.partition()], - partial_metadatas=None, # Child's metadata doesn't really matter for a file read - ).add_instruction( - instruction=execution_step.ReadFile( - index=i, - file_rows=file_rows[i], - limit_rows=limit_rows, - schema=schema, - storage_config=storage_config, - columns_to_read=columns_to_read, - file_format_config=file_format_config, - ), - # Set the filesize as the memory request. - # (Note: this is very conservative; file readers empirically use much more peak memory than 1x file size.) - resource_request=ResourceRequest(memory_bytes=file_sizes_bytes[i]), - ) - yield file_read_step - output_partition_index += 1 - - # Materialize a single dependency. - try: - child_step = next(child_plan) - if isinstance(child_step, PartitionTaskBuilder): - child_step = child_step.finalize_partition_task_single_output(stage_id=stage_id) - materializations.append(child_step) - yield child_step - - except StopIteration: - if len(materializations) > 0: - logger.debug("file_read blocked on completion of first source in: %s", materializations) - yield None - else: - return - - def file_write( child_plan: InProgressPhysicalPlan[PartitionT], file_format: FileFormat, diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 25db7bec7a..9b3c94b9f8 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -1,19 +1,15 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Iterator, cast from daft.daft import ( FileFormat, - FileFormatConfig, IOConfig, JoinType, PyExpr, PySchema, - PyTable, ResourceRequest, ScanTask, - StorageConfig, ) from daft.execution import execution_step, physical_plan from daft.expressions import Expression, ExpressionsProjection @@ -94,36 +90,6 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) ] -def tabular_scan( - schema: PySchema, - columns_to_read: list[str] | None, - file_info_table: PyTable, - file_format_config: FileFormatConfig, - storage_config: StorageConfig, - limit: int, - is_ray_runner: bool, -) -> physical_plan.InProgressPhysicalPlan[PartitionT]: - # TODO(Clark): Fix this Ray runner hack. - part = MicroPartition._from_pytable(file_info_table) - if is_ray_runner: - import ray - - parts = [ray.put(part)] - else: - parts = [part] - parts_t = cast(Iterator[PartitionT], parts) - - file_info_iter = physical_plan.partition_read(iter(parts_t)) - return physical_plan.file_read( - child_plan=file_info_iter, - limit_rows=limit, - schema=Schema._from_pyschema(schema), - storage_config=storage_config, - columns_to_read=columns_to_read, - file_format_config=file_format_config, - ) - - def project( input: physical_plan.InProgressPhysicalPlan[PartitionT], projection: list[PyExpr], resource_request: ResourceRequest ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: diff --git a/daft/io/_csv.py b/daft/io/_csv.py index ec3b1f1e7f..aa30959dad 100644 --- a/daft/io/_csv.py +++ b/daft/io/_csv.py @@ -14,7 +14,7 @@ ) from daft.dataframe import DataFrame from daft.datatype import DataType -from daft.io.common import _get_tabular_files_scan +from daft.io.common import get_tabular_files_scan @PublicAPI @@ -84,5 +84,5 @@ def read_csv( storage_config = StorageConfig.native(NativeStorageConfig(True, io_config)) else: storage_config = StorageConfig.python(PythonStorageConfig(io_config=io_config)) - builder = _get_tabular_files_scan(path, schema_hints, file_format_config, storage_config=storage_config) + builder = get_tabular_files_scan(path, schema_hints, file_format_config, storage_config=storage_config) return DataFrame(builder) diff --git a/daft/io/_iceberg.py b/daft/io/_iceberg.py index c936779d3d..fa14dca710 100644 --- a/daft/io/_iceberg.py +++ b/daft/io/_iceberg.py @@ -109,5 +109,5 @@ def read_iceberg( iceberg_operator = IcebergScanOperator(pyiceberg_table, storage_config=storage_config) handle = ScanOperatorHandle.from_python_scan_operator(iceberg_operator) - builder = LogicalPlanBuilder.from_tabular_scan_with_scan_operator(scan_operator=handle) + builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle) return DataFrame(builder) diff --git a/daft/io/_json.py b/daft/io/_json.py index c052fb2c52..95a00fb449 100644 --- a/daft/io/_json.py +++ b/daft/io/_json.py @@ -14,7 +14,7 @@ ) from daft.dataframe import DataFrame from daft.datatype import DataType -from daft.io.common import _get_tabular_files_scan +from daft.io.common import get_tabular_files_scan @PublicAPI @@ -56,5 +56,5 @@ def read_json( storage_config = StorageConfig.native(NativeStorageConfig(True, io_config)) else: storage_config = StorageConfig.python(PythonStorageConfig(io_config=io_config)) - builder = _get_tabular_files_scan(path, schema_hints, file_format_config, storage_config=storage_config) + builder = get_tabular_files_scan(path, schema_hints, file_format_config, storage_config=storage_config) return DataFrame(builder) diff --git a/daft/io/_parquet.py b/daft/io/_parquet.py index a41d82ec64..9f77f28bb0 100644 --- a/daft/io/_parquet.py +++ b/daft/io/_parquet.py @@ -14,7 +14,7 @@ ) from daft.dataframe import DataFrame from daft.datatype import DataType -from daft.io.common import _get_tabular_files_scan +from daft.io.common import get_tabular_files_scan @PublicAPI @@ -61,5 +61,5 @@ def read_parquet( else: storage_config = StorageConfig.python(PythonStorageConfig(io_config=io_config)) - builder = _get_tabular_files_scan(path, schema_hints, file_format_config, storage_config=storage_config) + builder = get_tabular_files_scan(path, schema_hints, file_format_config, storage_config=storage_config) return DataFrame(builder) diff --git a/daft/io/common.py b/daft/io/common.py index a55ab92992..3668963c3d 100644 --- a/daft/io/common.py +++ b/daft/io/common.py @@ -1,16 +1,8 @@ from __future__ import annotations -import os from typing import TYPE_CHECKING -from daft.context import get_context -from daft.daft import ( - FileFormatConfig, - NativeStorageConfig, - PythonStorageConfig, - ScanOperatorHandle, - StorageConfig, -) +from daft.daft import FileFormatConfig, ScanOperatorHandle, StorageConfig from daft.datatype import DataType from daft.logical.builder import LogicalPlanBuilder from daft.logical.schema import Schema @@ -26,7 +18,7 @@ def _get_schema_from_hints(hints: dict[str, DataType]) -> Schema: raise NotImplementedError(f"Unsupported schema hints: {type(hints)}") -def _get_tabular_files_scan( +def get_tabular_files_scan( path: str | list[str], schema_hints: dict[str, DataType] | None, file_format_config: FileFormatConfig, @@ -34,59 +26,23 @@ def _get_tabular_files_scan( ) -> LogicalPlanBuilder: """Returns a TabularFilesScan LogicalPlan for a given glob filepath.""" # Glob the path using the Runner - # NOTE: Globbing will always need the IOConfig, regardless of whether "native reads" are used - io_config = None - if isinstance(storage_config.config, NativeStorageConfig): - io_config = storage_config.config.io_config - elif isinstance(storage_config.config, PythonStorageConfig): - io_config = storage_config.config.io_config - else: - raise NotImplementedError(f"Tabular scan with config not implemented: {storage_config.config}") - schema_hint = _get_schema_from_hints(schema_hints) if schema_hints is not None else None - ### FEATURE_FLAG: $DAFT_MICROPARTITIONS - # - # This environment variable will make Daft use the new "v2 scans" and MicroPartitions when building Daft logical plans - if os.getenv("DAFT_MICROPARTITIONS", "1") == "1": - scan_op: ScanOperatorHandle - if isinstance(path, list): - scan_op = ScanOperatorHandle.glob_scan( - path, - file_format_config, - storage_config, - schema_hint=schema_hint._schema if schema_hint is not None else None, - ) - elif isinstance(path, str): - scan_op = ScanOperatorHandle.glob_scan( - [path], - file_format_config, - storage_config, - schema_hint=schema_hint._schema if schema_hint is not None else None, - ) - else: - raise NotImplementedError(f"_get_tabular_files_scan cannot construct ScanOperatorHandle for input: {path}") - - builder = LogicalPlanBuilder.from_tabular_scan_with_scan_operator( - scan_operator=scan_op, - ) - return builder - - paths = path if isinstance(path, list) else [str(path)] - runner_io = get_context().runner().runner_io() - file_infos = runner_io.glob_paths_details(paths, file_format_config=file_format_config, io_config=io_config) + if isinstance(path, list): + paths = path + elif isinstance(path, str): + paths = [path] + else: + raise NotImplementedError(f"get_tabular_files_scan cannot construct ScanOperatorHandle for input: {path}") - # Infer schema - schema = runner_io.get_schema_from_first_filepath(file_infos, file_format_config, storage_config) + scan_op = ScanOperatorHandle.glob_scan( + paths, + file_format_config, + storage_config, + schema_hint=schema_hint._schema if schema_hint is not None else None, + ) - # Apply hints from schema_hints if provided - if schema_hint is not None: - schema = schema.apply_hints(schema_hint) - # Construct plan builder = LogicalPlanBuilder.from_tabular_scan( - file_infos=file_infos, - schema=schema, - file_format_config=file_format_config, - storage_config=storage_config, + scan_operator=scan_op, ) return builder diff --git a/daft/logical/builder.py b/daft/logical/builder.py index f5ad0451aa..996c52ec64 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -3,22 +3,13 @@ import pathlib from typing import TYPE_CHECKING -from daft.daft import ( - CountMode, - FileFormat, - FileFormatConfig, - FileInfos, - IOConfig, - JoinStrategy, - JoinType, -) +from daft.daft import CountMode, FileFormat, IOConfig, JoinStrategy, JoinType from daft.daft import LogicalPlanBuilder as _LogicalPlanBuilder from daft.daft import ( PartitionScheme, PyDaftExecutionConfig, ResourceRequest, ScanOperatorHandle, - StorageConfig, ) from daft.expressions import Expression, col from daft.logical.schema import Schema @@ -82,25 +73,13 @@ def from_in_memory_scan( ) return cls(builder) - @classmethod - def from_tabular_scan_with_scan_operator( - cls, - *, - scan_operator: ScanOperatorHandle, - ) -> LogicalPlanBuilder: - builder = _LogicalPlanBuilder.table_scan_with_scan_operator(scan_operator) - return cls(builder) - @classmethod def from_tabular_scan( cls, *, - file_infos: FileInfos, - schema: Schema, - file_format_config: FileFormatConfig, - storage_config: StorageConfig, + scan_operator: ScanOperatorHandle, ) -> LogicalPlanBuilder: - builder = _LogicalPlanBuilder.table_scan(file_infos, schema._schema, file_format_config, storage_config) + builder = _LogicalPlanBuilder.table_scan(scan_operator) return cls(builder) def project( diff --git a/daft/logical/schema.py b/daft/logical/schema.py index 9a1950db14..2655885bb9 100644 --- a/daft/logical/schema.py +++ b/daft/logical/schema.py @@ -145,12 +145,6 @@ def union(self, other: Schema) -> Schema: return Schema._from_pyschema(self._schema.union(other._schema)) - def apply_hints(self, other: Schema) -> Schema: - if not isinstance(other, Schema): - raise ValueError(f"Expected Schema, got other: {type(other)}") - - return Schema._from_pyschema(self._schema.apply_hints(other._schema)) - def __reduce__(self) -> tuple: return Schema._from_pyschema, (self._schema,) diff --git a/daft/plan_scheduler/physical_plan_scheduler.py b/daft/plan_scheduler/physical_plan_scheduler.py index 110287de4d..459da8ed3a 100644 --- a/daft/plan_scheduler/physical_plan_scheduler.py +++ b/daft/plan_scheduler/physical_plan_scheduler.py @@ -32,7 +32,5 @@ def pretty_print(self, simple: bool = False) -> str: def __repr__(self) -> str: return self._scheduler.repr_ascii(simple=False) - def to_partition_tasks( - self, psets: dict[str, list[PartitionT]], is_ray_runner: bool - ) -> physical_plan.MaterializedPhysicalPlan: - return physical_plan.materialize(self._scheduler.to_partition_tasks(psets, is_ray_runner)) + def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.MaterializedPhysicalPlan: + return physical_plan.materialize(self._scheduler.to_partition_tasks(psets)) diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 41b435ff24..8bbd6fd03d 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -9,19 +9,12 @@ import psutil from daft.context import get_context -from daft.daft import ( - FileFormatConfig, - FileInfos, - IOConfig, - ResourceRequest, - StorageConfig, -) +from daft.daft import FileFormatConfig, FileInfos, IOConfig, ResourceRequest from daft.execution import physical_plan from daft.execution.execution_step import Instruction, PartitionTask from daft.filesystem import glob_path_with_stats from daft.internal.gpu import cuda_device_count from daft.logical.builder import LogicalPlanBuilder -from daft.logical.schema import Schema from daft.runners import runner_io from daft.runners.partitioning import ( MaterializedResult, @@ -113,17 +106,6 @@ def glob_paths_details( return file_infos - def get_schema_from_first_filepath( - self, - file_infos: FileInfos, - file_format_config: FileFormatConfig, - storage_config: StorageConfig, - ) -> Schema: - if len(file_infos) == 0: - raise ValueError("No files to get schema from") - # Naively retrieve the first filepath in the PartitionSet - return runner_io.sample_schema(file_infos[0].file_path, file_format_config, storage_config) - class PyRunner(Runner[MicroPartition]): def __init__(self, use_thread_pool: bool | None) -> None: @@ -163,7 +145,7 @@ def run_iter( plan_scheduler = builder.to_physical_plan_scheduler(daft_execution_config) psets = {k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()} # Get executable tasks from planner. - tasks = plan_scheduler.to_partition_tasks(psets, is_ray_runner=False) + tasks = plan_scheduler.to_partition_tasks(psets) with profiler("profile_PyRunner.run_{datetime.now().isoformat()}.json"): results_gen = self._physical_plan_to_partitions(tasks) yield from results_gen diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index e2da71b62b..33cd2cf568 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -32,7 +32,6 @@ IOConfig, PyDaftExecutionConfig, ResourceRequest, - StorageConfig, ) from daft.datatype import DataType from daft.execution.execution_step import ( @@ -126,17 +125,6 @@ def _to_pandas_ref(df: pd.DataFrame | ray.ObjectRef[pd.DataFrame]) -> ray.Object raise ValueError("Expected a Ray object ref or a Pandas DataFrame, " f"got {type(df)}") -@ray.remote -def sample_schema_from_filepath( - first_file_path: str, - file_format_config: FileFormatConfig, - storage_config: StorageConfig, -) -> Schema: - """Ray remote function to run schema sampling on top of a MicroPartition containing a single filepath""" - # Currently just samples the Schema from the first file - return runner_io.sample_schema(first_file_path, file_format_config, storage_config) - - @dataclass class RayPartitionSet(PartitionSet[ray.ObjectRef]): _results: dict[PartID, RayMaterializedResult] @@ -244,24 +232,6 @@ def glob_paths_details( ._table ) - def get_schema_from_first_filepath( - self, - file_infos: FileInfos, - file_format_config: FileFormatConfig, - storage_config: StorageConfig, - ) -> Schema: - if len(file_infos) == 0: - raise ValueError("No files to get schema from") - # Naively retrieve the first filepath in the file info table. - first_path = file_infos[0].file_path - return ray.get( - sample_schema_from_filepath.remote( - first_path, - file_format_config, - storage_config, - ) - ) - def partition_set_from_ray_dataset( self, ds: RayDataset, @@ -507,7 +477,7 @@ def _run_plan( result_uuid: str, ) -> None: # Get executable tasks from plan scheduler. - tasks = plan_scheduler.to_partition_tasks(psets, is_ray_runner=True) + tasks = plan_scheduler.to_partition_tasks(psets) daft_execution_config = self.execution_configs_objref_by_df[result_uuid] inflight_tasks: dict[str, PartitionTask[ray.ObjectRef]] = dict() diff --git a/daft/runners/runner_io.py b/daft/runners/runner_io.py index c03c7fcac8..980a855875 100644 --- a/daft/runners/runner_io.py +++ b/daft/runners/runner_io.py @@ -3,19 +3,7 @@ from abc import abstractmethod from typing import TYPE_CHECKING -from daft.daft import ( - CsvSourceConfig, - FileFormat, - FileFormatConfig, - FileInfos, - IOConfig, - JsonSourceConfig, - ParquetSourceConfig, - StorageConfig, -) -from daft.logical.schema import Schema -from daft.runners.partitioning import TableParseCSVOptions -from daft.table import schema_inference +from daft.daft import FileFormatConfig, FileInfos, IOConfig if TYPE_CHECKING: pass @@ -43,46 +31,3 @@ def glob_paths_details( FileInfo: The file infos for the globbed paths. """ raise NotImplementedError() - - @abstractmethod - def get_schema_from_first_filepath( - self, - file_infos: FileInfos, - file_format_config: FileFormatConfig, - storage_config: StorageConfig, - ) -> Schema: - raise NotImplementedError() - - -def sample_schema( - filepath: str, - file_format_config: FileFormatConfig, - storage_config: StorageConfig, -) -> Schema: - """Helper method that samples a schema from the specified source""" - file_format = file_format_config.file_format() - config = file_format_config.config - if file_format == FileFormat.Csv: - assert isinstance(config, CsvSourceConfig) - return schema_inference.from_csv( - file=filepath, - storage_config=storage_config, - csv_options=TableParseCSVOptions( - delimiter=config.delimiter, - header_index=0 if config.has_headers else None, - ), - ) - elif file_format == FileFormat.Json: - assert isinstance(config, JsonSourceConfig) - return schema_inference.from_json( - file=filepath, - storage_config=storage_config, - ) - elif file_format == FileFormat.Parquet: - assert isinstance(config, ParquetSourceConfig) - return schema_inference.from_parquet( - file=filepath, - storage_config=storage_config, - ) - else: - raise NotImplementedError(f"Schema inference for {file_format} not implemented") diff --git a/daft/table/__init__.py b/daft/table/__init__.py index 5b66e7b616..5ecab85779 100644 --- a/daft/table/__init__.py +++ b/daft/table/__init__.py @@ -1,18 +1,7 @@ from __future__ import annotations -import os - -from .table import Table, read_parquet_into_pyarrow, read_parquet_into_pyarrow_bulk - # Need to import after `.table` due to circular dep issue otherwise -from .micropartition import MicroPartition as _MicroPartition # isort:skip - - -MicroPartition = _MicroPartition - -# Use $DAFT_MICROPARTITIONS envvar as a feature flag to turn off MicroPartitions -if os.getenv("DAFT_MICROPARTITIONS", "1") != "1": - MicroPartition = Table # type: ignore - +from .micropartition import MicroPartition +from .table import Table, read_parquet_into_pyarrow, read_parquet_into_pyarrow_bulk __all__ = ["MicroPartition", "Table", "read_parquet_into_pyarrow", "read_parquet_into_pyarrow_bulk"] diff --git a/src/daft-core/src/python/schema.rs b/src/daft-core/src/python/schema.rs index f4f50bdadf..975aeffc5c 100644 --- a/src/daft-core/src/python/schema.rs +++ b/src/daft-core/src/python/schema.rs @@ -33,11 +33,6 @@ impl PySchema { Ok(new_schema.into()) } - pub fn apply_hints(&self, hints: &PySchema) -> PyResult { - let new_schema = Arc::new(self.schema.apply_hints(&hints.schema)?); - Ok(new_schema.into()) - } - pub fn eq(&self, other: &PySchema) -> PyResult { Ok(self.schema.fields.eq(&other.schema.fields)) } diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 6c4b434c11..1fa4819413 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -7,10 +7,7 @@ use crate::{ partitioning::PartitionSchemeConfig, planner::plan, sink_info::{OutputFileInfo, SinkInfo}, - source_info::{ - ExternalInfo as ExternalSourceInfo, FileInfos as InputFileInfos, LegacyExternalInfo, - SourceInfo, - }, + source_info::SourceInfo, JoinStrategy, JoinType, PartitionScheme, PhysicalPlanScheduler, ResourceRequest, }; use common_error::{DaftError, DaftResult}; @@ -18,11 +15,7 @@ use common_io_config::IOConfig; use daft_core::schema::Schema; use daft_core::schema::SchemaRef; use daft_dsl::Expr; -use daft_scan::{ - file_format::{FileFormat, FileFormatConfig}, - storage_config::{PyStorageConfig, StorageConfig}, - Pushdowns, ScanExternalInfo, ScanOperatorRef, -}; +use daft_scan::{file_format::FileFormat, Pushdowns, ScanExternalInfo, ScanOperatorRef}; #[cfg(feature = "python")] use { @@ -30,7 +23,7 @@ use { common_daft_config::PyDaftExecutionConfig, daft_core::python::schema::PySchema, daft_dsl::python::PyExpr, - daft_scan::{file_format::PyFileFormatConfig, python::pylib::ScanOperatorHandle}, + daft_scan::python::pylib::ScanOperatorHandle, pyo3::prelude::*, }; @@ -39,7 +32,7 @@ use { /// /// This builder holds the current root (sink) of the logical plan, and the building methods return /// a brand new builder holding a new plan; i.e., this is an immutable builder. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct LogicalPlanBuilder { // The current root of the logical plan in this builder. pub plan: Arc, @@ -72,56 +65,33 @@ impl LogicalPlanBuilder { Ok(logical_plan.into()) } - pub fn table_scan_with_scan_operator( + pub fn table_scan( scan_operator: ScanOperatorRef, pushdowns: Option, ) -> DaftResult { let schema = scan_operator.0.schema(); let partitioning_keys = scan_operator.0.partitioning_keys(); - let source_info = - SourceInfo::ExternalInfo(ExternalSourceInfo::Scan(ScanExternalInfo::new( - scan_operator.clone(), - schema.clone(), - partitioning_keys.into(), - pushdowns.unwrap_or_default(), - ))); - let logical_plan: LogicalPlan = - logical_ops::Source::new(schema.clone(), source_info.into()).into(); - Ok(logical_plan.into()) - } - - pub fn table_scan( - file_infos: InputFileInfos, - schema: Arc, - file_format_config: Arc, - storage_config: Arc, - ) -> DaftResult { - Self::table_scan_with_pushdowns( - file_infos, - schema, - file_format_config, - storage_config, - Default::default(), - ) - } - - pub fn table_scan_with_pushdowns( - file_infos: InputFileInfos, - schema: Arc, - file_format_config: Arc, - storage_config: Arc, - pushdowns: Pushdowns, - ) -> DaftResult { - let source_info = - SourceInfo::ExternalInfo(ExternalSourceInfo::Legacy(LegacyExternalInfo::new( - schema.clone(), - file_infos.into(), - file_format_config, - storage_config, - pushdowns, - ))); + let source_info = SourceInfo::ExternalInfo(ScanExternalInfo::new( + scan_operator.clone(), + schema.clone(), + partitioning_keys.into(), + pushdowns.clone().unwrap_or_default(), + )); + // If column selection (projection) pushdown is specified, prune unselected columns from the schema. + let output_schema = if let Some(Pushdowns { columns: Some(columns), .. }) = &pushdowns && columns.len() < schema.fields.len() { + let pruned_upstream_schema = schema + .fields + .iter() + .filter_map(|(name, field)| { + columns.contains(name).then(|| field.clone()) + }) + .collect::>(); + Arc::new(Schema::new(pruned_upstream_schema)?) + } else { + schema.clone() + }; let logical_plan: LogicalPlan = - logical_ops::Source::new(schema.clone(), source_info.into()).into(); + logical_ops::Source::new(output_schema, source_info.into()).into(); Ok(logical_plan.into()) } @@ -323,24 +293,8 @@ impl PyLogicalPlanBuilder { } #[staticmethod] - pub fn table_scan_with_scan_operator(scan_operator: ScanOperatorHandle) -> PyResult { - Ok(LogicalPlanBuilder::table_scan_with_scan_operator(scan_operator.into(), None)?.into()) - } - - #[staticmethod] - pub fn table_scan( - file_infos: InputFileInfos, - schema: PySchema, - file_format_config: PyFileFormatConfig, - storage_config: PyStorageConfig, - ) -> PyResult { - Ok(LogicalPlanBuilder::table_scan( - file_infos, - schema.into(), - file_format_config.into(), - storage_config.into(), - )? - .into()) + pub fn table_scan(scan_operator: ScanOperatorHandle) -> PyResult { + Ok(LogicalPlanBuilder::table_scan(scan_operator.into(), None)?.into()) } pub fn project( diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index db87148d11..50ffb2fe78 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -407,7 +407,11 @@ mod tests { use daft_core::{datatypes::Field, DataType}; use daft_dsl::{binary_op, col, lit, Operator}; - use crate::{logical_ops::Project, test::dummy_scan_node, LogicalPlan}; + use crate::{ + logical_ops::Project, + test::{dummy_scan_node, dummy_scan_operator}, + LogicalPlan, + }; /// Test that nested common subexpressions are correctly split /// into multiple levels of projections. @@ -419,10 +423,10 @@ mod tests { /// 3: a+a as aa #[test] fn test_nested_subexpression() -> DaftResult<()> { - let source = dummy_scan_node(vec![ + let source = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Int64), - ]) + ])) .build(); let a2 = binary_op(Operator::Plus, &col("a"), &col("a")); let a4 = binary_op(Operator::Plus, &a2, &a2); @@ -463,10 +467,10 @@ mod tests { /// 2. a+a as aa, a #[test] fn test_shared_subexpression() -> DaftResult<()> { - let source = dummy_scan_node(vec![ + let source = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Int64), - ]) + ])) .build(); let a2 = binary_op(Operator::Plus, &col("a"), &col("a")); let expressions = vec![ @@ -500,10 +504,10 @@ mod tests { /// (unchanged) #[test] fn test_vacuous_subexpression() -> DaftResult<()> { - let source = dummy_scan_node(vec![ + let source = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Int64), - ]) + ])) .build(); let expressions = vec![ lit(3).alias("x"), diff --git a/src/daft-plan/src/logical_ops/source.rs b/src/daft-plan/src/logical_ops/source.rs index c7fbfb9cad..f6cb37b4c1 100644 --- a/src/daft-plan/src/logical_ops/source.rs +++ b/src/daft-plan/src/logical_ops/source.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use daft_core::schema::SchemaRef; use daft_scan::ScanExternalInfo; -use crate::source_info::{ExternalInfo, SourceInfo}; +use crate::source_info::SourceInfo; #[cfg(feature = "python")] use crate::source_info::InMemoryInfo; @@ -30,19 +30,12 @@ impl Source { let mut res = vec![]; match self.source_info.as_ref() { - SourceInfo::ExternalInfo(ExternalInfo::Legacy(legacy_external_info)) => { - res.push(format!( - "Source: {}", - legacy_external_info.file_format_config.var_name() - )); - res.extend(legacy_external_info.multiline_display()); - } - SourceInfo::ExternalInfo(ExternalInfo::Scan(ScanExternalInfo { + SourceInfo::ExternalInfo(ScanExternalInfo { source_schema, scan_op, partitioning_keys, pushdowns, - })) => { + }) => { use itertools::Itertools; res.extend(scan_op.0.multiline_display()); diff --git a/src/daft-plan/src/optimization/logical_plan_tracker.rs b/src/daft-plan/src/optimization/logical_plan_tracker.rs index 238e08fbf1..75efdb2f40 100644 --- a/src/daft-plan/src/optimization/logical_plan_tracker.rs +++ b/src/daft-plan/src/optimization/logical_plan_tracker.rs @@ -71,22 +71,26 @@ mod tests { use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, + sync::Arc, }; use common_error::DaftResult; use daft_core::{datatypes::Field, DataType}; use daft_dsl::{col, lit}; - use crate::{optimization::logical_plan_tracker::LogicalPlanDigest, test::dummy_scan_node}; + use crate::{ + optimization::logical_plan_tracker::LogicalPlanDigest, + test::{dummy_scan_node, dummy_scan_operator}, + }; #[test] fn node_count() -> DaftResult<()> { // plan is Filter -> Concat -> {Projection -> Source, Projection -> Source}, // and should have a node count of 6. - let builder1 = dummy_scan_node(vec![ + let builder1 = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]); + ])); assert_eq!( LogicalPlanDigest::new(builder1.plan.as_ref(), &mut Default::default()).node_count, 1usize.try_into().unwrap() @@ -96,10 +100,10 @@ mod tests { LogicalPlanDigest::new(builder1.plan.as_ref(), &mut Default::default()).node_count, 2usize.try_into().unwrap() ); - let builder2 = dummy_scan_node(vec![ + let builder2 = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]); + ])); assert_eq!( LogicalPlanDigest::new(builder2.plan.as_ref(), &mut Default::default()).node_count, 1usize.try_into().unwrap() @@ -125,20 +129,14 @@ mod tests { #[test] fn same_plans_eq() -> DaftResult<()> { // Both plan1 and plan2 are Filter -> Project -> Source - let plan1 = dummy_scan_node(vec![ - Field::new("a", DataType::Int64), - Field::new("b", DataType::Utf8), - ]) - .project(vec![col("a")], Default::default())? - .filter(col("a").lt(&lit(2)))? - .build(); - let plan2 = dummy_scan_node(vec![ + let plan1 = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) + ])) .project(vec![col("a")], Default::default())? .filter(col("a").lt(&lit(2)))? .build(); + let plan2 = Arc::new(plan1.as_ref().clone()); // Double-check that logical plans are equal. assert_eq!(plan1, plan2); @@ -157,17 +155,17 @@ mod tests { #[test] fn different_plans_not_eq_op_ordering() -> DaftResult<()> { // plan1 is Project -> Filter -> Source, while plan2 is Filter -> Project -> Source. - let plan1 = dummy_scan_node(vec![ + let plan1 = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) + ])) .filter(col("a").lt(&lit(2)))? .project(vec![col("a")], Default::default())? .build(); - let plan2 = dummy_scan_node(vec![ + let plan2 = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) + ])) .project(vec![col("a")], Default::default())? .filter(col("a").lt(&lit(2)))? .build(); @@ -189,17 +187,17 @@ mod tests { #[test] fn different_plans_not_eq_same_order_diff_config() -> DaftResult<()> { // Both plan1 and plan2 are Filter -> Project -> Source, but with different filter predicates. - let plan1 = dummy_scan_node(vec![ + let plan1 = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) + ])) .project(vec![col("a")], Default::default())? .filter(col("a").lt(&lit(2)))? .build(); - let plan2 = dummy_scan_node(vec![ + let plan2 = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) + ])) .project(vec![col("a")], Default::default())? .filter(col("a").lt(&lit(4)))? .build(); diff --git a/src/daft-plan/src/optimization/mod.rs b/src/daft-plan/src/optimization/mod.rs index f7158979ea..82f349e55f 100644 --- a/src/daft-plan/src/optimization/mod.rs +++ b/src/daft-plan/src/optimization/mod.rs @@ -1,6 +1,8 @@ mod logical_plan_tracker; mod optimizer; mod rules; +#[cfg(test)] +mod test; pub use optimizer::Optimizer; pub use rules::Transformed; diff --git a/src/daft-plan/src/optimization/optimizer.rs b/src/daft-plan/src/optimization/optimizer.rs index 4229faeb7d..a554632414 100644 --- a/src/daft-plan/src/optimization/optimizer.rs +++ b/src/daft-plan/src/optimization/optimizer.rs @@ -318,7 +318,7 @@ mod tests { use crate::{ logical_ops::{Filter, Project}, optimization::rules::{ApplyOrder, OptimizerRule, Transformed}, - test::dummy_scan_node, + test::{dummy_scan_node, dummy_scan_operator}, LogicalPlan, }; @@ -336,7 +336,7 @@ mod tests { OptimizerConfig::new(5), ); let plan: Arc = - dummy_scan_node(vec![Field::new("a", DataType::Int64)]).build(); + dummy_scan_node(dummy_scan_operator(vec![Field::new("a", DataType::Int64)])).build(); let mut pass_count = 0; let mut did_transform = false; optimizer.optimize(plan.clone(), |new_plan, _, _, transformed, _| { @@ -391,7 +391,7 @@ mod tests { (col("a") + lit(2)).alias("b"), (col("a") + lit(3)).alias("c"), ]; - let plan = dummy_scan_node(vec![Field::new("a", DataType::Int64)]) + let plan = dummy_scan_node(dummy_scan_operator(vec![Field::new("a", DataType::Int64)])) .project(proj_exprs, Default::default())? .build(); let mut pass_count = 0; @@ -426,7 +426,7 @@ mod tests { (col("a") + lit(2)).alias("b"), (col("a") + lit(3)).alias("c"), ]; - let plan = dummy_scan_node(vec![Field::new("a", DataType::Int64)]) + let plan = dummy_scan_node(dummy_scan_operator(vec![Field::new("a", DataType::Int64)])) .project(proj_exprs, Default::default())? .build(); let mut pass_count = 0; @@ -440,51 +440,6 @@ mod tests { Ok(()) } - #[derive(Debug)] - struct RotateProjection { - reverse_first: Mutex, - } - - impl RotateProjection { - pub fn new(reverse_first: bool) -> Self { - Self { - reverse_first: Mutex::new(reverse_first), - } - } - } - - impl OptimizerRule for RotateProjection { - fn apply_order(&self) -> ApplyOrder { - ApplyOrder::TopDown - } - - fn try_optimize( - &self, - plan: Arc, - ) -> DaftResult>> { - let project = match plan.as_ref() { - LogicalPlan::Project(project) => project.clone(), - _ => return Ok(Transformed::No(plan)), - }; - let mut exprs = project.projection.clone(); - let mut reverse = self.reverse_first.lock().unwrap(); - if *reverse { - exprs.reverse(); - *reverse = false; - } else { - exprs.rotate_left(1); - } - Ok(Transformed::Yes( - LogicalPlan::from(Project::try_new( - project.input.clone(), - exprs, - project.resource_request.clone(), - )?) - .into(), - )) - } - } - /// Tests that the optimizer applies multiple rule batches. /// /// This test creates a Filter -> Projection -> Source plan and has 3 rule batches: @@ -517,16 +472,16 @@ mod tests { ], OptimizerConfig::new(20), ); - let fields = vec![Field::new("a", DataType::Int64)]; let proj_exprs = vec![ col("a") + lit(1), (col("a") + lit(2)).alias("b"), (col("a") + lit(3)).alias("c"), ]; let filter_predicate = col("a").lt(&lit(2)); - let plan = dummy_scan_node(fields.clone()) - .project(proj_exprs, Default::default())? - .filter(filter_predicate)? + let scan_op = dummy_scan_operator(vec![Field::new("a", DataType::Int64)]); + let plan = dummy_scan_node(scan_op.clone()) + .project(proj_exprs.clone(), Default::default())? + .filter(filter_predicate.clone())? .build(); let mut pass_count = 0; let mut did_transform = false; @@ -537,11 +492,24 @@ mod tests { assert!(did_transform); // 3 + 2 + 1 = 6 assert_eq!(pass_count, 6); - let expected = "\ - Filter: [[[col(a) < lit(2)] | lit(false)] | lit(false)] & lit(true)\ - \n Project: col(a) + lit(3) AS c, col(a) + lit(1), col(a) + lit(2) AS b\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, Native storage config = { Use multithreading = true }, Output schema = a#Int64"; - assert_eq!(opt_plan.repr_indent(), expected); + + let mut new_proj_exprs = proj_exprs.clone(); + new_proj_exprs.rotate_left(2); + let new_pred = filter_predicate + .or(&lit(false)) + .or(&lit(false)) + .and(&lit(true)); + let expected = dummy_scan_node(scan_op) + .project(new_proj_exprs, Default::default())? + .filter(new_pred)? + .build(); + assert_eq!( + opt_plan, + expected, + "\n\nOptimized plan not equal to expected.\n\nOptimized:\n{}\n\nExpected:\n{}", + opt_plan.repr_ascii(false), + expected.repr_ascii(false) + ); Ok(()) } @@ -602,4 +570,49 @@ mod tests { )) } } + + #[derive(Debug)] + struct RotateProjection { + reverse_first: Mutex, + } + + impl RotateProjection { + pub fn new(reverse_first: bool) -> Self { + Self { + reverse_first: Mutex::new(reverse_first), + } + } + } + + impl OptimizerRule for RotateProjection { + fn apply_order(&self) -> ApplyOrder { + ApplyOrder::TopDown + } + + fn try_optimize( + &self, + plan: Arc, + ) -> DaftResult>> { + let project = match plan.as_ref() { + LogicalPlan::Project(project) => project.clone(), + _ => return Ok(Transformed::No(plan)), + }; + let mut exprs = project.projection.clone(); + let mut reverse = self.reverse_first.lock().unwrap(); + if *reverse { + exprs.reverse(); + *reverse = false; + } else { + exprs.rotate_left(1); + } + Ok(Transformed::Yes( + LogicalPlan::from(Project::try_new( + project.input.clone(), + exprs, + project.resource_request.clone(), + )?) + .into(), + )) + } + } } diff --git a/src/daft-plan/src/optimization/rules/drop_repartition.rs b/src/daft-plan/src/optimization/rules/drop_repartition.rs index f0449d6a8b..a820e67d0a 100644 --- a/src/daft-plan/src/optimization/rules/drop_repartition.rs +++ b/src/daft-plan/src/optimization/rules/drop_repartition.rs @@ -53,61 +53,54 @@ mod tests { use crate::{ optimization::{ - optimizer::{RuleBatch, RuleExecutionStrategy}, - rules::drop_repartition::DropRepartition, - Optimizer, + rules::drop_repartition::DropRepartition, test::assert_optimized_plan_with_rules_eq, }, - test::dummy_scan_node, + test::{dummy_scan_node, dummy_scan_operator}, LogicalPlan, PartitionSchemeConfig, }; /// Helper that creates an optimizer with the DropRepartition rule registered, optimizes - /// the provided plan with said optimizer, and compares the optimized plan's repr with - /// the provided expected repr. - fn assert_optimized_plan_eq(plan: Arc, expected: &str) -> DaftResult<()> { - let optimizer = Optimizer::with_rule_batches( - vec![RuleBatch::new( - vec![Box::new(DropRepartition::new())], - RuleExecutionStrategy::Once, - )], - Default::default(), - ); - let optimized_plan = optimizer - .optimize_with_rules( - optimizer.rule_batches[0].rules.as_slice(), - plan.clone(), - &optimizer.rule_batches[0].order, - )? - .unwrap() - .clone(); - assert_eq!(optimized_plan.repr_indent(), expected); - - Ok(()) + /// the provided plan with said optimizer, and compares the optimized plan with + /// the provided expected plan. + fn assert_optimized_plan_eq( + plan: Arc, + expected: Arc, + ) -> DaftResult<()> { + assert_optimized_plan_with_rules_eq(plan, expected, vec![Box::new(DropRepartition::new())]) } - /// Tests that DropRepartition does drops the upstream Repartition in back-to-back Repartitions if . + /// Tests that DropRepartition does drops the upstream Repartition in back-to-back Repartitions. /// /// Repartition1-Repartition2 -> Repartition1 #[test] fn repartition_dropped_in_back_to_back() -> DaftResult<()> { - let plan = dummy_scan_node(vec![ + let num_partitions1 = Some(10); + let num_partitions2 = Some(5); + let partition_by = vec![col("a")]; + let partition_scheme_config = PartitionSchemeConfig::Hash(Default::default()); + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) - .repartition( - Some(10), - vec![col("a")], - PartitionSchemeConfig::Hash(Default::default()), - )? - .repartition( - Some(5), - vec![col("a")], - PartitionSchemeConfig::Hash(Default::default()), - )? - .build(); - let expected = "\ - Repartition: Scheme = Hash, Number of partitions = 5, Partition by = col(a)\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Utf8"; + ]); + let plan = dummy_scan_node(scan_op.clone()) + .repartition( + num_partitions1, + partition_by.clone(), + partition_scheme_config.clone(), + )? + .repartition( + num_partitions2, + partition_by.clone(), + partition_scheme_config.clone(), + )? + .build(); + let expected = dummy_scan_node(scan_op) + .repartition( + num_partitions2, + partition_by.clone(), + partition_scheme_config.clone(), + )? + .build(); assert_optimized_plan_eq(plan, expected)?; Ok(()) } diff --git a/src/daft-plan/src/optimization/rules/push_down_filter.rs b/src/daft-plan/src/optimization/rules/push_down_filter.rs index c51de47691..130b73a623 100644 --- a/src/daft-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-plan/src/optimization/rules/push_down_filter.rs @@ -12,11 +12,11 @@ use daft_dsl::{ }, Expr, }; -use daft_scan::{rewrite_predicate_for_partitioning, ScanExternalInfo}; +use daft_scan::rewrite_predicate_for_partitioning; use crate::{ logical_ops::{Concat, Filter, Project, Source}, - source_info::{ExternalInfo, SourceInfo}, + source_info::SourceInfo, LogicalPlan, }; @@ -77,16 +77,10 @@ impl OptimizerRule for PushDownFilter { // Filter pushdown is not supported for in-memory sources. #[cfg(feature = "python")] SourceInfo::InMemoryInfo(_) => return Ok(Transformed::No(plan)), - // Do not pushdown if Source node is already has a limit + // Do not pushdown if Source node already has a limit SourceInfo::ExternalInfo(external_info) if let Some(existing_limit) = - external_info.pushdowns().limit => - { - return Ok(Transformed::No(plan)) - } - // Do not pushdown if we are using python legacy scan info - SourceInfo::ExternalInfo(external_info) - if let ExternalInfo::Legacy(..) = external_info => + external_info.pushdowns.limit => { return Ok(Transformed::No(plan)) } @@ -115,14 +109,10 @@ impl OptimizerRule for PushDownFilter { if has_udf { return Ok(Transformed::No(plan)); } - let new_predicate = external_info.pushdowns().filters.as_ref().map(|f| predicate.and(f)).unwrap_or(predicate.clone()); - let partition_filter = if let ExternalInfo::Scan(ScanExternalInfo {scan_op, ..}) = &external_info { - rewrite_predicate_for_partitioning(new_predicate.clone(), scan_op.0.partitioning_keys())? - } else { - None - }; + let new_predicate = external_info.pushdowns.filters.as_ref().map(|f| predicate.and(f)).unwrap_or(predicate.clone()); + let partition_filter = rewrite_predicate_for_partitioning(new_predicate.clone(), external_info.scan_op.0.partitioning_keys())?; let new_pushdowns = - external_info.pushdowns().with_filters(Some(Arc::new(new_predicate))); + external_info.pushdowns.with_filters(Some(Arc::new(new_predicate))); let new_pushdowns = if let Some(pfilter) = partition_filter { new_pushdowns.with_partition_filters(Some(Arc::new(pfilter))) @@ -285,141 +275,151 @@ mod tests { use common_error::DaftResult; use daft_core::{datatypes::Field, DataType}; use daft_dsl::{col, lit}; + use daft_scan::Pushdowns; + use rstest::rstest; use crate::{ - optimization::{ - optimizer::{RuleBatch, RuleExecutionStrategy}, - rules::PushDownFilter, - Optimizer, - }, - test::{dummy_scan_node, dummy_scan_operator_node}, + optimization::{rules::PushDownFilter, test::assert_optimized_plan_with_rules_eq}, + test::{dummy_scan_node, dummy_scan_node_with_pushdowns, dummy_scan_operator}, JoinType, LogicalPlan, PartitionSchemeConfig, }; /// Helper that creates an optimizer with the PushDownFilter rule registered, optimizes - /// the provided plan with said optimizer, and compares the optimized plan's repr with - /// the provided expected repr. - fn assert_optimized_plan_eq(plan: Arc, expected: &str) -> DaftResult<()> { - let optimizer = Optimizer::with_rule_batches( - vec![RuleBatch::new( - vec![Box::new(PushDownFilter::new())], - RuleExecutionStrategy::Once, - )], - Default::default(), - ); - let optimized_plan = optimizer - .optimize_with_rules( - optimizer.rule_batches[0].rules.as_slice(), - plan.clone(), - &optimizer.rule_batches[0].order, - )? - .unwrap() - .clone(); - assert_eq!(optimized_plan.repr_indent(), expected); - - Ok(()) - } - - /// Tests combining of two Filters by merging their predicates. - #[test] - fn filter_combine_with_filter() -> DaftResult<()> { - let plan = dummy_scan_node(vec![ - Field::new("a", DataType::Int64), - Field::new("b", DataType::Utf8), - ]) - .filter(col("a").lt(&lit(2)))? - .filter(col("b").eq(&lit("foo")))? - .build(); - let expected = "\ - Filter: [col(b) == lit(\"foo\")] & [col(a) < lit(2)]\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Utf8"; - assert_optimized_plan_eq(plan, expected)?; - Ok(()) + /// the provided plan with said optimizer, and compares the optimized plan with + /// the provided expected plan. + fn assert_optimized_plan_eq( + plan: Arc, + expected: Arc, + ) -> DaftResult<()> { + assert_optimized_plan_with_rules_eq(plan, expected, vec![Box::new(PushDownFilter::new())]) } - /// Tests combining of two Filters into a ScanOperator + /// Tests that we can't pushdown a filter into a ScanOperator that has a limit. #[test] - fn pushdown_filter_into_scan_operator() -> DaftResult<()> { - let plan = dummy_scan_operator_node(vec![ + fn filter_not_pushed_down_into_scan_with_limit() -> DaftResult<()> { + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) - .filter(col("a").lt(&lit(2)))? - .filter(col("b").eq(&lit("foo")))? - .build(); - let expected = "\ - AnonymousScanOperator, File paths = [/foo], Use multithreading = true, File schema = a#Int64, b#Utf8, Partitioning keys = [], Filter pushdown = [col(b) == lit(\"foo\")] & [col(a) < lit(2)], Output schema = a#Int64, b#Utf8"; + ]); + let plan = + dummy_scan_node_with_pushdowns(scan_op, Pushdowns::default().with_limit(Some(1))) + .filter(col("a").lt(&lit(2)))? + .build(); + // Plan should be unchanged after optimization. + let expected = plan.clone(); assert_optimized_plan_eq(plan, expected)?; Ok(()) } - /// Tests that we cant pushdown a filter into a ScanOperator with a limit - #[test] - fn pushdown_filter_into_scan_operator_with_limit() -> DaftResult<()> { - let plan = dummy_scan_operator_node(vec![ + /// Tests combining of two Filters by merging their predicates. + #[rstest] + fn filter_combine_with_filter(#[values(false, true)] push_into_scan: bool) -> DaftResult<()> { + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) - .limit(1, false)? - .filter(col("a").lt(&lit(2)))? - .build(); - let expected = "\ - Filter: col(a) < lit(2)\ - \n Limit: 1\ - \n AnonymousScanOperator, File paths = [/foo], Use multithreading = true, File schema = a#Int64, b#Utf8, Partitioning keys = [], Output schema = a#Int64, b#Utf8"; + ]); + let scan_plan = dummy_scan_node_with_pushdowns( + scan_op.clone(), + Pushdowns::default().with_limit(if push_into_scan { None } else { Some(1) }), + ); + let p1 = col("a").lt(&lit(2)); + let p2 = col("b").eq(&lit("foo")); + let plan = scan_plan.filter(p1.clone())?.filter(p2.clone())?.build(); + let merged_filter = p2.and(&p1); + let expected = if push_into_scan { + // Merged filter should be pushed into scan. + dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_filters(Some(merged_filter.into())), + ) + .build() + } else { + // Merged filter should not be pushed into scan. + scan_plan.filter(merged_filter)?.build() + }; assert_optimized_plan_eq(plan, expected)?; Ok(()) } - /// Tests that we cant pushdown a filter into a ScanOperator with an udf-ish expression + /// Tests that we can't pushdown a filter into a ScanOperator if it has an udf-ish expression. #[test] - fn pushdown_filter_into_scan_operator_with_udf() -> DaftResult<()> { + fn filter_with_udf_not_pushed_down_into_scan() -> DaftResult<()> { let pred = daft_dsl::functions::uri::download(&col("a"), 1, true, true, None); - let plan = dummy_scan_operator_node(vec![ + let plan = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) + ])) .filter(pred.is_null())? .build(); - let expected = "\ - Filter: is_null(download(col(a)))\ - \n AnonymousScanOperator, File paths = [/foo], Use multithreading = true, File schema = a#Int64, b#Utf8, Partitioning keys = [], Output schema = a#Int64, b#Utf8"; + let expected = plan.clone(); assert_optimized_plan_eq(plan, expected)?; Ok(()) } /// Tests that Filter commutes with Projections. - #[test] - fn filter_commutes_with_projection() -> DaftResult<()> { - let plan = dummy_scan_node(vec![ + #[rstest] + fn filter_commutes_with_projection( + #[values(false, true)] push_into_scan: bool, + ) -> DaftResult<()> { + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) - .project(vec![col("a")], Default::default())? - .filter(col("a").lt(&lit(2)))? - .build(); - let expected = "\ - Project: col(a)\ - \n Filter: col(a) < lit(2)\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Utf8"; + ]); + let scan_plan = dummy_scan_node_with_pushdowns( + scan_op.clone(), + Pushdowns::default().with_limit(if push_into_scan { None } else { Some(1) }), + ); + let pred = col("a").lt(&lit(2)); + let proj = vec![col("a")]; + let plan = scan_plan + .project(proj.clone(), Default::default())? + .filter(pred.clone())? + .build(); + let expected_scan_filter = if push_into_scan { + dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_filters(Some(pred.into())), + ) + } else { + scan_plan.filter(pred)? + }; + let expected = expected_scan_filter + .project(proj, Default::default())? + .build(); assert_optimized_plan_eq(plan, expected)?; Ok(()) } /// Tests that a Filter with multiple columns in its predicate commutes with a Projection on both of those columns. - #[test] - fn filter_commutes_with_projection_multi() -> DaftResult<()> { - let plan = dummy_scan_node(vec![ + #[rstest] + fn filter_commutes_with_projection_multi( + #[values(false, true)] push_into_scan: bool, + ) -> DaftResult<()> { + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) - .project(vec![col("a"), col("b")], Default::default())? - .filter(col("a").lt(&lit(2)).and(&col("b").eq(&lit("foo"))))? - .build(); - let expected = "\ - Project: col(a), col(b)\ - \n Filter: [col(a) < lit(2)] & [col(b) == lit(\"foo\")]\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Utf8"; + ]); + let scan_plan = dummy_scan_node_with_pushdowns( + scan_op.clone(), + Pushdowns::default().with_limit(if push_into_scan { None } else { Some(1) }), + ); + let pred = col("a").lt(&lit(2)).and(&col("b").eq(&lit("foo"))); + let proj = vec![col("a"), col("b")]; + let plan = scan_plan + .project(proj.clone(), Default::default())? + .filter(pred.clone())? + .build(); + let expected_scan_filter = if push_into_scan { + dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_filters(Some(pred.into())), + ) + } else { + scan_plan.filter(pred)? + }; + let expected = expected_scan_filter + .project(proj, Default::default())? + .build(); assert_optimized_plan_eq(plan, expected)?; Ok(()) } @@ -427,19 +427,16 @@ mod tests { /// Tests that Filter does not commute with a Projection if the projection expression involves compute. #[test] fn filter_does_not_commute_with_projection_if_compute() -> DaftResult<()> { - let plan = dummy_scan_node(vec![ + let plan = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) + ])) // Projection involves compute on filtered column "a". .project(vec![col("a") + lit(1)], Default::default())? .filter(col("a").lt(&lit(2)))? .build(); // Filter should NOT commute with Project, since this would involve redundant computation. - let expected = "\ - Filter: col(a) < lit(2)\ - \n Project: col(a) + lit(1)\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Utf8"; + let expected = plan.clone(); assert_optimized_plan_eq(plan, expected)?; Ok(()) } @@ -447,166 +444,378 @@ mod tests { /// Tests that Filter commutes with Projection if projection expression involves deterministic compute. // REASON - No expression attribute indicating whether deterministic && (pure || idempotent). #[ignore] - #[test] - fn filter_commutes_with_projection_deterministic_compute() -> DaftResult<()> { - let plan = dummy_scan_node(vec![ + #[rstest] + fn filter_commutes_with_projection_deterministic_compute( + #[values(false, true)] push_into_scan: bool, + ) -> DaftResult<()> { + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) - // Projection involves compute on filtered column "a". - .project(vec![col("a") + lit(1)], Default::default())? - .filter(col("a").lt(&lit(2)))? - .build(); - let expected = "\ - Project: col(a) + lit(1)\ - \n Filter: [col(a) + lit(1)] < lit(2)\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Utf8"; + ]); + let scan_plan = dummy_scan_node_with_pushdowns( + scan_op.clone(), + Pushdowns::default().with_limit(if push_into_scan { None } else { Some(1) }), + ); + let pred = col("a").lt(&lit(2)); + let proj = vec![col("a") + lit(1)]; + let plan = scan_plan + // Projection involves compute on filtered column "a". + .project(proj.clone(), Default::default())? + .filter(pred.clone())? + .build(); + let expected_filter_scan = if push_into_scan { + dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_filters(Some(pred.into())), + ) + } else { + scan_plan.filter(pred)? + }; + let expected = expected_filter_scan + .project(proj, Default::default())? + .build(); assert_optimized_plan_eq(plan, expected)?; Ok(()) } /// Tests that Filter commutes with Sort. - #[test] - fn filter_commutes_with_sort() -> DaftResult<()> { - let plan = dummy_scan_node(vec![ + #[rstest] + fn filter_commutes_with_sort(#[values(false, true)] push_into_scan: bool) -> DaftResult<()> { + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) - .sort(vec![col("a")], vec![true])? - .filter(col("a").lt(&lit(2)))? - .build(); - let expected = "\ - Sort: Sort by = (col(a), descending)\ - \n Filter: col(a) < lit(2)\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Utf8"; - // TODO(Clark): For tests in which we only care about reordering of operators, maybe switch to a form that leverages the single-node display? - // let expected = format!("{sort}\n {filter}\n {source}"); + ]); + let scan_plan = dummy_scan_node_with_pushdowns( + scan_op.clone(), + Pushdowns::default().with_limit(if push_into_scan { None } else { Some(1) }), + ); + let pred = col("a").lt(&lit(2)); + let sort_by = vec![col("a")]; + let descending = vec![true]; + let plan = scan_plan + .sort(sort_by.clone(), descending.clone())? + .filter(pred.clone())? + .build(); + let expected_filter_scan = if push_into_scan { + dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_filters(Some(pred.into())), + ) + } else { + scan_plan.filter(pred)? + }; + let expected = expected_filter_scan.sort(sort_by, descending)?.build(); assert_optimized_plan_eq(plan, expected)?; Ok(()) } /// Tests that Filter commutes with Repartition. - #[test] - fn filter_commutes_with_repartition() -> DaftResult<()> { - let plan = dummy_scan_node(vec![ + #[rstest] + fn filter_commutes_with_repartition( + #[values(false, true)] push_into_scan: bool, + ) -> DaftResult<()> { + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) - .repartition( - Some(1), - vec![col("a")], - PartitionSchemeConfig::Hash(Default::default()), - )? - .filter(col("a").lt(&lit(2)))? - .build(); - let expected = "\ - Repartition: Scheme = Hash, Number of partitions = 1, Partition by = col(a)\ - \n Filter: col(a) < lit(2)\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Utf8"; + ]); + let scan_plan = dummy_scan_node_with_pushdowns( + scan_op.clone(), + Pushdowns::default().with_limit(if push_into_scan { None } else { Some(1) }), + ); + let pred = col("a").lt(&lit(2)); + let num_partitions = Some(1); + let repartition_by = vec![col("a")]; + let partition_scheme_config = PartitionSchemeConfig::Hash(Default::default()); + let plan = scan_plan + .repartition( + num_partitions, + repartition_by.clone(), + partition_scheme_config.clone(), + )? + .filter(pred.clone())? + .build(); + let expected_filter_scan = if push_into_scan { + dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_filters(Some(pred.into())), + ) + } else { + scan_plan.filter(pred)? + }; + let expected = expected_filter_scan + .repartition(num_partitions, repartition_by, partition_scheme_config)? + .build(); assert_optimized_plan_eq(plan, expected)?; Ok(()) } /// Tests that Filter commutes with Concat. - #[test] - fn filter_commutes_with_concat() -> DaftResult<()> { - let fields = vec![ + #[rstest] + fn filter_commutes_with_concat( + #[values(false, true)] push_into_left_scan: bool, + #[values(false, true)] push_into_right_scan: bool, + ) -> DaftResult<()> { + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]; - let plan = dummy_scan_node(fields.clone()) - .concat(&dummy_scan_node(fields))? - .filter(col("a").lt(&lit(2)))? + ]); + let left_scan_plan = dummy_scan_node_with_pushdowns( + scan_op.clone(), + Pushdowns::default().with_limit(if push_into_left_scan { None } else { Some(1) }), + ); + let right_scan_plan = dummy_scan_node_with_pushdowns( + scan_op.clone(), + Pushdowns::default().with_limit(if push_into_right_scan { None } else { Some(1) }), + ); + let pred = col("a").lt(&lit(2)); + let plan = left_scan_plan + .concat(&right_scan_plan)? + .filter(pred.clone())? + .build(); + let expected_left_filter_scan = if push_into_left_scan { + dummy_scan_node_with_pushdowns( + scan_op.clone(), + Pushdowns::default().with_filters(Some(pred.clone().into())), + ) + } else { + left_scan_plan.filter(pred.clone())? + }; + let expected_right_filter_scan = if push_into_right_scan { + dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_filters(Some(pred.into())), + ) + } else { + right_scan_plan.filter(pred)? + }; + let expected = expected_left_filter_scan + .concat(&expected_right_filter_scan)? + .build(); + assert_optimized_plan_eq(plan, expected)?; + Ok(()) + } + + /// Tests that Filter commutes with Join. + #[rstest] + fn filter_commutes_with_join( + #[values(false, true)] push_into_left_scan: bool, + #[values(false, true)] push_into_right_scan: bool, + ) -> DaftResult<()> { + let scan_op = dummy_scan_operator(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Utf8), + ]); + let left_scan_plan = dummy_scan_node_with_pushdowns( + scan_op.clone(), + Pushdowns::default().with_limit(if push_into_left_scan { None } else { Some(1) }), + ); + let right_scan_plan = dummy_scan_node_with_pushdowns( + scan_op.clone(), + Pushdowns::default().with_limit(if push_into_right_scan { None } else { Some(1) }), + ); + let join_on = vec![col("b")]; + let pred = col("a").lt(&lit(2)); + let plan = left_scan_plan + .join( + &right_scan_plan, + join_on.clone(), + join_on.clone(), + JoinType::Inner, + None, + )? + .filter(pred.clone())? + .build(); + let expected_left_filter_scan = if push_into_left_scan { + dummy_scan_node_with_pushdowns( + scan_op.clone(), + Pushdowns::default().with_filters(Some(pred.clone().into())), + ) + } else { + left_scan_plan.filter(pred.clone())? + }; + let expected_right_filter_scan = if push_into_right_scan { + dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_filters(Some(pred.into())), + ) + } else { + right_scan_plan.filter(pred)? + }; + let expected = expected_left_filter_scan + .join( + &expected_right_filter_scan, + join_on.clone(), + join_on.clone(), + JoinType::Inner, + None, + )? .build(); - let expected = "\ - Concat\ - \n Filter: col(a) < lit(2)\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Utf8\ - \n Filter: col(a) < lit(2)\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Utf8"; assert_optimized_plan_eq(plan, expected)?; Ok(()) } /// Tests that Filter can be pushed into the left side of a Join. - #[test] - fn filter_commutes_with_join_left_side() -> DaftResult<()> { - let plan = dummy_scan_node(vec![ + #[rstest] + fn filter_commutes_with_join_left_side( + #[values(false, true)] push_into_left_scan: bool, + ) -> DaftResult<()> { + let left_scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) - .join( - &dummy_scan_node(vec![ - Field::new("b", DataType::Utf8), - Field::new("c", DataType::Float64), - ]), - vec![col("b")], - vec![col("b")], - JoinType::Inner, - None, - )? - .filter(col("a").lt(&lit(2)))? - .build(); - let expected = "\ - Join: Type = Inner, Strategy = Auto, On = col(b), Output schema = a#Int64, b#Utf8, c#Float64\ - \n Filter: col(a) < lit(2)\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Utf8\ - \n Source: Json, File paths = [/foo], File schema = b#Utf8, c#Float64, Native storage config = { Use multithreading = true }, Output schema = b#Utf8, c#Float64"; + ]); + let right_scan_op = dummy_scan_operator(vec![ + Field::new("b", DataType::Utf8), + Field::new("c", DataType::Float64), + ]); + let left_scan_plan = dummy_scan_node_with_pushdowns( + left_scan_op.clone(), + Pushdowns::default().with_limit(if push_into_left_scan { None } else { Some(1) }), + ); + let right_scan_plan = dummy_scan_node(right_scan_op.clone()); + let join_on = vec![col("b")]; + let pred = col("a").lt(&lit(2)); + let plan = left_scan_plan + .join( + &right_scan_plan, + join_on.clone(), + join_on.clone(), + JoinType::Inner, + None, + )? + .filter(pred.clone())? + .build(); + let expected_left_filter_scan = if push_into_left_scan { + dummy_scan_node_with_pushdowns( + left_scan_op.clone(), + Pushdowns::default().with_filters(Some(pred.clone().into())), + ) + } else { + left_scan_plan.filter(pred.clone())? + }; + let expected = expected_left_filter_scan + .join( + &right_scan_plan, + join_on.clone(), + join_on.clone(), + JoinType::Inner, + None, + )? + .build(); assert_optimized_plan_eq(plan, expected)?; Ok(()) } /// Tests that Filter can be pushed into the right side of a Join. - #[test] - fn filter_commutes_with_join_right_side() -> DaftResult<()> { - let plan = dummy_scan_node(vec![ + #[rstest] + fn filter_commutes_with_join_right_side( + #[values(false, true)] push_into_right_scan: bool, + ) -> DaftResult<()> { + let left_scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) - .join( - &dummy_scan_node(vec![ - Field::new("b", DataType::Utf8), - Field::new("c", DataType::Float64), - ]), - vec![col("b")], - vec![col("b")], - JoinType::Inner, - None, - )? - .filter(col("c").lt(&lit(2.0)))? - .build(); - let expected = "\ - Join: Type = Inner, Strategy = Auto, On = col(b), Output schema = a#Int64, b#Utf8, c#Float64\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Utf8\ - \n Filter: col(c) < lit(2.0)\ - \n Source: Json, File paths = [/foo], File schema = b#Utf8, c#Float64, Native storage config = { Use multithreading = true }, Output schema = b#Utf8, c#Float64"; + ]); + let right_scan_op = dummy_scan_operator(vec![ + Field::new("b", DataType::Utf8), + Field::new("c", DataType::Float64), + ]); + let left_scan_plan = dummy_scan_node(left_scan_op.clone()); + let right_scan_plan = dummy_scan_node_with_pushdowns( + right_scan_op.clone(), + Pushdowns::default().with_limit(if push_into_right_scan { None } else { Some(1) }), + ); + let join_on = vec![col("b")]; + let pred = col("c").lt(&lit(2.0)); + let plan = left_scan_plan + .join( + &right_scan_plan, + join_on.clone(), + join_on.clone(), + JoinType::Inner, + None, + )? + .filter(pred.clone())? + .build(); + let expected_right_filter_scan = if push_into_right_scan { + dummy_scan_node_with_pushdowns( + right_scan_op.clone(), + Pushdowns::default().with_filters(Some(pred.clone().into())), + ) + } else { + right_scan_plan.filter(pred.clone())? + }; + let expected = left_scan_plan + .join( + &expected_right_filter_scan, + join_on.clone(), + join_on.clone(), + JoinType::Inner, + None, + )? + .build(); assert_optimized_plan_eq(plan, expected)?; Ok(()) } - /// Tests that Filter can be pushed into both sides of a Join. - #[test] - fn filter_commutes_with_join_both_sides() -> DaftResult<()> { - let plan = dummy_scan_node(vec![ - Field::new("a", DataType::Int64), + /// Tests that Filter on join key commutes with Join. + #[rstest] + fn filter_commutes_with_join_on_join_key( + #[values(false, true)] push_into_left_scan: bool, + #[values(false, true)] push_into_right_scan: bool, + ) -> DaftResult<()> { + let left_scan_op = dummy_scan_operator(vec![ + Field::new("a", DataType::Utf8), Field::new("b", DataType::Int64), Field::new("c", DataType::Float64), - ]) - .join( - &dummy_scan_node(vec![Field::new("b", DataType::Int64)]), - vec![col("b")], - vec![col("b")], - JoinType::Inner, - None, - )? - .filter(col("b").lt(&lit(2)))? - .build(); - let expected = "\ - Join: Type = Inner, Strategy = Auto, On = col(b), Output schema = a#Int64, b#Int64, c#Float64\ - \n Filter: col(b) < lit(2)\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Int64, c#Float64, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Int64, c#Float64\ - \n Filter: col(b) < lit(2)\ - \n Source: Json, File paths = [/foo], File schema = b#Int64, Native storage config = { Use multithreading = true }, Output schema = b#Int64"; + ]); + let right_scan_op = dummy_scan_operator(vec![ + Field::new("b", DataType::Int64), + Field::new("d", DataType::Boolean), + ]); + let left_scan_plan = dummy_scan_node_with_pushdowns( + left_scan_op.clone(), + Pushdowns::default().with_limit(if push_into_left_scan { None } else { Some(1) }), + ); + let right_scan_plan = dummy_scan_node_with_pushdowns( + right_scan_op.clone(), + Pushdowns::default().with_limit(if push_into_right_scan { None } else { Some(1) }), + ); + let join_on = vec![col("b")]; + let pred = col("b").lt(&lit(2)); + let plan = left_scan_plan + .join( + &right_scan_plan, + join_on.clone(), + join_on.clone(), + JoinType::Inner, + None, + )? + .filter(pred.clone())? + .build(); + let expected_left_filter_scan = if push_into_left_scan { + dummy_scan_node_with_pushdowns( + left_scan_op.clone(), + Pushdowns::default().with_filters(Some(pred.clone().into())), + ) + } else { + left_scan_plan.filter(pred.clone())? + }; + let expected_right_filter_scan = if push_into_right_scan { + dummy_scan_node_with_pushdowns( + right_scan_op, + Pushdowns::default().with_filters(Some(pred.into())), + ) + } else { + right_scan_plan.filter(pred)? + }; + let expected = expected_left_filter_scan + .join( + &expected_right_filter_scan, + join_on.clone(), + join_on.clone(), + JoinType::Inner, + None, + )? + .build(); assert_optimized_plan_eq(plan, expected)?; Ok(()) } diff --git a/src/daft-plan/src/optimization/rules/push_down_limit.rs b/src/daft-plan/src/optimization/rules/push_down_limit.rs index dab79c3607..dbafeee1d9 100644 --- a/src/daft-plan/src/optimization/rules/push_down_limit.rs +++ b/src/daft-plan/src/optimization/rules/push_down_limit.rs @@ -1,11 +1,10 @@ use std::sync::Arc; use common_error::DaftResult; -use daft_scan::ScanExternalInfo; use crate::{ logical_ops::{Limit as LogicalLimit, Source}, - source_info::{ExternalInfo, SourceInfo}, + source_info::SourceInfo, LogicalPlan, }; @@ -57,27 +56,21 @@ impl OptimizerRule for PushDownLimit { // Do not pushdown if Source node is already more limited than `limit` SourceInfo::ExternalInfo(external_info) if let Some(existing_limit) = - external_info.pushdowns().limit && existing_limit <= limit => + external_info.pushdowns.limit && existing_limit <= limit => { Ok(Transformed::No(plan)) } // Pushdown limit into the Source node as a "local" limit SourceInfo::ExternalInfo(external_info) => { let new_pushdowns = - external_info.pushdowns().with_limit(Some(limit)); + external_info.pushdowns.with_limit(Some(limit)); let new_external_info = external_info.with_pushdowns(new_pushdowns); let new_source = LogicalPlan::Source(Source::new( source.output_schema.clone(), SourceInfo::ExternalInfo(new_external_info).into(), )) .into(); - let out_plan = - match external_info { - ExternalInfo::Scan(ScanExternalInfo { - scan_op, .. - }) if scan_op.0.can_absorb_limit() => new_source, - _ => plan.with_new_children(&[new_source]).into(), - }; + let out_plan = if external_info.scan_op.0.can_absorb_limit() { new_source } else { plan.with_new_children(&[new_source]).into() }; Ok(Transformed::Yes(out_plan)) } } @@ -120,46 +113,26 @@ mod tests { use daft_core::{datatypes::Field, schema::Schema, DataType}; use daft_dsl::col; use daft_scan::Pushdowns; + use rstest::rstest; use std::sync::Arc; #[cfg(feature = "python")] use pyo3::Python; use crate::{ - optimization::{ - optimizer::{RuleBatch, RuleExecutionStrategy}, - rules::PushDownLimit, - Optimizer, - }, - test::{ - dummy_scan_node, dummy_scan_node_with_pushdowns, - dummy_scan_operator_node_with_pushdowns, - }, + optimization::{rules::PushDownLimit, test::assert_optimized_plan_with_rules_eq}, + test::{dummy_scan_node, dummy_scan_node_with_pushdowns, dummy_scan_operator}, LogicalPlan, PartitionSchemeConfig, }; /// Helper that creates an optimizer with the PushDownLimit rule registered, optimizes - /// the provided plan with said optimizer, and compares the optimized plan's repr with - /// the provided expected repr. - fn assert_optimized_plan_eq(plan: Arc, expected: &str) -> DaftResult<()> { - let optimizer = Optimizer::with_rule_batches( - vec![RuleBatch::new( - vec![Box::new(PushDownLimit::new())], - RuleExecutionStrategy::Once, - )], - Default::default(), - ); - let optimized_plan = optimizer - .optimize_with_rules( - optimizer.rule_batches[0].rules.as_slice(), - plan.clone(), - &optimizer.rule_batches[0].order, - )? - .unwrap() - .clone(); - assert_eq!(optimized_plan.repr_indent(), expected); - - Ok(()) + /// the provided plan with said optimizer, and compares the optimized plan with + /// the provided expected plan. + fn assert_optimized_plan_eq( + plan: Arc, + expected: Arc, + ) -> DaftResult<()> { + assert_optimized_plan_with_rules_eq(plan, expected, vec![Box::new(PushDownLimit::new())]) } /// Tests that Limit pushes into external Source. @@ -167,122 +140,101 @@ mod tests { /// Limit-Source -> Source[with_limit] #[test] fn limit_pushes_into_external_source() -> DaftResult<()> { - let plan = dummy_scan_node(vec![ + let limit = 5; + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) - .limit(5, false)? + ]); + let plan = dummy_scan_node(scan_op.clone()) + .limit(limit, false)? + .build(); + let expected = dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_limit(Some(limit as usize)), + ) + .limit(limit, false)? .build(); - let expected = "\ - Limit: 5\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Limit pushdown = 5, Output schema = a#Int64, b#Utf8"; assert_optimized_plan_eq(plan, expected)?; Ok(()) } - /// Tests that Limit does not push into external Source with existing smaller limit. + /// Tests that Limit does not push into scan with existing smaller limit. /// /// Limit-Source[existing_limit] -> Source[existing_limit] #[test] - fn limit_does_not_push_into_external_source_if_smaller_limit() -> DaftResult<()> { + fn limit_does_not_push_into_scan_if_smaller_limit() -> DaftResult<()> { + let limit = 5; + let existing_limit = 3; + let scan_op = dummy_scan_operator(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Utf8), + ]); let plan = dummy_scan_node_with_pushdowns( - vec![ - Field::new("a", DataType::Int64), - Field::new("b", DataType::Utf8), - ], - Pushdowns::default().with_limit(Some(3)), + scan_op.clone(), + Pushdowns::default().with_limit(Some(existing_limit)), ) - .limit(5, false)? + .limit(limit, false)? .build(); - let expected = "\ - Limit: 5\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Limit pushdown = 3, Output schema = a#Int64, b#Utf8"; - assert_optimized_plan_eq(plan, expected)?; - Ok(()) - } - - /// Tests that Limit does not push into external Source with existing smaller limit. - /// - /// Limit[x]-Limit[y] -> Limit[min(x,y)] - #[test] - fn limit_folds_with_smaller_limit() -> DaftResult<()> { - let plan = dummy_scan_node_with_pushdowns( - vec![ - Field::new("a", DataType::Int64), - Field::new("b", DataType::Utf8), - ], - Pushdowns::default(), + let expected = dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_limit(Some(existing_limit)), ) - .limit(5, false)? - .limit(10, false)? + .limit(limit, false)? .build(); - let expected = "\ - Limit: 5\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Limit pushdown = 5, Output schema = a#Int64, b#Utf8"; assert_optimized_plan_eq(plan, expected)?; Ok(()) } - /// Tests that Limit does not push into external Source with existing smaller limit. + /// Tests that Limit does push into scan with existing larger limit. /// - /// Limit[x]-Limit[y] -> Limit[min(x,y)] + /// Limit-Source[existing_limit] -> Source[new_limit] #[test] - fn limit_folds_with_large_limit() -> DaftResult<()> { + fn limit_does_push_into_scan_if_larger_limit() -> DaftResult<()> { + let limit = 5; + let existing_limit = 10; + let scan_op = dummy_scan_operator(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Utf8), + ]); let plan = dummy_scan_node_with_pushdowns( - vec![ - Field::new("a", DataType::Int64), - Field::new("b", DataType::Utf8), - ], - Pushdowns::default().with_limit(Some(20)), + scan_op.clone(), + Pushdowns::default().with_limit(Some(existing_limit)), ) - .limit(10, false)? - .limit(5, false)? + .limit(limit, false)? .build(); - let expected = "\ - Limit: 5\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Limit pushdown = 5, Output schema = a#Int64, b#Utf8"; - assert_optimized_plan_eq(plan, expected)?; - Ok(()) - } - - /// Tests that Limit does push into external Source with existing larger limit. - /// - /// Limit-Source[existing_limit] -> Source[new_limit] - #[test] - fn limit_does_push_into_external_source_if_larger_limit() -> DaftResult<()> { - let plan = dummy_scan_node_with_pushdowns( - vec![ - Field::new("a", DataType::Int64), - Field::new("b", DataType::Utf8), - ], - Pushdowns::default().with_limit(Some(10)), + let expected = dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_limit(Some(limit as usize)), ) - .limit(5, false)? + .limit(limit, false)? .build(); - let expected = "\ - Limit: 5\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Limit pushdown = 5, Output schema = a#Int64, b#Utf8"; assert_optimized_plan_eq(plan, expected)?; Ok(()) } - /// Tests that Limit does push into external Source with existing larger limit. + /// Tests that multiple adjacent Limits fold into the smallest limit. /// - /// Limit-Source[existing_limit] -> Source[new_limit] - #[test] - fn limit_does_push_into_scan_operator_if_larger_limit() -> DaftResult<()> { - let plan = dummy_scan_operator_node_with_pushdowns( - vec![ - Field::new("a", DataType::Int64), - Field::new("b", DataType::Utf8), - ], - Pushdowns::default().with_limit(Some(10)), + /// Limit[x]-Limit[y] -> Limit[min(x,y)] + #[rstest] + fn limit_folds_with_smaller_limit( + #[values(false, true)] smaller_first: bool, + ) -> DaftResult<()> { + let smaller_limit = 5; + let limit = 10; + let scan_op = dummy_scan_operator(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Utf8), + ]); + let plan = dummy_scan_node(scan_op.clone()) + .limit(if smaller_first { smaller_limit } else { limit }, false)? + .limit(if smaller_first { limit } else { smaller_limit }, false)? + .build(); + let expected = dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_limit(Some(smaller_limit as usize)), ) - .limit(5, false)? + .limit(smaller_limit, false)? .build(); - let expected = "\ - Limit: 5\ - \n AnonymousScanOperator, File paths = [/foo], Use multithreading = true, File schema = a#Int64, b#Utf8, Partitioning keys = [], Limit pushdown = 5, Output schema = a#Int64, b#Utf8"; assert_optimized_plan_eq(plan, expected)?; Ok(()) } @@ -299,10 +251,7 @@ mod tests { LogicalPlanBuilder::in_memory_scan("foo", py_obj, schema, Default::default(), 5)? .limit(5, false)? .build(); - let expected = "\ - Limit: 5\ - \n . Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Utf8"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_eq(plan.clone(), plan)?; Ok(()) } @@ -311,21 +260,29 @@ mod tests { /// Limit-Repartition-Source -> Repartition-Source[with_limit] #[test] fn limit_commutes_with_repartition() -> DaftResult<()> { - let plan = dummy_scan_node(vec![ + let limit = 5; + let num_partitions = Some(1); + let partition_by = vec![col("a")]; + let partition_scheme_config = PartitionSchemeConfig::Hash(Default::default()); + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) - .repartition( - Some(1), - vec![col("a")], - PartitionSchemeConfig::Hash(Default::default()), - )? - .limit(5, false)? + ]); + let plan = dummy_scan_node(scan_op.clone()) + .repartition( + num_partitions, + partition_by.clone(), + partition_scheme_config.clone(), + )? + .limit(limit, false)? + .build(); + let expected = dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_limit(Some(limit as usize)), + ) + .limit(limit, false)? + .repartition(num_partitions, partition_by, partition_scheme_config)? .build(); - let expected = "\ - Repartition: Scheme = Hash, Number of partitions = 1, Partition by = col(a)\ - \n Limit: 5\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Limit pushdown = 5, Output schema = a#Int64, b#Utf8"; assert_optimized_plan_eq(plan, expected)?; Ok(()) } @@ -335,17 +292,23 @@ mod tests { /// Limit-Project-Source -> Project-Source[with_limit] #[test] fn limit_commutes_with_projection() -> DaftResult<()> { - let plan = dummy_scan_node(vec![ + let limit = 5; + let proj = vec![col("a")]; + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) - .project(vec![col("a")], Default::default())? - .limit(5, false)? + ]); + let plan = dummy_scan_node(scan_op.clone()) + .project(proj.clone(), Default::default())? + .limit(limit, false)? + .build(); + let expected = dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_limit(Some(limit as usize)), + ) + .limit(limit, false)? + .project(proj, Default::default())? .build(); - let expected = "\ - Project: col(a)\ - \n Limit: 5\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Utf8, Native storage config = { Use multithreading = true }, Limit pushdown = 5, Output schema = a#Int64, b#Utf8"; assert_optimized_plan_eq(plan, expected)?; Ok(()) } diff --git a/src/daft-plan/src/optimization/rules/push_down_projection.rs b/src/daft-plan/src/optimization/rules/push_down_projection.rs index afcce3f2b3..3e6aad50b3 100644 --- a/src/daft-plan/src/optimization/rules/push_down_projection.rs +++ b/src/daft-plan/src/optimization/rules/push_down_projection.rs @@ -159,7 +159,7 @@ impl PushDownProjection { let new_source: LogicalPlan = Source::new( schema.into(), Arc::new(SourceInfo::ExternalInfo(external_info.with_pushdowns( - external_info.pushdowns().with_columns(Some(Arc::new( + external_info.pushdowns.with_columns(Some(Arc::new( required_columns.iter().cloned().collect(), ))), ))), @@ -504,39 +504,26 @@ mod tests { use common_error::DaftResult; use daft_core::{datatypes::Field, DataType}; use daft_dsl::{col, lit}; + use daft_scan::Pushdowns; use crate::{ - optimization::{ - optimizer::{RuleBatch, RuleExecutionStrategy}, - rules::PushDownProjection, - Optimizer, - }, - test::dummy_scan_node, + optimization::{rules::PushDownProjection, test::assert_optimized_plan_with_rules_eq}, + test::{dummy_scan_node, dummy_scan_node_with_pushdowns, dummy_scan_operator}, LogicalPlan, }; /// Helper that creates an optimizer with the PushDownProjection rule registered, optimizes - /// the provided plan with said optimizer, and compares the optimized plan's repr with - /// the provided expected repr. - fn assert_optimized_plan_eq(plan: Arc, expected: &str) -> DaftResult<()> { - let optimizer = Optimizer::with_rule_batches( - vec![RuleBatch::new( - vec![Box::new(PushDownProjection::new())], - RuleExecutionStrategy::Once, - )], - Default::default(), - ); - let optimized_plan = optimizer - .optimize_with_rules( - optimizer.rule_batches[0].rules.as_slice(), - plan.clone(), - &optimizer.rule_batches[0].order, - )? - .unwrap() - .clone(); - assert_eq!(optimized_plan.repr_indent(), expected); - - Ok(()) + /// the provided plan with said optimizer, and compares the optimized plan with + /// the provided expected plan. + fn assert_optimized_plan_eq( + plan: Arc, + expected: Arc, + ) -> DaftResult<()> { + assert_optimized_plan_with_rules_eq( + plan, + expected, + vec![Box::new(PushDownProjection::new())], + ) } /// Projection merging: Ensure factored projections do not get merged. @@ -546,12 +533,12 @@ mod tests { let a4 = &a2 + &a2; let a8 = &a4 + &a4; let expressions = vec![a8.alias("x")]; - let unoptimized = dummy_scan_node(vec![Field::new("a", DataType::Int64)]) + let scan_op = dummy_scan_operator(vec![Field::new("a", DataType::Int64)]); + let plan = dummy_scan_node(scan_op) .project(expressions, Default::default())? .build(); - let expected = unoptimized.repr_indent(); - assert_optimized_plan_eq(unoptimized, expected.as_str())?; + assert_optimized_plan_eq(plan.clone(), plan)?; Ok(()) } @@ -559,57 +546,65 @@ mod tests { /// in both the parent and the child. #[test] fn test_merge_projections() -> DaftResult<()> { - let unoptimized = dummy_scan_node(vec![ + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Int64), - ]) - .project( - vec![col("a") + lit(1), col("b") + lit(2), col("a").alias("c")], - Default::default(), - )? - .project( - vec![col("a") + lit(3), col("b"), col("c") + lit(4)], - Default::default(), - )? - .build(); + ]); + let proj1 = vec![col("a") + lit(1), col("b") + lit(2), col("a").alias("c")]; + let proj2 = vec![col("a") + lit(3), col("b"), col("c") + lit(4)]; + let plan = dummy_scan_node(scan_op.clone()) + .project(proj1, Default::default())? + .project(proj2, Default::default())? + .build(); + + let merged_proj = vec![ + col("a") + lit(1) + lit(3), + col("b") + lit(2), + col("a").alias("c") + lit(4), + ]; + let expected = dummy_scan_node(scan_op) + .project(merged_proj, Default::default())? + .build(); - let expected = "\ - Project: [col(a) + lit(1)] + lit(3), col(b) + lit(2), col(a) + lit(4)\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Int64, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Int64"; - assert_optimized_plan_eq(unoptimized, expected)?; + assert_optimized_plan_eq(plan, expected)?; Ok(()) } /// Projection dropping: Test that a no-op projection is dropped. #[test] fn test_drop_projection() -> DaftResult<()> { - let unoptimized = dummy_scan_node(vec![ + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Int64), - ]) - .project(vec![col("a"), col("b")], Default::default())? - .build(); + ]); + let plan = dummy_scan_node(scan_op.clone()) + .project(vec![col("a"), col("b")], Default::default())? + .build(); + + let expected = dummy_scan_node(scan_op).build(); - let expected = "\ - Source: Json, File paths = [/foo], File schema = a#Int64, b#Int64, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Int64"; - assert_optimized_plan_eq(unoptimized, expected)?; + assert_optimized_plan_eq(plan, expected)?; Ok(()) } + /// Projection dropping: Test that projections doing reordering are not dropped. #[test] fn test_dont_drop_projection() -> DaftResult<()> { - let unoptimized = dummy_scan_node(vec![ + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Int64), - ]) - .project(vec![col("b"), col("a")], Default::default())? - .build(); + ]); + let proj = vec![col("b"), col("a")]; + let plan = dummy_scan_node(scan_op.clone()) + .project(proj.clone(), Default::default())? + .build(); - let expected = "\ - Project: col(b), col(a)\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Int64, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Int64"; - assert_optimized_plan_eq(unoptimized, expected)?; + let expected = dummy_scan_node(scan_op) + .project(proj, Default::default())? + .build(); + + assert_optimized_plan_eq(plan, expected)?; Ok(()) } @@ -617,17 +612,24 @@ mod tests { /// Projection<-Source #[test] fn test_projection_source() -> DaftResult<()> { - let unoptimized = dummy_scan_node(vec![ + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Int64), - ]) - .project(vec![col("b") + lit(3)], Default::default())? + ]); + let proj = vec![col("b") + lit(3)]; + let plan = dummy_scan_node(scan_op.clone()) + .project(proj.clone(), Default::default())? + .build(); + + let proj_pushdown = vec!["b".to_string()]; + let expected = dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_columns(Some(Arc::new(proj_pushdown))), + ) + .project(proj, Default::default())? .build(); - let expected = "\ - Project: col(b) + lit(3)\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Int64, Native storage config = { Use multithreading = true }, Projection pushdown = [b], Output schema = b#Int64"; - assert_optimized_plan_eq(unoptimized, expected)?; + assert_optimized_plan_eq(plan, expected)?; Ok(()) } @@ -635,25 +637,24 @@ mod tests { /// Projection<-Projection column pruning #[test] fn test_projection_projection() -> DaftResult<()> { - let unoptimized = dummy_scan_node(vec![ + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Int64), - ]) - .project( - vec![col("b") + lit(3), col("a"), col("a").alias("x")], - Default::default(), - )? - .project( - vec![col("a"), col("b"), col("b").alias("c")], - Default::default(), - )? - .build(); + ]); + let proj1 = vec![col("b") + lit(3), col("a"), col("a").alias("x")]; + let proj2 = vec![col("a"), col("b"), col("b").alias("c")]; + let plan = dummy_scan_node(scan_op.clone()) + .project(proj1, Default::default())? + .project(proj2.clone(), Default::default())? + .build(); - let expected = "\ - Project: col(a), col(b), col(b) AS c\ - \n Project: col(b) + lit(3), col(a)\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Int64, Native storage config = { Use multithreading = true }, Output schema = a#Int64, b#Int64"; - assert_optimized_plan_eq(unoptimized, expected)?; + let new_proj1 = vec![col("b") + lit(3), col("a")]; + let expected = dummy_scan_node(scan_op) + .project(new_proj1, Default::default())? + .project(proj2, Default::default())? + .build(); + + assert_optimized_plan_eq(plan, expected)?; Ok(()) } @@ -661,20 +662,30 @@ mod tests { /// Projection<-Aggregation column pruning #[test] fn test_projection_aggregation() -> DaftResult<()> { - let unoptimized = dummy_scan_node(vec![ + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Int64), Field::new("c", DataType::Int64), - ]) - .aggregate(vec![col("a").mean(), col("b").mean()], vec![col("c")])? - .project(vec![col("a")], Default::default())? + ]); + let agg = vec![col("a").mean(), col("b").mean()]; + let group_by = vec![col("c")]; + let proj = vec![col("a")]; + let plan = dummy_scan_node(scan_op.clone()) + .aggregate(agg, group_by.clone())? + .project(proj.clone(), Default::default())? + .build(); + + let proj_pushdown = vec!["a".to_string(), "c".to_string()]; + let new_agg = vec![col("a").mean()]; + let expected = dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_columns(Some(Arc::new(proj_pushdown))), + ) + .aggregate(new_agg, group_by)? + .project(proj, Default::default())? .build(); - let expected = "\ - Project: col(a)\ - \n Aggregation: mean(col(a)), Group by = col(c), Output schema = c#Int64, a#Float64\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Int64, c#Int64, Native storage config = { Use multithreading = true }, Projection pushdown = [a, c], Output schema = a#Int64, c#Int64"; - assert_optimized_plan_eq(unoptimized, expected)?; + assert_optimized_plan_eq(plan, expected)?; Ok(()) } @@ -682,20 +693,28 @@ mod tests { /// Projection<-X pushes down the combined required columns #[test] fn test_projection_pushdown() -> DaftResult<()> { - let unoptimized = dummy_scan_node(vec![ + let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Boolean), Field::new("c", DataType::Int64), - ]) - .filter(col("b"))? - .project(vec![col("a")], Default::default())? + ]); + let pred = col("b"); + let proj = vec![col("a")]; + let plan = dummy_scan_node(scan_op.clone()) + .filter(pred.clone())? + .project(proj.clone(), Default::default())? + .build(); + + let proj_pushdown = vec!["a".to_string(), "b".to_string()]; + let expected = dummy_scan_node_with_pushdowns( + scan_op, + Pushdowns::default().with_columns(Some(Arc::new(proj_pushdown))), + ) + .filter(pred)? + .project(proj, Default::default())? .build(); - let expected = "\ - Project: col(a)\ - \n Filter: col(b)\ - \n Source: Json, File paths = [/foo], File schema = a#Int64, b#Boolean, c#Int64, Native storage config = { Use multithreading = true }, Projection pushdown = [a, b], Output schema = a#Int64, b#Boolean"; - assert_optimized_plan_eq(unoptimized, expected)?; + assert_optimized_plan_eq(plan, expected)?; Ok(()) } diff --git a/src/daft-plan/src/optimization/test/mod.rs b/src/daft-plan/src/optimization/test/mod.rs new file mode 100644 index 0000000000..0e09c84eec --- /dev/null +++ b/src/daft-plan/src/optimization/test/mod.rs @@ -0,0 +1,44 @@ +use std::sync::Arc; + +use common_error::DaftResult; + +use crate::{ + optimization::{ + optimizer::{RuleBatch, RuleExecutionStrategy}, + Optimizer, + }, + LogicalPlan, +}; + +use super::optimizer::OptimizerRuleInBatch; + +/// Helper that creates an optimizer with the provided rules registered, optimizes +/// the provided plan with said optimizer, and compares the optimized plan with +/// the provided expected plan. +pub fn assert_optimized_plan_with_rules_eq( + plan: Arc, + expected: Arc, + rules: Vec>, +) -> DaftResult<()> { + let optimizer = Optimizer::with_rule_batches( + vec![RuleBatch::new(rules, RuleExecutionStrategy::Once)], + Default::default(), + ); + let optimized_plan = optimizer + .optimize_with_rules( + optimizer.rule_batches[0].rules.as_slice(), + plan.clone(), + &optimizer.rule_batches[0].order, + )? + .unwrap() + .clone(); + assert_eq!( + optimized_plan, + expected, + "\n\nOptimized plan not equal to expected.\n\nOptimized:\n{}\n\nExpected:\n{}", + optimized_plan.repr_ascii(false), + expected.repr_ascii(false) + ); + + Ok(()) +} diff --git a/src/daft-plan/src/physical_ops/csv.rs b/src/daft-plan/src/physical_ops/csv.rs index eed7ab7889..ccff3ff7c4 100644 --- a/src/daft-plan/src/physical_ops/csv.rs +++ b/src/daft-plan/src/physical_ops/csv.rs @@ -1,53 +1,8 @@ -use std::sync::Arc; - use daft_core::schema::SchemaRef; -use daft_scan::Pushdowns; -use crate::{ - physical_plan::PhysicalPlanRef, sink_info::OutputFileInfo, source_info::LegacyExternalInfo, - PartitionSpec, -}; +use crate::{physical_plan::PhysicalPlanRef, sink_info::OutputFileInfo}; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] -pub struct TabularScanCsv { - pub projection_schema: SchemaRef, - pub external_info: LegacyExternalInfo, - pub partition_spec: Arc, - pub pushdowns: Pushdowns, -} - -impl TabularScanCsv { - pub(crate) fn new( - projection_schema: SchemaRef, - external_info: LegacyExternalInfo, - partition_spec: Arc, - pushdowns: Pushdowns, - ) -> Self { - Self { - projection_schema, - external_info, - partition_spec, - pushdowns, - } - } - - pub fn multiline_display(&self) -> Vec { - let mut res = vec![]; - res.push("TabularScanCsv:".to_string()); - res.push(format!( - "Projection schema = {}", - self.projection_schema.short_string() - )); - res.extend(self.external_info.multiline_display()); - res.push(format!( - "Partition spec = {{ {} }}", - self.partition_spec.multiline_display().join(", ") - )); - res - } -} - #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct TabularWriteCsv { pub schema: SchemaRef, diff --git a/src/daft-plan/src/physical_ops/explode.rs b/src/daft-plan/src/physical_ops/explode.rs index b0e6308ddc..7bb29785d7 100644 --- a/src/daft-plan/src/physical_ops/explode.rs +++ b/src/daft-plan/src/physical_ops/explode.rs @@ -82,18 +82,22 @@ mod tests { use daft_core::{datatypes::Field, DataType}; use daft_dsl::{col, Expr}; - use crate::{planner::plan, test::dummy_scan_node, PartitionSchemeConfig, PartitionSpec}; + use crate::{ + planner::plan, + test::{dummy_scan_node, dummy_scan_operator}, + PartitionSchemeConfig, PartitionSpec, + }; /// do not destroy the partition spec. #[test] fn test_partition_spec_preserving() -> DaftResult<()> { let cfg = DaftExecutionConfig::default().into(); - let logical_plan = dummy_scan_node(vec![ + let logical_plan = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::List(Box::new(DataType::Int64))), Field::new("c", DataType::Int64), - ]) + ])) .repartition( Some(3), vec![Expr::Column("a".into())], @@ -123,11 +127,11 @@ mod tests { fn test_partition_spec_destroying() -> DaftResult<()> { let cfg = DaftExecutionConfig::default().into(); - let logical_plan = dummy_scan_node(vec![ + let logical_plan = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::List(Box::new(DataType::Int64))), Field::new("c", DataType::Int64), - ]) + ])) .repartition( Some(3), vec![Expr::Column("a".into()), Expr::Column("b".into())], diff --git a/src/daft-plan/src/physical_ops/json.rs b/src/daft-plan/src/physical_ops/json.rs index 6d0f51229d..1aea9c4a41 100644 --- a/src/daft-plan/src/physical_ops/json.rs +++ b/src/daft-plan/src/physical_ops/json.rs @@ -1,53 +1,8 @@ -use std::sync::Arc; - use daft_core::schema::SchemaRef; -use daft_scan::Pushdowns; -use crate::{ - physical_plan::PhysicalPlanRef, sink_info::OutputFileInfo, source_info::LegacyExternalInfo, - PartitionSpec, -}; +use crate::{physical_plan::PhysicalPlanRef, sink_info::OutputFileInfo}; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct TabularScanJson { - pub projection_schema: SchemaRef, - pub external_info: LegacyExternalInfo, - pub partition_spec: Arc, - pub pushdowns: Pushdowns, -} - -impl TabularScanJson { - pub(crate) fn new( - projection_schema: SchemaRef, - external_info: LegacyExternalInfo, - partition_spec: Arc, - pushdowns: Pushdowns, - ) -> Self { - Self { - projection_schema, - external_info, - partition_spec, - pushdowns, - } - } - - pub fn multiline_display(&self) -> Vec { - let mut res = vec![]; - res.push("TabularScanJson:".to_string()); - res.push(format!( - "Projection schema = {}", - self.projection_schema.short_string() - )); - res.extend(self.external_info.multiline_display()); - res.push(format!( - "Partition spec = {{ {} }}", - self.partition_spec.multiline_display().join(", ") - )); - res - } -} - #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct TabularWriteJson { pub schema: SchemaRef, diff --git a/src/daft-plan/src/physical_ops/mod.rs b/src/daft-plan/src/physical_ops/mod.rs index 849a9470cb..346462f796 100644 --- a/src/daft-plan/src/physical_ops/mod.rs +++ b/src/daft-plan/src/physical_ops/mod.rs @@ -27,7 +27,7 @@ pub use agg::Aggregate; pub use broadcast_join::BroadcastJoin; pub use coalesce::Coalesce; pub use concat::Concat; -pub use csv::{TabularScanCsv, TabularWriteCsv}; +pub use csv::TabularWriteCsv; pub use empty_scan::EmptyScan; pub use explode::Explode; pub use fanout::{FanoutByHash, FanoutByRange, FanoutRandom}; @@ -36,10 +36,10 @@ pub use flatten::Flatten; pub use hash_join::HashJoin; #[cfg(feature = "python")] pub use in_memory::InMemoryScan; -pub use json::{TabularScanJson, TabularWriteJson}; +pub use json::TabularWriteJson; pub use limit::Limit; pub use monotonically_increasing_id::MonotonicallyIncreasingId; -pub use parquet::{TabularScanParquet, TabularWriteParquet}; +pub use parquet::TabularWriteParquet; pub use project::Project; pub use reduce::ReduceMerge; pub use sample::Sample; diff --git a/src/daft-plan/src/physical_ops/parquet.rs b/src/daft-plan/src/physical_ops/parquet.rs index a93cac5fc2..371babc4ae 100644 --- a/src/daft-plan/src/physical_ops/parquet.rs +++ b/src/daft-plan/src/physical_ops/parquet.rs @@ -1,53 +1,8 @@ -use std::sync::Arc; - use daft_core::schema::SchemaRef; -use daft_scan::Pushdowns; -use crate::{ - physical_plan::PhysicalPlanRef, sink_info::OutputFileInfo, - source_info::LegacyExternalInfo as ExternalSourceInfo, PartitionSpec, -}; +use crate::{physical_plan::PhysicalPlanRef, sink_info::OutputFileInfo}; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct TabularScanParquet { - pub projection_schema: SchemaRef, - pub external_info: ExternalSourceInfo, - pub partition_spec: Arc, - pub pushdowns: Pushdowns, -} - -impl TabularScanParquet { - pub(crate) fn new( - projection_schema: SchemaRef, - external_info: ExternalSourceInfo, - partition_spec: Arc, - pushdowns: Pushdowns, - ) -> Self { - Self { - projection_schema, - external_info, - partition_spec, - pushdowns, - } - } - - pub fn multiline_display(&self) -> Vec { - let mut res = vec![]; - res.push("TabularScanParquet:".to_string()); - res.push(format!( - "Projection schema = {}", - self.projection_schema.short_string() - )); - res.extend(self.external_info.multiline_display()); - res.push(format!( - "Partition spec = {{ {} }}", - self.partition_spec.multiline_display().join(", ") - )); - res - } -} - #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct TabularWriteParquet { pub schema: SchemaRef, diff --git a/src/daft-plan/src/physical_ops/project.rs b/src/daft-plan/src/physical_ops/project.rs index 5fba44d476..5a86035fe2 100644 --- a/src/daft-plan/src/physical_ops/project.rs +++ b/src/daft-plan/src/physical_ops/project.rs @@ -235,7 +235,10 @@ mod tests { use rstest::rstest; use crate::{ - partitioning::PartitionSchemeConfig, planner::plan, test::dummy_scan_node, PartitionSpec, + partitioning::PartitionSchemeConfig, + planner::plan, + test::{dummy_scan_node, dummy_scan_operator}, + PartitionSpec, }; /// Test that projections preserving column inputs, even through aliasing, @@ -248,11 +251,11 @@ mod tests { col("b"), col("a").alias("aa"), ]; - let logical_plan = dummy_scan_node(vec![ + let logical_plan = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Int64), Field::new("c", DataType::Int64), - ]) + ])) .repartition( Some(3), vec![Expr::Column("a".into()), Expr::Column("b".into())], @@ -292,11 +295,11 @@ mod tests { use crate::partitioning::PartitionSchemeConfig; let cfg = DaftExecutionConfig::default().into(); - let logical_plan = dummy_scan_node(vec![ + let logical_plan = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Int64), Field::new("c", DataType::Int64), - ]) + ])) .repartition( Some(3), vec![Expr::Column("a".into()), Expr::Column("b".into())], @@ -324,11 +327,11 @@ mod tests { let cfg = DaftExecutionConfig::default().into(); let expressions = vec![col("a").alias("y"), col("a"), col("a").alias("z"), col("b")]; - let logical_plan = dummy_scan_node(vec![ + let logical_plan = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Int64), Field::new("c", DataType::Int64), - ]) + ])) .repartition( Some(3), vec![Expr::Column("a".into()), Expr::Column("b".into())], diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index e306d1e593..02d311171c 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -1,21 +1,12 @@ #[cfg(feature = "python")] use { - crate::{ - partitioning::PyPartitionSpec, - sink_info::OutputFileInfo, - source_info::{FileInfos, InMemoryInfo, LegacyExternalInfo}, - }, + crate::{partitioning::PyPartitionSpec, sink_info::OutputFileInfo, source_info::InMemoryInfo}, common_io_config::IOConfig, daft_core::python::schema::PySchema, daft_core::schema::SchemaRef, daft_dsl::python::PyExpr, daft_dsl::Expr, - daft_scan::{ - file_format::{FileFormat, FileFormatConfig, PyFileFormatConfig}, - python::pylib::PyScanTask, - storage_config::{PyStorageConfig, StorageConfig}, - Pushdowns, - }, + daft_scan::{file_format::FileFormat, python::pylib::PyScanTask}, pyo3::{ pyclass, pymethods, types::PyBytes, PyObject, PyRef, PyRefMut, PyResult, PyTypeInfo, Python, ToPyObject, @@ -41,9 +32,6 @@ pub(crate) type PhysicalPlanRef = Arc; pub enum PhysicalPlan { #[cfg(feature = "python")] InMemoryScan(InMemoryScan), - TabularScanParquet(TabularScanParquet), - TabularScanCsv(TabularScanCsv), - TabularScanJson(TabularScanJson), TabularScan(TabularScan), EmptyScan(EmptyScan), Project(Project), @@ -78,11 +66,6 @@ impl PhysicalPlan { Self::InMemoryScan(InMemoryScan { partition_spec, .. }) => partition_spec.clone(), Self::TabularScan(TabularScan { partition_spec, .. }) => partition_spec.clone(), Self::EmptyScan(EmptyScan { partition_spec, .. }) => partition_spec.clone(), - Self::TabularScanParquet(TabularScanParquet { partition_spec, .. }) => { - partition_spec.clone() - } - Self::TabularScanCsv(TabularScanCsv { partition_spec, .. }) => partition_spec.clone(), - Self::TabularScanJson(TabularScanJson { partition_spec, .. }) => partition_spec.clone(), Self::Project(Project { partition_spec, .. }) => partition_spec.clone(), Self::Filter(Filter { input, .. }) => input.partition_spec(), Self::Limit(Limit { input, .. }) => input.partition_spec(), @@ -279,10 +262,6 @@ impl PhysicalPlan { } // TODO(Clark): Approximate post-aggregation sizes via grouping estimates + aggregation type. Self::Aggregate(_) => None, - // No size approximation support for legacy I/O. - Self::TabularScanParquet(_) | Self::TabularScanCsv(_) | Self::TabularScanJson(_) => { - None - } // Post-write DataFrame will contain paths to files that were written. // TODO(Clark): Estimate output size via root directory and estimates for # of partitions given partitioning column. Self::TabularWriteParquet(_) | Self::TabularWriteCsv(_) | Self::TabularWriteJson(_) => { @@ -295,11 +274,7 @@ impl PhysicalPlan { match self { #[cfg(feature = "python")] Self::InMemoryScan(..) => vec![], - Self::TabularScan(..) - | Self::EmptyScan(..) - | Self::TabularScanParquet(..) - | Self::TabularScanCsv(..) - | Self::TabularScanJson(..) => vec![], + Self::TabularScan(..) | Self::EmptyScan(..) => vec![], Self::Project(Project { input, .. }) => vec![input], Self::Filter(Filter { input, .. }) => vec![input], Self::Limit(Limit { input, .. }) => vec![input], @@ -335,10 +310,7 @@ impl PhysicalPlan { #[cfg(feature = "python")] Self::InMemoryScan(..) => panic!("Source nodes don't have children, with_new_children() should never be called for source ops"), Self::TabularScan(..) - | Self::EmptyScan(..) - | Self::TabularScanParquet(..) - | Self::TabularScanCsv(..) - | Self::TabularScanJson(..) => panic!("Source nodes don't have children, with_new_children() should never be called for source ops"), + | Self::EmptyScan(..) => panic!("Source nodes don't have children, with_new_children() should never be called for source ops"), Self::Project(Project { projection, resource_request, partition_spec, .. }) => Self::Project(Project::try_new( input.clone(), projection.clone(), resource_request.clone(), partition_spec.clone(), ).unwrap()), @@ -364,10 +336,7 @@ impl PhysicalPlan { #[cfg(feature = "python")] Self::InMemoryScan(..) => panic!("Source nodes don't have children, with_new_children() should never be called for source ops"), Self::TabularScan(..) - | Self::EmptyScan(..) - | Self::TabularScanParquet(..) - | Self::TabularScanCsv(..) - | Self::TabularScanJson(..) => panic!("Source nodes don't have children, with_new_children() should never be called for source ops"), + | Self::EmptyScan(..) => panic!("Source nodes don't have children, with_new_children() should never be called for source ops"), Self::HashJoin(HashJoin { left_on, right_on, join_type, .. }) => Self::HashJoin(HashJoin::new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), *join_type)), Self::BroadcastJoin(BroadcastJoin { left_on, @@ -390,9 +359,6 @@ impl PhysicalPlan { Self::InMemoryScan(..) => "InMemoryScan", Self::TabularScan(..) => "TabularScan", Self::EmptyScan(..) => "EmptyScan", - Self::TabularScanParquet(..) => "TabularScanParquet", - Self::TabularScanCsv(..) => "TabularScanCsv", - Self::TabularScanJson(..) => "TabularScanJson", Self::Project(..) => "Project", Self::Filter(..) => "Filter", Self::Limit(..) => "Limit", @@ -425,11 +391,6 @@ impl PhysicalPlan { Self::InMemoryScan(in_memory_scan) => in_memory_scan.multiline_display(), Self::TabularScan(tabular_scan) => tabular_scan.multiline_display(), Self::EmptyScan(empty_scan) => empty_scan.multiline_display(), - Self::TabularScanParquet(tabular_scan_parquet) => { - tabular_scan_parquet.multiline_display() - } - Self::TabularScanCsv(tabular_scan_csv) => tabular_scan_csv.multiline_display(), - Self::TabularScanJson(tabular_scan_json) => tabular_scan_json.multiline_display(), Self::Project(project) => project.multiline_display(), Self::Filter(filter) => filter.multiline_display(), Self::Limit(limit) => limit.multiline_display(), @@ -493,12 +454,8 @@ impl PhysicalPlanScheduler { Ok(self.plan.repr_ascii(simple)) } /// Converts the contained physical plan into an iterator of executable partition tasks. - pub fn to_partition_tasks( - &self, - psets: HashMap>, - is_ray_runner: bool, - ) -> PyResult { - Python::with_gil(|py| self.plan.to_partition_tasks(py, &psets, is_ray_runner)) + pub fn to_partition_tasks(&self, psets: HashMap>) -> PyResult { + Python::with_gil(|py| self.plan.to_partition_tasks(py, &psets)) } } @@ -530,46 +487,6 @@ impl PartitionIterator { } } -#[cfg(feature = "python")] -#[allow(clippy::too_many_arguments)] -fn tabular_scan( - py: Python<'_>, - source_schema: &SchemaRef, - projection_schema: &SchemaRef, - file_infos: &Arc, - file_format_config: &Arc, - storage_config: &Arc, - pushdowns: &Pushdowns, - is_ray_runner: bool, -) -> PyResult { - let columns_to_read = if projection_schema.names() != source_schema.names() { - Some( - projection_schema - .fields - .iter() - .map(|(name, _)| name) - .cloned() - .collect::>(), - ) - } else { - None - }; - let py_iter = py - .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? - .getattr(pyo3::intern!(py, "tabular_scan"))? - .call1(( - PySchema::from(source_schema.clone()), - columns_to_read, - file_infos.to_table()?, - PyFileFormatConfig::from(file_format_config.clone()), - PyStorageConfig::from(storage_config.clone()), - pushdowns.limit, - is_ray_runner, - ))?; - - Ok(py_iter.into()) -} - #[allow(clippy::too_many_arguments)] #[cfg(feature = "python")] fn tabular_write( @@ -612,7 +529,6 @@ impl PhysicalPlan { &self, py: Python<'_>, psets: &HashMap>, - is_ray_runner: bool, ) -> PyResult { match self { PhysicalPlan::InMemoryScan(InMemoryScan { @@ -655,79 +571,13 @@ impl PhysicalPlan { Ok(py_iter.into()) } - PhysicalPlan::TabularScanParquet(TabularScanParquet { - projection_schema, - external_info: - LegacyExternalInfo { - source_schema, - file_infos, - file_format_config, - storage_config, - pushdowns, - .. - }, - .. - }) => tabular_scan( - py, - source_schema, - projection_schema, - file_infos, - file_format_config, - storage_config, - pushdowns, - is_ray_runner, - ), - PhysicalPlan::TabularScanCsv(TabularScanCsv { - projection_schema, - external_info: - LegacyExternalInfo { - source_schema, - file_infos, - file_format_config, - storage_config, - pushdowns, - .. - }, - .. - }) => tabular_scan( - py, - source_schema, - projection_schema, - file_infos, - file_format_config, - storage_config, - pushdowns, - is_ray_runner, - ), - PhysicalPlan::TabularScanJson(TabularScanJson { - projection_schema, - external_info: - LegacyExternalInfo { - source_schema, - file_infos, - file_format_config, - storage_config, - pushdowns, - .. - }, - .. - }) => tabular_scan( - py, - source_schema, - projection_schema, - file_infos, - file_format_config, - storage_config, - pushdowns, - is_ray_runner, - ), PhysicalPlan::Project(Project { input, projection, resource_request, .. }) => { - let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_iter = input.to_partition_tasks(py, psets)?; let projection_pyexprs: Vec = projection .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -739,7 +589,7 @@ impl PhysicalPlan { Ok(py_iter.into()) } PhysicalPlan::Filter(Filter { input, predicate }) => { - let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_iter = input.to_partition_tasks(py, psets)?; let expressions_mod = py.import(pyo3::intern!(py, "daft.expressions.expressions"))?; let py_predicate = expressions_mod @@ -769,7 +619,7 @@ impl PhysicalPlan { eager, num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_iter = input.to_partition_tasks(py, psets)?; let py_physical_plan = py.import(pyo3::intern!(py, "daft.execution.physical_plan"))?; let global_limit_iter = py_physical_plan @@ -780,7 +630,7 @@ impl PhysicalPlan { PhysicalPlan::Explode(Explode { input, to_explode, .. }) => { - let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_iter = input.to_partition_tasks(py, psets)?; let explode_pyexprs: Vec = to_explode .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -797,7 +647,7 @@ impl PhysicalPlan { with_replacement, seed, }) => { - let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_iter = input.to_partition_tasks(py, psets)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? .getattr(pyo3::intern!(py, "sample"))? @@ -808,7 +658,7 @@ impl PhysicalPlan { input, column_name, }) => { - let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_iter = input.to_partition_tasks(py, psets)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "monotonically_increasing_id"))? @@ -821,7 +671,7 @@ impl PhysicalPlan { descending, num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_iter = input.to_partition_tasks(py, psets)?; let sort_by_pyexprs: Vec = sort_by .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -842,7 +692,7 @@ impl PhysicalPlan { input_num_partitions, output_num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_iter = input.to_partition_tasks(py, psets)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "split"))? @@ -850,7 +700,7 @@ impl PhysicalPlan { Ok(py_iter.into()) } PhysicalPlan::Flatten(Flatten { input }) => { - let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_iter = input.to_partition_tasks(py, psets)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "flatten_plan"))? @@ -861,7 +711,7 @@ impl PhysicalPlan { input, num_partitions, }) => { - let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_iter = input.to_partition_tasks(py, psets)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "fanout_random"))? @@ -873,7 +723,7 @@ impl PhysicalPlan { num_partitions, partition_by, }) => { - let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_iter = input.to_partition_tasks(py, psets)?; let partition_by_pyexprs: Vec = partition_by .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -888,7 +738,7 @@ impl PhysicalPlan { "FanoutByRange not implemented, since only use case (sorting) doesn't need it yet." ), PhysicalPlan::ReduceMerge(ReduceMerge { input }) => { - let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_iter = input.to_partition_tasks(py, psets)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? .getattr(pyo3::intern!(py, "reduce_merge"))? @@ -901,7 +751,7 @@ impl PhysicalPlan { input, .. }) => { - let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_iter = input.to_partition_tasks(py, psets)?; let aggs_as_pyexprs: Vec = aggregations .iter() .map(|agg_expr| PyExpr::from(Expr::Agg(agg_expr.clone()))) @@ -921,7 +771,7 @@ impl PhysicalPlan { num_from, num_to, }) => { - let upstream_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_iter = input.to_partition_tasks(py, psets)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "coalesce"))? @@ -929,8 +779,8 @@ impl PhysicalPlan { Ok(py_iter.into()) } PhysicalPlan::Concat(Concat { other, input }) => { - let upstream_input_iter = input.to_partition_tasks(py, psets, is_ray_runner)?; - let upstream_other_iter = other.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_input_iter = input.to_partition_tasks(py, psets)?; + let upstream_other_iter = other.to_partition_tasks(py, psets)?; let py_iter = py .import(pyo3::intern!(py, "daft.execution.physical_plan"))? .getattr(pyo3::intern!(py, "concat"))? @@ -945,8 +795,8 @@ impl PhysicalPlan { join_type, .. }) => { - let upstream_left_iter = left.to_partition_tasks(py, psets, is_ray_runner)?; - let upstream_right_iter = right.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_left_iter = left.to_partition_tasks(py, psets)?; + let upstream_right_iter = right.to_partition_tasks(py, psets)?; let left_on_pyexprs: Vec = left_on .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -977,8 +827,8 @@ impl PhysicalPlan { left_is_larger, needs_presort, }) => { - let left_iter = left.to_partition_tasks(py, psets, is_ray_runner)?; - let right_iter = right.to_partition_tasks(py, psets, is_ray_runner)?; + let left_iter = left.to_partition_tasks(py, psets)?; + let right_iter = right.to_partition_tasks(py, psets)?; let left_on_pyexprs: Vec = left_on .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -1022,8 +872,8 @@ impl PhysicalPlan { join_type, is_swapped, }) => { - let upstream_left_iter = left.to_partition_tasks(py, psets, is_ray_runner)?; - let upstream_right_iter = right.to_partition_tasks(py, psets, is_ray_runner)?; + let upstream_left_iter = left.to_partition_tasks(py, psets)?; + let upstream_right_iter = right.to_partition_tasks(py, psets)?; let left_on_pyexprs: Vec = left_on .iter() .map(|expr| PyExpr::from(expr.clone())) @@ -1058,7 +908,7 @@ impl PhysicalPlan { input, }) => tabular_write( py, - input.to_partition_tasks(py, psets, is_ray_runner)?, + input.to_partition_tasks(py, psets)?, file_format, schema, root_dir, @@ -1079,7 +929,7 @@ impl PhysicalPlan { input, }) => tabular_write( py, - input.to_partition_tasks(py, psets, is_ray_runner)?, + input.to_partition_tasks(py, psets)?, file_format, schema, root_dir, @@ -1100,7 +950,7 @@ impl PhysicalPlan { input, }) => tabular_write( py, - input.to_partition_tasks(py, psets, is_ray_runner)?, + input.to_partition_tasks(py, psets)?, file_format, schema, root_dir, diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index e5d999b2a3..6a507a72b7 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -10,7 +10,6 @@ use common_error::DaftResult; use daft_core::count_mode::CountMode; use daft_core::DataType; use daft_dsl::Expr; -use daft_scan::file_format::FileFormatConfig; use daft_scan::ScanExternalInfo; use crate::logical_ops::{ @@ -24,7 +23,7 @@ use crate::logical_plan::LogicalPlan; use crate::partitioning::PartitionSchemeConfig; use crate::physical_plan::PhysicalPlan; use crate::sink_info::{OutputFileInfo, SinkInfo}; -use crate::source_info::{ExternalInfo as ExternalSourceInfo, LegacyExternalInfo, SourceInfo}; +use crate::source_info::SourceInfo; use crate::FileFormat; use crate::{physical_ops::*, JoinStrategy, PartitionSpec}; @@ -34,56 +33,13 @@ use crate::physical_ops::InMemoryScan; /// Translate a logical plan to a physical plan. pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftResult { match logical_plan { - LogicalPlan::Source(Source { - output_schema, - source_info, - }) => match source_info.as_ref() { - SourceInfo::ExternalInfo(ExternalSourceInfo::Legacy( - ext_info @ LegacyExternalInfo { - file_format_config, - file_infos, - pushdowns, - .. - }, - )) => { - let partition_spec = Arc::new(PartitionSpec::new( - PartitionSchemeConfig::Unknown(Default::default()), - file_infos.len(), - None, - )); - match file_format_config.as_ref() { - FileFormatConfig::Parquet(_) => { - Ok(PhysicalPlan::TabularScanParquet(TabularScanParquet::new( - output_schema.clone(), - ext_info.clone(), - partition_spec, - pushdowns.clone(), - ))) - } - FileFormatConfig::Csv(_) => { - Ok(PhysicalPlan::TabularScanCsv(TabularScanCsv::new( - output_schema.clone(), - ext_info.clone(), - partition_spec, - pushdowns.clone(), - ))) - } - FileFormatConfig::Json(_) => { - Ok(PhysicalPlan::TabularScanJson(TabularScanJson::new( - output_schema.clone(), - ext_info.clone(), - partition_spec, - pushdowns.clone(), - ))) - } - } - } - SourceInfo::ExternalInfo(ExternalSourceInfo::Scan(ScanExternalInfo { + LogicalPlan::Source(Source { source_info, .. }) => match source_info.as_ref() { + SourceInfo::ExternalInfo(ScanExternalInfo { pushdowns, scan_op, source_schema, .. - })) => { + }) => { let scan_tasks = scan_op.0.to_scan_tasks(pushdowns.clone())?; let scan_tasks = daft_scan::scan_task_iters::split_by_row_groups( @@ -796,7 +752,7 @@ mod tests { use crate::physical_plan::PhysicalPlan; use crate::planner::plan; - use crate::test::dummy_scan_node; + use crate::test::{dummy_scan_node, dummy_scan_operator}; /// Tests that planner drops a simple Repartition (e.g. df.into_partitions()) the child already has the desired number of partitions. /// @@ -805,10 +761,10 @@ mod tests { fn repartition_dropped_redundant_into_partitions() -> DaftResult<()> { let cfg: Arc = DaftExecutionConfig::default().into(); // dummy_scan_node() will create the default PartitionSpec, which only has a single partition. - let builder = dummy_scan_node(vec![ + let builder = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) + ])) .repartition(Some(10), vec![], Default::default())? .filter(col("a").lt(&lit(2)))?; assert_eq!( @@ -833,10 +789,10 @@ mod tests { fn repartition_dropped_single_partition() -> DaftResult<()> { let cfg: Arc = DaftExecutionConfig::default().into(); // dummy_scan_node() will create the default PartitionSpec, which only has a single partition. - let builder = dummy_scan_node(vec![ + let builder = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]); + ])); assert_eq!( plan(builder.build().as_ref(), cfg.clone())? .partition_spec() @@ -847,7 +803,7 @@ mod tests { .repartition(Some(1), vec![col("a")], Default::default())? .build(); let physical_plan = plan(logical_plan.as_ref(), cfg.clone())?; - assert_matches!(physical_plan, PhysicalPlan::TabularScanJson(_)); + assert_matches!(physical_plan, PhysicalPlan::TabularScan(_)); Ok(()) } @@ -857,10 +813,10 @@ mod tests { #[test] fn repartition_dropped_same_partition_spec() -> DaftResult<()> { let cfg = DaftExecutionConfig::default().into(); - let logical_plan = dummy_scan_node(vec![ + let logical_plan = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), - ]) + ])) .repartition(Some(10), vec![col("a")], Default::default())? .filter(col("a").lt(&lit(2)))? .repartition(Some(10), vec![col("a")], Default::default())? @@ -877,10 +833,10 @@ mod tests { #[test] fn repartition_dropped_same_partition_spec_agg() -> DaftResult<()> { let cfg = DaftExecutionConfig::default().into(); - let logical_plan = dummy_scan_node(vec![ + let logical_plan = dummy_scan_node(dummy_scan_operator(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Int64), - ]) + ])) .repartition(Some(10), vec![col("a")], Default::default())? .aggregate( vec![Expr::Agg(AggExpr::Sum(col("a").into()))], diff --git a/src/daft-plan/src/source_info/mod.rs b/src/daft-plan/src/source_info/mod.rs index c6053fb8b6..3d2f7acbed 100644 --- a/src/daft-plan/src/source_info/mod.rs +++ b/src/daft-plan/src/source_info/mod.rs @@ -1,11 +1,9 @@ pub mod file_info; use daft_core::schema::SchemaRef; -use daft_scan::storage_config::StorageConfig; use daft_scan::ScanExternalInfo; -use daft_scan::{file_format::FileFormatConfig, Pushdowns}; pub use file_info::{FileInfo, FileInfos}; use serde::{Deserialize, Serialize}; -use std::{hash::Hash, sync::Arc}; +use std::hash::Hash; #[cfg(feature = "python")] use { daft_scan::py_object_serde::{deserialize_py_object, serialize_py_object}, @@ -17,7 +15,7 @@ use { pub enum SourceInfo { #[cfg(feature = "python")] InMemoryInfo(InMemoryInfo), - ExternalInfo(ExternalInfo), + ExternalInfo(ScanExternalInfo), } #[cfg(feature = "python")] @@ -82,88 +80,3 @@ impl Hash for InMemoryInfo { } } } - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum ExternalInfo { - Scan(ScanExternalInfo), - Legacy(LegacyExternalInfo), -} - -impl ExternalInfo { - pub fn pushdowns(&self) -> &Pushdowns { - match self { - Self::Scan(ScanExternalInfo { pushdowns, .. }) - | Self::Legacy(LegacyExternalInfo { pushdowns, .. }) => pushdowns, - } - } - - pub fn with_pushdowns(&self, pushdowns: Pushdowns) -> Self { - match self { - Self::Scan(external_info) => Self::Scan(ScanExternalInfo { - pushdowns, - ..external_info.clone() - }), - Self::Legacy(external_info) => Self::Legacy(LegacyExternalInfo { - pushdowns, - ..external_info.clone() - }), - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub struct LegacyExternalInfo { - pub source_schema: SchemaRef, - pub file_infos: Arc, - pub file_format_config: Arc, - pub storage_config: Arc, - pub pushdowns: Pushdowns, -} - -impl LegacyExternalInfo { - pub fn new( - source_schema: SchemaRef, - file_infos: Arc, - file_format_config: Arc, - storage_config: Arc, - pushdowns: Pushdowns, - ) -> Self { - Self { - source_schema, - file_infos, - file_format_config, - storage_config, - pushdowns, - } - } - - pub fn multiline_display(&self) -> Vec { - let mut res = vec![]; - res.push(format!( - "File paths = [{}]", - self.file_infos.file_paths.join(", ") - )); - res.push(format!( - "File schema = {}", - self.source_schema.short_string() - )); - let file_format = self.file_format_config.multiline_display(); - if !file_format.is_empty() { - res.push(format!( - "{} config= {}", - self.file_format_config.var_name(), - file_format.join(", ") - )); - } - let storage_config = self.storage_config.multiline_display(); - if !storage_config.is_empty() { - res.push(format!( - "{} storage config = {{ {} }}", - self.storage_config.var_name(), - storage_config.join(", ") - )); - } - res.extend(self.pushdowns.multiline_display()); - res - } -} diff --git a/src/daft-plan/src/test/mod.rs b/src/daft-plan/src/test/mod.rs index 37dbff01ee..b3a852a0a9 100644 --- a/src/daft-plan/src/test/mod.rs +++ b/src/daft-plan/src/test/mod.rs @@ -1,51 +1,33 @@ use std::sync::Arc; use daft_core::{datatypes::Field, schema::Schema}; -use daft_scan::{file_format::FileFormatConfig, storage_config::StorageConfig, Pushdowns}; +use daft_scan::{ + file_format::FileFormatConfig, storage_config::StorageConfig, AnonymousScanOperator, Pushdowns, + ScanOperator, +}; -use crate::{builder::LogicalPlanBuilder, source_info::FileInfos, NativeStorageConfig}; - -/// Create a dummy scan node containing the provided fields in its schema. -pub fn dummy_scan_node(fields: Vec) -> LogicalPlanBuilder { - dummy_scan_node_with_pushdowns(fields, Default::default()) -} +use crate::{builder::LogicalPlanBuilder, NativeStorageConfig}; /// Create a dummy scan node containing the provided fields in its schema and the provided limit. -pub fn dummy_scan_node_with_pushdowns( - fields: Vec, - pushdowns: Pushdowns, -) -> LogicalPlanBuilder { +pub fn dummy_scan_operator(fields: Vec) -> Arc { let schema = Arc::new(Schema::new(fields).unwrap()); - - LogicalPlanBuilder::table_scan_with_pushdowns( - FileInfos::new_internal(vec!["/foo".to_string()], vec![None], vec![None]), + Arc::new(AnonymousScanOperator::new( + vec!["/foo".to_string()], schema, FileFormatConfig::Json(Default::default()).into(), StorageConfig::Native(NativeStorageConfig::new_internal(true, None).into()).into(), - pushdowns, - ) - .unwrap() + )) } -pub fn dummy_scan_operator_node(fields: Vec) -> LogicalPlanBuilder { - dummy_scan_operator_node_with_pushdowns(fields, Default::default()) +/// Create a dummy scan node containing the provided fields in its schema. +pub fn dummy_scan_node(scan_op: Arc) -> LogicalPlanBuilder { + dummy_scan_node_with_pushdowns(scan_op, Default::default()) } /// Create a dummy scan node containing the provided fields in its schema and the provided limit. -pub fn dummy_scan_operator_node_with_pushdowns( - fields: Vec, +pub fn dummy_scan_node_with_pushdowns( + scan_op: Arc, pushdowns: Pushdowns, ) -> LogicalPlanBuilder { - let schema = Arc::new(Schema::new(fields).unwrap()); - let anon = daft_scan::AnonymousScanOperator::new( - vec!["/foo".to_string()], - schema, - FileFormatConfig::Json(Default::default()).into(), - StorageConfig::Native(NativeStorageConfig::new_internal(true, None).into()).into(), - ); - LogicalPlanBuilder::table_scan_with_scan_operator( - daft_scan::ScanOperatorRef(Arc::new(anon)), - Some(pushdowns), - ) - .unwrap() + LogicalPlanBuilder::table_scan(daft_scan::ScanOperatorRef(scan_op), Some(pushdowns)).unwrap() } diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index 2ce839eb33..26e9773df8 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -13,7 +13,7 @@ use crate::{ storage_config::StorageConfig, DataFileSource, PartitionField, Pushdowns, ScanOperator, ScanTask, ScanTaskRef, }; -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug)] pub struct GlobScanOperator { glob_paths: Vec, file_format_config: Arc, diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index fab6b271f7..449a5d9acd 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -610,6 +610,15 @@ impl ScanExternalInfo { pushdowns, } } + + pub fn with_pushdowns(&self, pushdowns: Pushdowns) -> Self { + Self { + scan_op: self.scan_op.clone(), + source_schema: self.source_schema.clone(), + partitioning_keys: self.partitioning_keys.clone(), + pushdowns, + } + } } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] diff --git a/tests/dataframe/test_creation.py b/tests/dataframe/test_creation.py index 7d92478ecd..52ab451824 100644 --- a/tests/dataframe/test_creation.py +++ b/tests/dataframe/test_creation.py @@ -424,11 +424,7 @@ def test_create_dataframe_csv_generate_headers(valid_data: list[dict[str, float] writer.writerows([[item[col] for col in header] for item in valid_data]) f.flush() - cnames = ( - [f"column_{i}" for i in range(1, 6)] - if use_native_downloader or os.environ.get("DAFT_MICROPARTITIONS", "1") == "1" - else [f"f{i}" for i in range(5)] - ) + cnames = [f"column_{i}" for i in range(1, 6)] df = daft.read_csv(fname, has_headers=False, use_native_downloader=use_native_downloader) assert df.column_names == cnames @@ -524,15 +520,6 @@ def test_create_dataframe_csv_specify_schema_no_headers( "column_5": DataType.string(), } - if use_native_downloader == False and os.environ.get("DAFT_MICROPARTITIONS") == "0": - schema_hints_for_csv_without_headers = { - "f0": DataType.float64(), - "f1": DataType.float64(), - "f2": DataType.float64(), - "f3": DataType.float64(), - "f4": DataType.string(), - } - df = daft.read_csv( fname, delimiter="\t", diff --git a/tests/integration/io/parquet/test_reads_s3_minio.py b/tests/integration/io/parquet/test_reads_s3_minio.py index 44ae5e3633..f640513f65 100644 --- a/tests/integration/io/parquet/test_reads_s3_minio.py +++ b/tests/integration/io/parquet/test_reads_s3_minio.py @@ -1,7 +1,5 @@ from __future__ import annotations -import os - import pyarrow as pa import pytest from pyarrow import parquet as pq @@ -36,10 +34,5 @@ def test_minio_parquet_read_no_files(minio_io_config): with minio_create_bucket(minio_io_config, bucket_name=bucket_name) as fs: fs.touch("s3://data-engineering-prod/foo/file.txt") - msg = ( - "Glob path had no matches:" - if os.getenv("DAFT_MICROPARTITIONS", "1") == "1" - else "No files found at s3://data-engineering-prod/foo/\\*\\*.parquet" - ) - with pytest.raises(FileNotFoundError, match=msg): + with pytest.raises(FileNotFoundError, match="Glob path had no matches:"): daft.read_parquet("s3://data-engineering-prod/foo/**.parquet", io_config=minio_io_config) diff --git a/tests/io/test_merge_scan_tasks.py b/tests/io/test_merge_scan_tasks.py index 75c426874b..e69d1a1105 100644 --- a/tests/io/test_merge_scan_tasks.py +++ b/tests/io/test_merge_scan_tasks.py @@ -1,7 +1,6 @@ from __future__ import annotations import contextlib -import os import pytest @@ -33,7 +32,6 @@ def csv_files(tmpdir): return tmpdir -@pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions") def test_merge_scan_task_exceed_max(csv_files): with override_merge_scan_tasks_configs(0, 0): df = daft.read_csv(str(csv_files)) @@ -42,7 +40,6 @@ def test_merge_scan_task_exceed_max(csv_files): ), "Should have 3 partitions since all merges are more than the maximum (>0 bytes)" -@pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions") def test_merge_scan_task_below_max(csv_files): with override_merge_scan_tasks_configs(21, 22): df = daft.read_csv(str(csv_files)) @@ -51,7 +48,6 @@ def test_merge_scan_task_below_max(csv_files): ), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the second merge is too large (>22 bytes)" -@pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions") def test_merge_scan_task_above_min(csv_files): with override_merge_scan_tasks_configs(19, 40): df = daft.read_csv(str(csv_files)) @@ -60,7 +56,6 @@ def test_merge_scan_task_above_min(csv_files): ), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the first merge is above the minimum (>19 bytes)" -@pytest.mark.skipif(os.getenv("DAFT_MICROPARTITIONS", "1") == "0", reason="Test can only run on micropartitions") def test_merge_scan_task_below_min(csv_files): with override_merge_scan_tasks_configs(35, 40): df = daft.read_csv(str(csv_files)) From e8697b2229d617a62eb5862e66fb9a1b4ff3333f Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Mon, 26 Feb 2024 14:25:35 -0800 Subject: [PATCH 09/11] [DOCS] Fix notebooks by falling back on null for URL downloads (#1951) Co-authored-by: Jay Chia --- tutorials/text_to_image/using_cloud_with_ray.ipynb | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tutorials/text_to_image/using_cloud_with_ray.ipynb b/tutorials/text_to_image/using_cloud_with_ray.ipynb index 364384f92d..d0e44d039c 100644 --- a/tutorials/text_to_image/using_cloud_with_ray.ipynb +++ b/tutorials/text_to_image/using_cloud_with_ray.ipynb @@ -57,7 +57,9 @@ "USE_RAY = False if CI else True\n", "NUM_ROWS_LIMIT = 16 if CI else 160\n", "IO_CONFIG = daft.io.IOConfig(s3=daft.io.S3Config(anonymous=True, region_name=\"us-west-2\")) # Use anonymous-mode for accessing AWS S3\n", - "PARQUET_URL = \"s3://daft-public-data/tutorials/laion-parquet/train-00000-of-00001-6f24a7497df494ae.parquet\"" + "PARQUET_URL = \"s3://daft-public-data/tutorials/laion-parquet/train-00000-of-00001-6f24a7497df494ae.parquet\"\n", + "\n", + "daft.set_planning_config(default_io_config=IO_CONFIG)" ] }, { @@ -145,8 +147,7 @@ }, "outputs": [], "source": [ - "io_config = daft.io.IOConfig(s3=daft.io.S3Config(anonymous=True)) # Use anonymous-mode for accessing S3\n", - "images_df = parquet_df.with_column(\"images\", col(\"URL\").url.download(io_config=io_config))\n", + "images_df = parquet_df.with_column(\"images\", col(\"URL\").url.download(on_error=\"null\"))\n", "images_df.collect()" ] }, @@ -209,7 +210,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.9.18" }, "vscode": { "interpreter": { From 1a94752c6743bdba5d21f515f2b68e1025eba8a7 Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Mon, 26 Feb 2024 15:17:27 -0800 Subject: [PATCH 10/11] [PERF] Spread scan tasks over Ray cluster. (#1950) This PR forces a `SPREAD` scheduling strategy for scan tasks when using the Ray runner. This should result in better load balancing of read tasks across the Ray cluster, yielding: - better utilization of the aggregate network bandwidth of the cluster, - better memory stability due to a more even post-read object distribution, - better performance of downstream parallel compute operations due to a more even distribution of data over the compute bandwidth of the cluster. Closes #1940 --- daft/execution/execution_step.py | 43 +++++++++++++++++++- daft/execution/rust_physical_plan_shim.py | 49 ++--------------------- daft/runners/ray_runner.py | 13 +++++- 3 files changed, 56 insertions(+), 49 deletions(-) diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 5bd5e483df..a70ab0a1da 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -11,7 +11,7 @@ else: from typing import Protocol -from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest +from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest, ScanTask from daft.expressions import Expression, ExpressionsProjection, col from daft.logical.map_partition_ops import MapPartitionOp from daft.logical.schema import Schema @@ -290,6 +290,47 @@ def num_outputs(self) -> int: return 1 +@dataclass(frozen=True) +class ScanWithTask(SingleOutputInstruction): + scan_task: ScanTask + + def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + return self._scan(inputs) + + def _scan(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + assert len(inputs) == 0 + table = MicroPartition._from_scan_task(self.scan_task) + return [table] + + def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: + assert len(input_metadatas) == 0 + + return [ + PartialPartitionMetadata( + num_rows=self.scan_task.num_rows(), + size_bytes=self.scan_task.size_bytes(), + ) + ] + + +@dataclass(frozen=True) +class EmptyScan(SingleOutputInstruction): + schema: Schema + + def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + return [MicroPartition.empty(self.schema)] + + def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: + assert len(input_metadatas) == 0 + + return [ + PartialPartitionMetadata( + num_rows=0, + size_bytes=0, + ) + ] + + @dataclass(frozen=True) class WriteFile(SingleOutputInstruction): file_format: FileFormat diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 9b3c94b9f8..a0b1b8de8f 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -1,7 +1,5 @@ from __future__ import annotations -from dataclasses import dataclass - from daft.daft import ( FileFormat, IOConfig, @@ -15,7 +13,7 @@ from daft.expressions import Expression, ExpressionsProjection from daft.logical.map_partition_ops import MapPartitionOp from daft.logical.schema import Schema -from daft.runners.partitioning import PartialPartitionMetadata, PartitionT +from daft.runners.partitioning import PartitionT from daft.table import MicroPartition @@ -30,7 +28,7 @@ def scan_with_tasks( # We can instead right-size and bundle the ScanTask into single-instruction bulk reads. for scan_task in scan_tasks: scan_step = execution_step.PartitionTaskBuilder[PartitionT](inputs=[], partial_metadatas=None,).add_instruction( - instruction=ScanWithTask(scan_task), + instruction=execution_step.ScanWithTask(scan_task), # Set the filesize as the memory request. # (Note: this is very conservative; file readers empirically use much more peak memory than 1x file size.) resource_request=ResourceRequest(memory_bytes=scan_task.size_bytes()), @@ -43,53 +41,12 @@ def empty_scan( ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: """yield a plan to create an empty Partition""" scan_step = execution_step.PartitionTaskBuilder[PartitionT](inputs=[], partial_metadatas=None,).add_instruction( - instruction=EmptyScan(schema=schema), + instruction=execution_step.EmptyScan(schema=schema), resource_request=ResourceRequest(memory_bytes=0), ) yield scan_step -@dataclass(frozen=True) -class ScanWithTask(execution_step.SingleOutputInstruction): - scan_task: ScanTask - - def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: - return self._scan(inputs) - - def _scan(self, inputs: list[MicroPartition]) -> list[MicroPartition]: - assert len(inputs) == 0 - table = MicroPartition._from_scan_task(self.scan_task) - return [table] - - def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: - assert len(input_metadatas) == 0 - - return [ - PartialPartitionMetadata( - num_rows=self.scan_task.num_rows(), - size_bytes=self.scan_task.size_bytes(), - ) - ] - - -@dataclass(frozen=True) -class EmptyScan(execution_step.SingleOutputInstruction): - schema: Schema - - def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: - return [MicroPartition.empty(self.schema)] - - def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: - assert len(input_metadatas) == 0 - - return [ - PartialPartitionMetadata( - num_rows=0, - size_bytes=0, - ) - ] - - def project( input: physical_plan.InProgressPhysicalPlan[PartitionT], projection: list[PyExpr], resource_request: ResourceRequest ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 33cd2cf568..5b0f6c8185 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -40,6 +40,7 @@ MultiOutputPartitionTask, PartitionTask, ReduceInstruction, + ScanWithTask, SingleOutputPartitionTask, ) from daft.filesystem import glob_path_with_stats @@ -644,14 +645,22 @@ def _build_partitions( ray_options = {**ray_options, **_get_ray_task_options(task.resource_request)} if isinstance(task.instructions[0], ReduceInstruction): - build_remote = reduce_and_fanout if isinstance(task.instructions[-1], FanoutInstruction) else reduce_pipeline + build_remote = ( + reduce_and_fanout + if task.instructions and isinstance(task.instructions[-1], FanoutInstruction) + else reduce_pipeline + ) build_remote = build_remote.options(**ray_options) [metadatas_ref, *partitions] = build_remote.remote(daft_execution_config_objref, task.instructions, task.inputs) else: build_remote = ( - fanout_pipeline if isinstance(task.instructions[-1], FanoutInstruction) else single_partition_pipeline + fanout_pipeline + if task.instructions and isinstance(task.instructions[-1], FanoutInstruction) + else single_partition_pipeline ) + if task.instructions and isinstance(task.instructions[0], ScanWithTask): + ray_options["scheduling_strategy"] = "SPREAD" build_remote = build_remote.options(**ray_options) [metadatas_ref, *partitions] = build_remote.remote( daft_execution_config_objref, task.instructions, *task.inputs From cc7b9576f1ce6b94c7c78103340ea0dc0088f8ca Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Mon, 26 Feb 2024 17:09:06 -0800 Subject: [PATCH 11/11] [PERF] scan task in memory estimate (#1901) * Closes: https://github.com/Eventual-Inc/Daft/issues/1898 1. When column stats are provided, use only the columns in the materialized schema to estimate in memory size * when column stats are missing, fall back on schema estimate for that field 2. When num_rows is provided, use the materialized schema to estimate in memory size 3. When neither are provided, estimate the in memory size using an inflation factor (same as our writes) and approximate the number of rows. Then use the materialized schema to estimate in memory size 4. thread through the new in memory estimator to the ScanWithTask physical op --- Cargo.lock | 1 + daft/daft.pyi | 5 ++ src/daft-core/src/datatypes/dtype.rs | 37 ++++++++++++++ src/daft-core/src/schema.rs | 7 +++ src/daft-micropartition/src/micropartition.rs | 8 ++-- src/daft-scan/Cargo.toml | 3 +- src/daft-scan/src/lib.rs | 48 +++++++++++++++++-- src/daft-scan/src/python.rs | 13 ++++- src/daft-stats/src/column_stats/mod.rs | 12 +++-- src/daft-stats/src/lib.rs | 1 + src/daft-stats/src/table_stats.rs | 30 ++++++++++-- 11 files changed, 145 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6521e7095e..7f670752de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1404,6 +1404,7 @@ name = "daft-scan" version = "0.2.0-dev0" dependencies = [ "bincode", + "common-daft-config", "common-error", "common-io-config", "daft-core", diff --git a/daft/daft.pyi b/daft/daft.pyi index 8f9ac04c74..63c32914a2 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -593,6 +593,11 @@ class ScanTask: Get number of bytes that will be scanned by this ScanTask. """ ... + def estimate_in_memory_size_bytes(self, cfg: PyDaftExecutionConfig) -> int: + """ + Estimate the In Memory Size of this ScanTask. + """ + ... @staticmethod def catalog_scan_task( file: str, diff --git a/src/daft-core/src/datatypes/dtype.rs b/src/daft-core/src/datatypes/dtype.rs index 65dbfe2b50..9e9815e8cb 100644 --- a/src/daft-core/src/datatypes/dtype.rs +++ b/src/daft-core/src/datatypes/dtype.rs @@ -334,6 +334,43 @@ impl DataType { } } + pub fn estimate_size_bytes(&self) -> Option { + const VARIABLE_TYPE_SIZE: f64 = 20.; + const DEFAULT_LIST_LEN: f64 = 4.; + + let elem_size = match self.to_physical() { + DataType::Null => Some(0.), + DataType::Boolean => Some(0.125), + DataType::Int8 => Some(1.), + DataType::Int16 => Some(2.), + DataType::Int32 => Some(4.), + DataType::Int64 => Some(8.), + DataType::Int128 => Some(16.), + DataType::UInt8 => Some(1.), + DataType::UInt16 => Some(2.), + DataType::UInt32 => Some(4.), + DataType::UInt64 => Some(8.), + DataType::Float32 => Some(4.), + DataType::Float64 => Some(8.), + DataType::Utf8 => Some(VARIABLE_TYPE_SIZE), + DataType::Binary => Some(VARIABLE_TYPE_SIZE), + DataType::FixedSizeList(dtype, len) => { + dtype.estimate_size_bytes().map(|b| b * (len as f64)) + } + DataType::List(dtype) => dtype.estimate_size_bytes().map(|b| b * DEFAULT_LIST_LEN), + DataType::Struct(fields) => Some( + fields + .iter() + .map(|f| f.dtype.estimate_size_bytes().unwrap_or(0f64)) + .sum(), + ), + DataType::Extension(_, dtype, _) => dtype.estimate_size_bytes(), + _ => None, + }; + // add bitmap + elem_size.map(|e| e + 0.125) + } + #[inline] pub fn is_logical(&self) -> bool { matches!( diff --git a/src/daft-core/src/schema.rs b/src/daft-core/src/schema.rs index 04248d9774..e99c85e9b4 100644 --- a/src/daft-core/src/schema.rs +++ b/src/daft-core/src/schema.rs @@ -161,6 +161,13 @@ impl Schema { ); format!("{}\n", table) } + + pub fn estimate_row_size_bytes(&self) -> f64 { + self.fields + .values() + .map(|f| f.dtype.estimate_size_bytes().unwrap_or(0.)) + .sum() + } } impl Eq for Schema {} diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 8fceda4810..372779ae54 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -453,11 +453,9 @@ impl MicroPartition { .iter() .sum(); Some(total_size) - } else if let Some(stats) = &self.statistics { - let row_size = stats.estimate_row_size()?; - Some(row_size * self.len()) - } else if let TableState::Unloaded(scan_task) = guard.deref() && let Some(size_bytes_on_disk) = scan_task.size_bytes_on_disk { - Some(size_bytes_on_disk as usize) + } else if let TableState::Unloaded(scan_task) = guard.deref() { + // TODO: pass in the execution config once we have it available + scan_task.estimate_in_memory_size_bytes(None) } else { // If the table is not loaded, we don't have stats, and we don't have the file size in bytes, return None. // TODO(Clark): Should we pull in the table or trigger a file metadata fetch instead of returning None here? diff --git a/src/daft-scan/Cargo.toml b/src/daft-scan/Cargo.toml index d7b950f235..17dd9a59a2 100644 --- a/src/daft-scan/Cargo.toml +++ b/src/daft-scan/Cargo.toml @@ -1,5 +1,6 @@ [dependencies] bincode = {workspace = true} +common-daft-config = {path = "../common/daft-config", default-features = false} common-error = {path = "../common/error", default-features = false} common-io-config = {path = "../common/io-config", default-features = false} daft-core = {path = "../daft-core", default-features = false} @@ -21,7 +22,7 @@ tokio = {workspace = true} [features] default = ["python"] -python = ["dep:pyo3", "common-error/python", "daft-core/python", "daft-dsl/python", "daft-table/python", "daft-stats/python", "common-io-config/python"] +python = ["dep:pyo3", "common-error/python", "daft-core/python", "daft-dsl/python", "daft-table/python", "daft-stats/python", "common-io-config/python", "common-daft-config/python"] [package] edition = {workspace = true} diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index 449a5d9acd..7f669ae4cf 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -1,6 +1,7 @@ #![feature(if_let_guard)] #![feature(let_chains)] use std::{ + borrow::Cow, fmt::{Debug, Display}, hash::{Hash, Hasher}, sync::Arc, @@ -21,6 +22,7 @@ mod anonymous; pub use anonymous::AnonymousScanOperator; pub mod file_format; mod glob; +use common_daft_config::DaftExecutionConfig; #[cfg(feature = "python")] pub mod py_object_serde; pub mod scan_task_iters; @@ -385,15 +387,55 @@ impl ScanTask { } pub fn size_bytes(&self) -> Option { + self.size_bytes_on_disk.map(|s| s as usize) + } + + pub fn estimate_in_memory_size_bytes( + &self, + config: Option<&DaftExecutionConfig>, + ) -> Option { + let mat_schema = self.materialized_schema(); self.statistics .as_ref() .and_then(|s| { // Derive in-memory size estimate from table stats. - self.num_rows() - .and_then(|num_rows| Some(num_rows * s.estimate_row_size().ok()?)) + self.num_rows().and_then(|num_rows| { + let row_size = s.estimate_row_size(Some(mat_schema.as_ref())).ok()?; + let estimate = (num_rows as f64) * row_size; + Some(estimate as usize) + }) + }) + .or_else(|| { + // if num rows is provided, use that to estimate row size bytes + self.num_rows().map(|num_rows| { + let row_size = mat_schema.estimate_row_size_bytes(); + let estimate = (num_rows as f64) * row_size; + estimate as usize + }) }) // Fall back on on-disk size. - .or_else(|| self.size_bytes_on_disk.map(|s| s as usize)) + .or_else(|| { + self.size_bytes_on_disk.map(|file_size| { + // use inflation factor from config + let config = config + .map_or_else(|| Cow::Owned(DaftExecutionConfig::default()), Cow::Borrowed); + let inflation_factor = match self.file_format_config.as_ref() { + FileFormatConfig::Parquet(_) => config.parquet_inflation_factor, + FileFormatConfig::Csv(_) | FileFormatConfig::Json(_) => { + config.csv_inflation_factor + } + }; + + // estimate number of rows from read schema + let in_mem_size: f64 = (file_size as f64) * inflation_factor; + let read_row_size = self.schema.estimate_row_size_bytes(); + let approx_rows = in_mem_size / read_row_size; + + // estimate in memory size using mat schema estimate + let proj_schema_size = mat_schema.estimate_row_size_bytes(); + (approx_rows * proj_schema_size) as usize + }) + }) } pub fn partition_spec(&self) -> Option<&PartitionSpec> { diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index 93cdcbc374..882c9af304 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -35,7 +35,7 @@ pub mod pylib { use crate::file_format::PyFileFormatConfig; use crate::glob::GlobScanOperator; use crate::storage_config::PyStorageConfig; - + use common_daft_config::PyDaftExecutionConfig; #[pyclass(module = "daft.daft", frozen)] #[derive(Debug, Clone)] pub struct ScanOperatorHandle { @@ -245,6 +245,17 @@ pub mod pylib { pub fn size_bytes(&self) -> PyResult> { Ok(self.0.size_bytes().map(i64::try_from).transpose()?) } + + pub fn estimate_in_memory_size_bytes( + &self, + cfg: PyDaftExecutionConfig, + ) -> PyResult> { + Ok(self + .0 + .estimate_in_memory_size_bytes(Some(cfg.config.as_ref())) + .map(i64::try_from) + .transpose()?) + } } #[pymethods] diff --git a/src/daft-stats/src/column_stats/mod.rs b/src/daft-stats/src/column_stats/mod.rs index 4b1cd02249..ed2d2dfa72 100644 --- a/src/daft-stats/src/column_stats/mod.rs +++ b/src/daft-stats/src/column_stats/mod.rs @@ -118,12 +118,14 @@ impl ColumnRangeStatistics { } } - pub(crate) fn element_size(&self) -> crate::Result { + pub(crate) fn element_size(&self) -> crate::Result> { match self { - Self::Missing => Ok(0), - Self::Loaded(l, u) => Ok((l.size_bytes().context(DaftCoreComputeSnafu)? - + u.size_bytes().context(DaftCoreComputeSnafu)?) - / 2), + Self::Missing => Ok(None), + Self::Loaded(l, u) => Ok(Some( + ((l.size_bytes().context(DaftCoreComputeSnafu)? + + u.size_bytes().context(DaftCoreComputeSnafu)?) as f64) + / 2., + )), } } diff --git a/src/daft-stats/src/lib.rs b/src/daft-stats/src/lib.rs index b55fde3922..3bf362782f 100644 --- a/src/daft-stats/src/lib.rs +++ b/src/daft-stats/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(let_chains)] use common_error::DaftError; use snafu::Snafu; diff --git a/src/daft-stats/src/table_stats.rs b/src/daft-stats/src/table_stats.rs index 973cd8928a..baec34df6e 100644 --- a/src/daft-stats/src/table_stats.rs +++ b/src/daft-stats/src/table_stats.rs @@ -74,12 +74,32 @@ impl TableStatistics { }) } - pub fn estimate_row_size(&self) -> super::Result { - let mut sum_so_far = 0; - - for elem_size in self.columns.values().map(|c| c.element_size()) { - sum_so_far += elem_size?; + pub fn estimate_row_size(&self, schema: Option<&Schema>) -> super::Result { + let mut sum_so_far = 0.; + + if let Some(schema) = schema { + // if schema provided, use it + for field in schema.fields.values() { + let name = field.name.as_str(); + let elem_size = if let Some(stats) = self.columns.get(name) { + // first try to use column stats + stats.element_size()? + } else { + None + } + .or_else(|| { + // failover to use dtype estimate + field.dtype.estimate_size_bytes() + }) + .unwrap_or(0.); + sum_so_far += elem_size; + } + } else { + for elem_size in self.columns.values().map(|c| c.element_size()) { + sum_so_far += elem_size?.unwrap_or(0.); + } } + Ok(sum_so_far) }