diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000000..9affa24854 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[env] +PYO3_PYTHON = "./.venv/bin/python" diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index ef7507bc66..c017f4cf5b 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1 +1,2 @@ d5e444d0a71409ae3701d4249ad877f1fb9e2235 # introduced `rustfmt.toml` and ran formatter; ignoring large formatting changes +45e2944e252ccdd563dc20edd9b29762e05cec1d # auto-fix prefer `Self` over explicit type diff --git a/Cargo.lock b/Cargo.lock index 592a0793ba..0dc52e53bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1324,6 +1324,7 @@ name = "common-error" version = "0.3.0-dev0" dependencies = [ "arrow2", + "parquet2", "pyo3", "regex", "serde_json", @@ -2084,6 +2085,7 @@ dependencies = [ "serde", "snafu", "test-log", + "uuid 1.10.0", ] [[package]] @@ -2170,6 +2172,7 @@ version = "0.3.0-dev0" dependencies = [ "common-daft-config", "common-error", + "common-io-config", "daft-core", "daft-dsl", "daft-functions", diff --git a/Cargo.toml b/Cargo.toml index 1d1065f026..dde3543e70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -229,7 +229,76 @@ features = ["derive", "rc"] version = "1.0.200" [workspace.lints.clippy] +as_conversions = "allow" +cast-sign-loss = "allow" +cast_lossless = "allow" +cast_possible_truncation = "allow" +cast_possible_wrap = "allow" +cast_precision_loss = "allow" +cognitive_complexity = "allow" +default_trait_access = "allow" +doc-markdown = "allow" +doc_link_with_quotes = "allow" +enum_glob_use = "allow" +float_cmp = "allow" +fn_params_excessive_bools = "allow" +from_iter_instead_of_collect = "allow" +future_not_send = "allow" +if_not_else = "allow" +implicit_hasher = "allow" +inline_always = "allow" +into_iter_without_iter = "allow" +items_after_statements = "allow" +iter_with_drain = "allow" # REMOVE +iter_without_into_iter = "allow" +manual_let_else = "allow" +many_single_char_names = "allow" +map_unwrap_or = "allow" +match_bool = "allow" +match_same_arms = "allow" +match_wildcard_for_single_variants = "allow" +missing-panics-doc = "allow" +missing_const_for_fn = "allow" +missing_errors_doc = "allow" +module_name_repetitions = "allow" +must_use_candidate = "allow" +needless_pass_by_value = "allow" +needless_return = "allow" +nonminimal_bool = "allow" +nursery = {level = "deny", priority = -1} +only_used_in_recursion = "allow" +option_if_let_else = "allow" +pedantic = {level = "deny", priority = -1} +perf = {level = "deny", priority = -1} +redundant_closure = "allow" +redundant_closure_for_method_calls = "allow" +redundant_else = "allow" +redundant_pub_crate = "allow" +return_self_not_must_use = "allow" +significant_drop_in_scrutinee = "allow" # REMOVE +significant_drop_tightening = "allow" # REMOVE +similar_names = "allow" +single_match = "allow" +single_match_else = "allow" +struct_excessive_bools = "allow" +style = {level = "deny", priority = 1} +suspicious_operation_groupings = "allow" +too_many_lines = "allow" +trivially_copy_pass_by_ref = "allow" +type_repetition_in_bounds = "allow" +uninlined_format_args = "allow" +unnecessary_wraps = "allow" +unnested_or_patterns = "allow" +unreadable_literal = "allow" +# todo: remove? +unsafe_derive_deserialize = "allow" +unused_async = "allow" +# used_underscore_items = "allow" # REMOVE +unused_self = "allow" use-self = "deny" +used_underscore_binding = "allow" # REMOVE REMOVE +wildcard_imports = "allow" +zero_sized_map_values = "allow" [workspace.package] edition = "2021" diff --git a/README.rst b/README.rst index d1eb6eb42a..cea9a29283 100644 --- a/README.rst +++ b/README.rst @@ -4,13 +4,13 @@ `Website `_ • `Docs `_ • `Installation`_ • `10-minute tour of Daft `_ • `Community and Support `_ -Daft: Distributed dataframes for multimodal data -======================================================= +Daft: Unified Engine for Data Analytics, Engineering & ML/AI +============================================================ -`Daft `_ is a distributed query engine for large-scale data processing in Python and is implemented in Rust. +`Daft `_ is a distributed query engine for large-scale data processing using Python or SQL, implemented in Rust. -* **Familiar interactive API:** Lazy Python Dataframe for rapid and interactive iteration +* **Familiar interactive API:** Lazy Python Dataframe for rapid and interactive iteration, or SQL for analytical queries * **Focus on the what:** Powerful Query Optimizer that rewrites queries to be as efficient as possible * **Data Catalog integrations:** Full integration with data catalogs such as Apache Iceberg * **Rich multimodal type-system:** Supports multimodal types such as Images, URLs, Tensors and more @@ -51,7 +51,7 @@ Quickstart In this example, we load images from an AWS S3 bucket's URLs and resize each image in the dataframe: -.. code:: +.. code:: python import daft diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 9206c9da4d..47e980e907 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -9,7 +9,7 @@ from daft.io.scan import ScanOperator from daft.plan_scheduler.physical_plan_scheduler import PartitionT from daft.runners.partitioning import PartitionCacheEntry from daft.sql.sql_connection import SQLConnection -from daft.udf import PartialStatefulUDF, PartialStatelessUDF +from daft.udf import InitArgsType, PartialStatefulUDF, PartialStatelessUDF if TYPE_CHECKING: import pyarrow as pa @@ -1051,6 +1051,7 @@ class PyExpr: def approx_count_distinct(self) -> PyExpr: ... def approx_percentiles(self, percentiles: float | list[float]) -> PyExpr: ... def mean(self) -> PyExpr: ... + def stddev(self) -> PyExpr: ... def min(self) -> PyExpr: ... def max(self) -> PyExpr: ... def any_value(self, ignore_nulls: bool) -> PyExpr: ... @@ -1134,6 +1135,7 @@ def lit(item: Any) -> PyExpr: ... def date_lit(item: int) -> PyExpr: ... def time_lit(item: int, tu: PyTimeUnit) -> PyExpr: ... def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ... +def duration_lit(item: int, tu: PyTimeUnit) -> PyExpr: ... def decimal_lit(sign: bool, digits: tuple[int, ...], exp: int) -> PyExpr: ... def series_lit(item: PySeries) -> PyExpr: ... def stateless_udf( @@ -1150,12 +1152,14 @@ def stateful_udf( expressions: list[PyExpr], return_dtype: PyDataType, resource_request: ResourceRequest | None, - init_args: tuple[tuple[Any, ...], dict[str, Any]] | None, + init_args: InitArgsType, batch_size: int | None, concurrency: int | None, ) -> PyExpr: ... def check_column_name_validity(name: str, schema: PySchema): ... -def extract_partial_stateful_udf_py(expression: PyExpr) -> dict[str, PartialStatefulUDF]: ... +def extract_partial_stateful_udf_py( + expression: PyExpr, +) -> dict[str, tuple[PartialStatefulUDF, InitArgsType]]: ... def bind_stateful_udfs(expression: PyExpr, initialized_funcs: dict[str, Callable]) -> PyExpr: ... def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ... def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ... @@ -1199,9 +1203,17 @@ def minhash( # ----- # SQL functions # ----- +class SQLFunctionStub: + @property + def name(self) -> str: ... + @property + def docstring(self) -> str: ... + @property + def arg_names(self) -> list[str]: ... + def sql(sql: str, catalog: PyCatalog, daft_planning_config: PyDaftPlanningConfig) -> LogicalPlanBuilder: ... def sql_expr(sql: str) -> PyExpr: ... -def list_sql_functions() -> list[str]: ... +def list_sql_functions() -> list[SQLFunctionStub]: ... def utf8_count_matches(expr: PyExpr, patterns: PyExpr, whole_words: bool, case_sensitive: bool) -> PyExpr: ... def to_struct(inputs: list[PyExpr]) -> PyExpr: ... @@ -1273,6 +1285,7 @@ def dt_truncate(expr: PyExpr, interval: str, relative_to: PyExpr) -> PyExpr: ... # --- def explode(expr: PyExpr) -> PyExpr: ... def list_sort(expr: PyExpr, desc: PyExpr) -> PyExpr: ... +def list_value_counts(expr: PyExpr) -> PyExpr: ... def list_join(expr: PyExpr, delimiter: PyExpr) -> PyExpr: ... def list_count(expr: PyExpr, mode: CountMode) -> PyExpr: ... def list_get(expr: PyExpr, idx: PyExpr, default: PyExpr) -> PyExpr: ... @@ -1324,6 +1337,7 @@ class PySeries: def count(self, mode: CountMode) -> PySeries: ... def sum(self) -> PySeries: ... def mean(self) -> PySeries: ... + def stddev(self) -> PySeries: ... def min(self) -> PySeries: ... def max(self) -> PySeries: ... def agg_list(self) -> PySeries: ... diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 6211423e94..2408890d7b 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -2118,6 +2118,33 @@ def mean(self, *cols: ColumnInputType) -> "DataFrame": """ return self._apply_agg_fn(Expression.mean, cols) + @DataframePublicAPI + def stddev(self, *cols: ColumnInputType) -> "DataFrame": + """Performs a global standard deviation on the DataFrame + + Example: + >>> import daft + >>> df = daft.from_pydict({"col_a":[0,1,2]}) + >>> df = df.stddev("col_a") + >>> df.show() + ╭───────────────────╮ + │ col_a │ + │ --- │ + │ Float64 │ + ╞═══════════════════╡ + │ 0.816496580927726 │ + ╰───────────────────╯ + + (Showing first 1 of 1 rows) + + + Args: + *cols (Union[str, Expression]): columns to stddev + Returns: + DataFrame: Globally aggregated standard deviation. Should be a single row. + """ + return self._apply_agg_fn(Expression.stddev, cols) + @DataframePublicAPI def min(self, *cols: ColumnInputType) -> "DataFrame": """Performs a global min on the DataFrame @@ -2856,6 +2883,34 @@ def mean(self, *cols: ColumnInputType) -> "DataFrame": """ return self.df._apply_agg_fn(Expression.mean, cols, self.group_by) + def stddev(self, *cols: ColumnInputType) -> "DataFrame": + """Performs grouped standard deviation on this GroupedDataFrame. + + Example: + >>> import daft + >>> df = daft.from_pydict({"keys": ["a", "a", "a", "b"], "col_a": [0,1,2,100]}) + >>> df = df.groupby("keys").stddev() + >>> df.show() + ╭──────┬───────────────────╮ + │ keys ┆ col_a │ + │ --- ┆ --- │ + │ Utf8 ┆ Float64 │ + ╞══════╪═══════════════════╡ + │ a ┆ 0.816496580927726 │ + ├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ b ┆ 0 │ + ╰──────┴───────────────────╯ + + (Showing first 2 of 2 rows) + + Args: + *cols (Union[str, Expression]): columns to stddev + + Returns: + DataFrame: DataFrame with grouped standard deviation. + """ + return self.df._apply_agg_fn(Expression.stddev, cols, self.group_by) + def min(self, *cols: ColumnInputType) -> "DataFrame": """Perform grouped min on this GroupedDataFrame. diff --git a/daft/delta_lake/delta_lake_scan.py b/daft/delta_lake/delta_lake_scan.py index eb6973f24d..56bc60bf55 100644 --- a/daft/delta_lake/delta_lake_scan.py +++ b/daft/delta_lake/delta_lake_scan.py @@ -22,12 +22,15 @@ if TYPE_CHECKING: from collections.abc import Iterator + from datetime import datetime logger = logging.getLogger(__name__) class DeltaLakeScanOperator(ScanOperator): - def __init__(self, table_uri: str, storage_config: StorageConfig) -> None: + def __init__( + self, table_uri: str, storage_config: StorageConfig, version: int | str | datetime | None = None + ) -> None: super().__init__() # Unfortunately delta-rs doesn't do very good inference of credentials for S3. Thus the current Daft behavior of passing @@ -67,6 +70,9 @@ def __init__(self, table_uri: str, storage_config: StorageConfig) -> None: table_uri, storage_options=io_config_to_storage_options(deltalake_sdk_io_config, table_uri) ) + if version is not None: + self._table.load_as_version(version) + self._storage_config = storage_config self._schema = Schema.from_pyarrow_schema(self._table.schema().to_pyarrow()) partition_columns = set(self._table.metadata().partition_columns) diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 91be85e02a..03c88b4779 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -241,6 +241,7 @@ def actor_pool_project( with get_context().runner().actor_pool_context( actor_pool_name, actor_resource_request, + task_resource_request, num_actors, projection, ) as actor_pool_id: diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 1ae7e90dac..e67ddead64 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -2,7 +2,8 @@ import math import os -from datetime import date, datetime, time +import warnings +from datetime import date, datetime, time, timedelta from decimal import Decimal from typing import ( TYPE_CHECKING, @@ -22,6 +23,7 @@ from daft.daft import col as _col from daft.daft import date_lit as _date_lit from daft.daft import decimal_lit as _decimal_lit +from daft.daft import duration_lit as _duration_lit from daft.daft import list_sort as _list_sort from daft.daft import lit as _lit from daft.daft import series_lit as _series_lit @@ -114,6 +116,12 @@ def lit(value: object) -> Expression: i64_value = pa_time.cast(pa.int64()).as_py() time_unit = TimeUnit.from_str(pa.type_for_alias(str(pa_time.type)).unit)._timeunit lit_value = _time_lit(i64_value, time_unit) + elif isinstance(value, timedelta): + # pyo3 timedelta (PyDelta) is not available when running in abi3 mode, workaround + pa_duration = pa.scalar(value) + i64_value = pa_duration.cast(pa.int64()).as_py() + time_unit = TimeUnit.from_str(pa_duration.type.unit)._timeunit + lit_value = _duration_lit(i64_value, time_unit) elif isinstance(value, Decimal): sign, digits, exponent = value.as_tuple() assert isinstance(exponent, int) @@ -854,6 +862,11 @@ def mean(self) -> Expression: expr = self._expr.mean() return Expression._from_pyexpr(expr) + def stddev(self) -> Expression: + """Calculates the standard deviation of the values in the expression""" + expr = self._expr.stddev() + return Expression._from_pyexpr(expr) + def min(self) -> Expression: """Calculates the minimum value in the expression""" expr = self._expr.min() @@ -2922,6 +2935,40 @@ def join(self, delimiter: str | Expression) -> Expression: delimiter_expr = Expression._to_expression(delimiter) return Expression._from_pyexpr(native.list_join(self._expr, delimiter_expr._expr)) + def value_counts(self) -> Expression: + """Counts the occurrences of each unique value in the list. + + Returns: + Expression: A Map expression where the keys are unique elements from the + original list of type X, and the values are UInt64 counts representing + the number of times each element appears in the list. + + Note: + This function does not work for nested types. For example, it will not produce a map + with lists as keys. + + Example: + >>> import daft + >>> df = daft.from_pydict({"letters": [["a", "b", "a"], ["b", "c", "b", "c"]]}) + >>> df.with_column("value_counts", df["letters"].list.value_counts()).collect() + ╭──────────────┬───────────────────╮ + │ letters ┆ value_counts │ + │ --- ┆ --- │ + │ List[Utf8] ┆ Map[Utf8: UInt64] │ + ╞══════════════╪═══════════════════╡ + │ [a, b, a] ┆ [{key: a, │ + │ ┆ value: 2, │ + │ ┆ }, {key: … │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ [b, c, b, c] ┆ [{key: b, │ + │ ┆ value: 2, │ + │ ┆ }, {key: … │ + ╰──────────────┴───────────────────╯ + + (Showing first 2 of 2 rows) + """ + return Expression._from_pyexpr(native.list_value_counts(self._expr)) + def count(self, mode: CountMode = CountMode.Valid) -> Expression: """Counts the number of elements in each list @@ -2936,6 +2983,21 @@ def count(self, mode: CountMode = CountMode.Valid) -> Expression: def lengths(self) -> Expression: """Gets the length of each list + (DEPRECATED) Please use Expression.list.length instead + + Returns: + Expression: a UInt64 expression which is the length of each list + """ + warnings.warn( + "This function will be deprecated from Daft version >= 0.3.5! Instead, please use 'Expression.list.length'", + category=DeprecationWarning, + ) + + return Expression._from_pyexpr(native.list_count(self._expr, CountMode.All)) + + def length(self) -> Expression: + """Gets the length of each list + Returns: Expression: a UInt64 expression which is the length of each list """ @@ -3069,21 +3131,21 @@ def get(self, key: Expression) -> Expression: >>> df = daft.from_arrow(pa.table({"map_col": pa_array})) >>> df = df.with_column("a", df["map_col"].map.get("a")) >>> df.show() - ╭──────────────────────────────────────┬───────╮ - │ map_col ┆ a │ - │ --- ┆ --- │ - │ Map[Struct[key: Utf8, value: Int64]] ┆ Int64 │ - ╞══════════════════════════════════════╪═══════╡ - │ [{key: a, ┆ 1 │ - │ value: 1, ┆ │ - │ }] ┆ │ - ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ - │ [] ┆ None │ - ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ - │ [{key: b, ┆ None │ - │ value: 2, ┆ │ - │ }] ┆ │ - ╰──────────────────────────────────────┴───────╯ + ╭──────────────────┬───────╮ + │ map_col ┆ a │ + │ --- ┆ --- │ + │ Map[Utf8: Int64] ┆ Int64 │ + ╞══════════════════╪═══════╡ + │ [{key: a, ┆ 1 │ + │ value: 1, ┆ │ + │ }] ┆ │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ [] ┆ None │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ [{key: b, ┆ None │ + │ value: 2, ┆ │ + │ }] ┆ │ + ╰──────────────────┴───────╯ (Showing first 3 of 3 rows) diff --git a/daft/iceberg/iceberg_write.py b/daft/iceberg/iceberg_write.py index 0de4c950d8..9fda932db7 100644 --- a/daft/iceberg/iceberg_write.py +++ b/daft/iceberg/iceberg_write.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Any, Iterator, List, Tuple from daft import Expression, col +from daft.datatype import DataType +from daft.io.common import _get_schema_from_dict from daft.table import MicroPartition from daft.table.partitioning import PartitionedTable, partition_strings_to_path @@ -211,7 +213,10 @@ def visitor(self, partition_record: "IcebergRecord") -> "IcebergWriteVisitors.Fi return self.FileVisitor(self, partition_record) def to_metadata(self) -> MicroPartition: - return MicroPartition.from_pydict({"data_file": self.data_files}) + col_name = "data_file" + if len(self.data_files) == 0: + return MicroPartition.empty(_get_schema_from_dict({col_name: DataType.python()})) + return MicroPartition.from_pydict({col_name: self.data_files}) def partitioned_table_to_iceberg_iter( diff --git a/daft/io/_deltalake.py b/daft/io/_deltalake.py index c4530bcd98..7165c1a341 100644 --- a/daft/io/_deltalake.py +++ b/daft/io/_deltalake.py @@ -11,12 +11,15 @@ from daft.logical.builder import LogicalPlanBuilder if TYPE_CHECKING: + from datetime import datetime + from daft.unity_catalog import UnityCatalogTable @PublicAPI def read_deltalake( table: Union[str, DataCatalogTable, "UnityCatalogTable"], + version: Optional[Union[int, str, "datetime"]] = None, io_config: Optional["IOConfig"] = None, _multithreaded_io: Optional[bool] = None, ) -> DataFrame: @@ -37,8 +40,11 @@ def read_deltalake( Args: table: Either a URI for the Delta Lake table or a :class:`~daft.io.catalog.DataCatalogTable` instance referencing a table in a data catalog, such as AWS Glue Data Catalog or Databricks Unity Catalog. - io_config: A custom :class:`~daft.daft.IOConfig` to use when accessing Delta Lake object storage data. Defaults to None. - _multithreaded_io: Whether to use multithreading for IO threads. Setting this to False can be helpful in reducing + version (optional): If int is passed, read the table with specified version number. Otherwise if string or datetime, + read the timestamp version of the table. Strings must be RFC 3339 and ISO 8601 date and time format. + Datetimes are assumed to be UTC timezone unless specified. By default, read the latest version of the table. + io_config (optional): A custom :class:`~daft.daft.IOConfig` to use when accessing Delta Lake object storage data. Defaults to None. + _multithreaded_io (optional): Whether to use multithreading for IO threads. Setting this to False can be helpful in reducing the amount of system resources (number of connections and thread contention) when running in the Ray runner. Defaults to None, which will let Daft decide based on the runner it is currently using. @@ -69,7 +75,7 @@ def read_deltalake( raise ValueError( f"table argument must be a table URI string, DataCatalogTable or UnityCatalogTable instance, but got: {type(table)}, {table}" ) - delta_lake_operator = DeltaLakeScanOperator(table_uri, storage_config=storage_config) + delta_lake_operator = DeltaLakeScanOperator(table_uri, storage_config=storage_config, version=version) handle = ScanOperatorHandle.from_python_scan_operator(delta_lake_operator) builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle) diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index e80acb03cb..d861e2e060 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -138,10 +138,15 @@ def initialize_actor_global_state(uninitialized_projection: ExpressionsProjectio logger.info("Initializing stateful UDFs: %s", ", ".join(partial_stateful_udfs.keys())) - # TODO: Account for Stateful Actor initialization arguments as well as user-provided batch_size - PyActorPool.initialized_stateful_udfs_process_singleton = { - name: partial_udf.func_cls() for name, partial_udf in partial_stateful_udfs.items() - } + PyActorPool.initialized_stateful_udfs_process_singleton = {} + for name, (partial_udf, init_args) in partial_stateful_udfs.items(): + if init_args is None: + PyActorPool.initialized_stateful_udfs_process_singleton[name] = partial_udf.func_cls() + else: + args, kwargs = init_args + PyActorPool.initialized_stateful_udfs_process_singleton[name] = partial_udf.func_cls( + *args, **kwargs + ) @staticmethod def build_partitions_with_stateful_project( @@ -332,20 +337,27 @@ def run_iter_tables( @contextlib.contextmanager def actor_pool_context( - self, name: str, resource_request: ResourceRequest, num_actors: int, projection: ExpressionsProjection + self, + name: str, + actor_resource_request: ResourceRequest, + _task_resource_request: ResourceRequest, + num_actors: int, + projection: ExpressionsProjection, ) -> Iterator[str]: actor_pool_id = f"py_actor_pool-{name}" - total_resource_request = resource_request * num_actors + total_resource_request = actor_resource_request * num_actors admitted = self._attempt_admit_task(total_resource_request) if not admitted: raise RuntimeError( - f"Not enough resources available to admit {num_actors} actors, each with resource request: {resource_request}" + f"Not enough resources available to admit {num_actors} actors, each with resource request: {actor_resource_request}" ) try: - self._actor_pools[actor_pool_id] = PyActorPool(actor_pool_id, num_actors, resource_request, projection) + self._actor_pools[actor_pool_id] = PyActorPool( + actor_pool_id, num_actors, actor_resource_request, projection + ) self._actor_pools[actor_pool_id].setup() logger.debug("Created actor pool %s with resources: %s", actor_pool_id, total_resource_request) yield actor_pool_id diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 7c84f7b9dc..c0579e11e5 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -931,9 +931,14 @@ def __init__(self, daft_execution_config: PyDaftExecutionConfig, uninitialized_p for name, psu in extract_partial_stateful_udf_py(expr._expr).items() } logger.info("Initializing stateful UDFs: %s", ", ".join(partial_stateful_udfs.keys())) - self.initialized_stateful_udfs = { - name: partial_udf.func_cls() for name, partial_udf in partial_stateful_udfs.items() - } + + self.initialized_stateful_udfs = {} + for name, (partial_udf, init_args) in partial_stateful_udfs.items(): + if init_args is None: + self.initialized_stateful_udfs[name] = partial_udf.func_cls() + else: + args, kwargs = init_args + self.initialized_stateful_udfs[name] = partial_udf.func_cls(*args, **kwargs) @ray.method(num_returns=2) def run( @@ -981,8 +986,12 @@ def __init__( self._projection = projection def setup(self) -> None: + ray_options = _get_ray_task_options(self._resource_request_per_actor) + self._actors = [ - DaftRayActor.options(name=f"rank={rank}-{self._id}").remote(self._execution_config, self._projection) # type: ignore + DaftRayActor.options(name=f"rank={rank}-{self._id}", **ray_options).remote( # type: ignore + self._execution_config, self._projection + ) for rank in range(self._num_actors) ] @@ -1150,8 +1159,16 @@ def run_iter_tables( @contextlib.contextmanager def actor_pool_context( - self, name: str, resource_request: ResourceRequest, num_actors: PartID, projection: ExpressionsProjection + self, + name: str, + actor_resource_request: ResourceRequest, + task_resource_request: ResourceRequest, + num_actors: PartID, + projection: ExpressionsProjection, ) -> Iterator[str]: + # Ray runs actor methods serially, so the resource request for an actor should be both the actor's resources and the task's resources + resource_request = actor_resource_request + task_resource_request + execution_config = get_context().daft_execution_config if self.ray_client_mode: try: diff --git a/daft/runners/runner.py b/daft/runners/runner.py index c1dd30f64e..730f5e1a4a 100644 --- a/daft/runners/runner.py +++ b/daft/runners/runner.py @@ -67,7 +67,8 @@ def run_iter_tables( def actor_pool_context( self, name: str, - resource_request: ResourceRequest, + actor_resource_request: ResourceRequest, + task_resource_request: ResourceRequest, num_actors: int, projection: ExpressionsProjection, ) -> Iterator[str]: diff --git a/daft/series.py b/daft/series.py index 440d570b4a..5cbcfe7ba0 100644 --- a/daft/series.py +++ b/daft/series.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import Any, Literal, TypeVar from daft.arrow_utils import ensure_array, ensure_chunked_array @@ -511,6 +512,10 @@ def mean(self) -> Series: assert self._series is not None return Series._from_pyseries(self._series.mean()) + def stddev(self) -> Series: + assert self._series is not None + return Series._from_pyseries(self._series.stddev()) + def sum(self) -> Series: assert self._series is not None return Series._from_pyseries(self._series.sum()) @@ -927,6 +932,14 @@ def iceberg_truncate(self, w: int) -> Series: class SeriesListNamespace(SeriesNamespace): def lengths(self) -> Series: + warnings.warn( + "This function will be deprecated from Daft version >= 0.3.5! Instead, please use 'length'", + category=DeprecationWarning, + ) + + return Series._from_pyseries(self._series.list_count(CountMode.All)) + + def length(self) -> Series: return Series._from_pyseries(self._series.list_count(CountMode.All)) def get(self, idx: Series, default: Series) -> Series: diff --git a/daft/sql/_sql_funcs.py b/daft/sql/_sql_funcs.py new file mode 100644 index 0000000000..030cd3b53f --- /dev/null +++ b/daft/sql/_sql_funcs.py @@ -0,0 +1,30 @@ +"""This module is used for Sphinx documentation only. We procedurally generate Python functions to allow +Sphinx to generate documentation pages for every SQL function. +""" + +from __future__ import annotations + +from inspect import Parameter as _Parameter +from inspect import Signature as _Signature + +from daft.daft import list_sql_functions as _list_sql_functions + + +def _create_sql_function(func_name: str, docstring: str, arg_names: list[str]): + def sql_function(*args, **kwargs): + raise NotImplementedError("This function is for documentation purposes only and should not be called.") + + sql_function.__name__ = func_name + sql_function.__qualname__ = func_name + sql_function.__doc__ = docstring + sql_function.__signature__ = _Signature([_Parameter(name, _Parameter.POSITIONAL_OR_KEYWORD) for name in arg_names]) # type: ignore[attr-defined] + + # Register the function in the current module + globals()[func_name] = sql_function + + +__all__ = [] + +for sql_function_stub in _list_sql_functions(): + _create_sql_function(sql_function_stub.name, sql_function_stub.docstring, sql_function_stub.arg_names) + __all__.append(sql_function_stub.name) diff --git a/daft/sql/sql.py b/daft/sql/sql.py index 987a9baeb0..2c9bb78554 100644 --- a/daft/sql/sql.py +++ b/daft/sql/sql.py @@ -1,7 +1,7 @@ # isort: dont-add-import: from __future__ import annotations import inspect -from typing import Optional, overload +from typing import Optional from daft.api_annotations import PublicAPI from daft.context import get_context @@ -38,22 +38,120 @@ def _copy_from(self, other: "SQLCatalog") -> None: @PublicAPI def sql_expr(sql: str) -> Expression: - return Expression._from_pyexpr(_sql_expr(sql)) - + """Parses a SQL string into a Daft Expression -@overload -def sql(sql: str) -> DataFrame: ... + This function allows you to create Daft Expressions from SQL snippets, which can then be used + in Daft operations or combined with other Daft Expressions. + Args: + sql (str): A SQL string to be parsed into a Daft Expression. -@overload -def sql(sql: str, catalog: SQLCatalog, register_globals: bool = ...) -> DataFrame: ... + Returns: + Expression: A Daft Expression representing the parsed SQL. + + Examples: + Create a simple SQL expression: + + >>> import daft + >>> expr = daft.sql_expr("1 + 2") + >>> print(expr) + lit(1) + lit(2) + + Use SQL expression in a Daft DataFrame operation: + + >>> df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + >>> df = df.with_column("c", daft.sql_expr("a + b")) + >>> df.show() + ╭───────┬───────┬───────╮ + │ a ┆ b ┆ c │ + │ --- ┆ --- ┆ --- │ + │ Int64 ┆ Int64 ┆ Int64 │ + ╞═══════╪═══════╪═══════╡ + │ 1 ┆ 4 ┆ 5 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 2 ┆ 5 ┆ 7 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 3 ┆ 6 ┆ 9 │ + ╰───────┴───────┴───────╯ + + (Showing first 3 of 3 rows) + + `daft.sql_expr` is also called automatically for you in some DataFrame operations such as filters: + + >>> df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]}) + >>> result = df.where("x < 3 AND y > 4") + >>> result.show() + ╭───────┬───────╮ + │ x ┆ y │ + │ --- ┆ --- │ + │ Int64 ┆ Int64 │ + ╞═══════╪═══════╡ + │ 2 ┆ 5 │ + ╰───────┴───────╯ + + (Showing first 1 of 1 rows) + """ + return Expression._from_pyexpr(_sql_expr(sql)) @PublicAPI def sql(sql: str, catalog: Optional[SQLCatalog] = None, register_globals: bool = True) -> DataFrame: - """Create a DataFrame from an SQL query. - - EXPERIMENTAL: This features is early in development and will change. + """Run a SQL query, returning the results as a DataFrame + + .. WARNING:: + This features is early in development and will likely experience API changes. + + Examples: + + A simple example joining 2 dataframes together using a SQL statement, relying on Daft to detect the names of + SQL tables using their corresponding Python variable names. + + >>> import daft + >>> + >>> df1 = daft.from_pydict({"a": [1, 2, 3], "b": ["foo", "bar", "baz"]}) + >>> df2 = daft.from_pydict({"a": [1, 2, 3], "c": ["daft", None, None]}) + >>> + >>> # Daft automatically detects `df1` and `df2` from your Python global namespace + >>> result_df = daft.sql("SELECT * FROM df1 JOIN df2 ON df1.a = df2.a") + >>> result_df.show() + ╭───────┬──────┬──────╮ + │ a ┆ b ┆ c │ + │ --- ┆ --- ┆ --- │ + │ Int64 ┆ Utf8 ┆ Utf8 │ + ╞═══════╪══════╪══════╡ + │ 1 ┆ foo ┆ daft │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 2 ┆ bar ┆ None │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 3 ┆ baz ┆ None │ + ╰───────┴──────┴──────╯ + + (Showing first 3 of 3 rows) + + A more complex example using a SQLCatalog to create a named table called `"my_table"`, which can then be referenced from inside your SQL statement. + + >>> import daft + >>> from daft.sql import SQLCatalog + >>> + >>> df = daft.from_pydict({"a": [1, 2, 3], "b": ["foo", "bar", "baz"]}) + >>> + >>> # Register dataframes as tables in SQL explicitly with names + >>> catalog = SQLCatalog({"my_table": df}) + >>> + >>> daft.sql("SELECT a FROM my_table", catalog=catalog).show() + ╭───────╮ + │ a │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 1 │ + ├╌╌╌╌╌╌╌┤ + │ 2 │ + ├╌╌╌╌╌╌╌┤ + │ 3 │ + ╰───────╯ + + (Showing first 3 of 3 rows) Args: sql (str): SQL query to execute diff --git a/daft/table/table_io.py b/daft/table/table_io.py index ba07fab8a4..0f892534d9 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -22,6 +22,7 @@ PythonStorageConfig, StorageConfig, ) +from daft.datatype import DataType from daft.dependencies import pa, pacsv, pads, pajson, pq from daft.expressions import ExpressionsProjection, col from daft.filesystem import ( @@ -29,6 +30,7 @@ canonicalize_protocol, get_protocol_from_path, ) +from daft.io.common import _get_schema_from_dict from daft.logical.schema import Schema from daft.runners.partitioning import ( TableParseCSVOptions, @@ -426,16 +428,22 @@ def __call__(self, written_file): self.parent.paths.append(written_file.path) self.parent.partition_indices.append(self.idx) - def __init__(self, partition_values: MicroPartition | None, path_key: str = "path"): + def __init__(self, partition_values: MicroPartition | None, schema: Schema): self.paths: list[str] = [] self.partition_indices: list[int] = [] self.partition_values = partition_values - self.path_key = path_key + self.path_key = schema.column_names()[ + 0 + ] # I kept this from our original code, but idk why it's the first column name -kevin + self.schema = schema def visitor(self, partition_idx: int) -> TabularWriteVisitors.FileVisitor: return self.FileVisitor(self, partition_idx) def to_metadata(self) -> MicroPartition: + if len(self.paths) == 0: + return MicroPartition.empty(self.schema) + metadata: dict[str, Any] = {self.path_key: self.paths} if self.partition_values: @@ -488,10 +496,7 @@ def write_tabular( partitioned = PartitionedTable(table, partition_cols) - # I kept this from our original code, but idk why it's the first column name -kevin - path_key = schema.column_names()[0] - - visitors = TabularWriteVisitors(partitioned.partition_values(), path_key) + visitors = TabularWriteVisitors(partitioned.partition_values(), schema) for i, (part_table, part_path) in enumerate(partitioned_table_to_hive_iter(partitioned, resolved_path)): size_bytes = part_table.nbytes @@ -686,7 +691,10 @@ def visitor(self, partition_values: dict[str, str | None]) -> DeltaLakeWriteVisi return self.FileVisitor(self, partition_values) def to_metadata(self) -> MicroPartition: - return MicroPartition.from_pydict({"add_action": self.add_actions}) + col_name = "add_action" + if len(self.add_actions) == 0: + return MicroPartition.empty(_get_schema_from_dict({col_name: DataType.python()})) + return MicroPartition.from_pydict({col_name: self.add_actions}) def write_deltalake( diff --git a/daft/udf.py b/daft/udf.py index e2afb495ff..c662dc6ced 100644 --- a/daft/udf.py +++ b/daft/udf.py @@ -4,7 +4,7 @@ import functools import inspect from abc import abstractmethod -from typing import Any, Callable, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union from daft.context import get_context from daft.daft import PyDataType, ResourceRequest @@ -13,6 +13,7 @@ from daft.expressions import Expression from daft.series import PySeries, Series +InitArgsType = Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] UserProvidedPythonFunction = Callable[..., Union[Series, "np.ndarray", list]] @@ -294,7 +295,7 @@ class StatefulUDF(UDF): name: str cls: type return_dtype: DataType - init_args: tuple[tuple[Any, ...], dict[str, Any]] | None = None + init_args: InitArgsType = None concurrency: int | None = None def __post_init__(self): diff --git a/docs/source/10-min.ipynb b/docs/source/10-min.ipynb index cbda803752..d4444c2cd5 100644 --- a/docs/source/10-min.ipynb +++ b/docs/source/10-min.ipynb @@ -569,7 +569,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "See: [Expressions](user_guide/basic_concepts/expressions.rst)\n", + "See: [Expressions](user_guide/expressions.rst)\n", "\n", "Expressions are an API for defining computation that needs to happen over your columns.\n", "\n", @@ -1516,7 +1516,7 @@ "source": [ "### User-Defined Functions\n", "\n", - "See: [UDF User Guide](user_guide/daft_in_depth/udf)" + "See: [UDF User Guide](user_guide/udf)" ] }, { diff --git a/docs/source/_static/high_level_architecture.png b/docs/source/_static/high_level_architecture.png index 8e645c8899..f5133b2736 100644 Binary files a/docs/source/_static/high_level_architecture.png and b/docs/source/_static/high_level_architecture.png differ diff --git a/docs/source/api_docs/dataframe.rst b/docs/source/api_docs/dataframe.rst index f93f052742..14a4e9fa20 100644 --- a/docs/source/api_docs/dataframe.rst +++ b/docs/source/api_docs/dataframe.rst @@ -104,6 +104,7 @@ Aggregations DataFrame.groupby DataFrame.sum DataFrame.mean + DataFrame.stddev DataFrame.count DataFrame.min DataFrame.max diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index ae34b7bb22..a53ef825fd 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -113,6 +113,7 @@ The following can be used with DataFrame.agg or GroupedDataFrame.agg Expression.count Expression.sum Expression.mean + Expression.stddev Expression.min Expression.max Expression.any_value @@ -214,7 +215,7 @@ List :template: autosummary/accessor_method.rst Expression.list.join - Expression.list.lengths + Expression.list.length Expression.list.get Expression.list.slice Expression.list.chunk diff --git a/docs/source/api_docs/index.rst b/docs/source/api_docs/index.rst index 3079870df6..6bee44ad95 100644 --- a/docs/source/api_docs/index.rst +++ b/docs/source/api_docs/index.rst @@ -7,6 +7,7 @@ API Documentation Table of Contents creation dataframe + sql expressions schema datatype diff --git a/docs/source/api_docs/sql.rst b/docs/source/api_docs/sql.rst new file mode 100644 index 0000000000..33cf0c25dd --- /dev/null +++ b/docs/source/api_docs/sql.rst @@ -0,0 +1,15 @@ +SQL +=== + +.. autofunction:: daft.sql + +.. autofunction:: daft.sql_expr + +SQL Functions +------------- + +This is a full list of functions that can be used from within SQL. + + +.. sql-autosummary:: + :toctree: doc_gen/sql_funcs diff --git a/docs/source/conf.py b/docs/source/conf.py index 36e66be49a..fd59d32625 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,12 +9,16 @@ import inspect import os import subprocess +import sys import sphinx_autosummary_accessors # Set environment variable to help code determine whether or not we are running a Sphinx doc build process os.environ["DAFT_SPHINX_BUILD"] = "1" +# Help Sphinx find local custom extensions/directives that we build +sys.path.insert(0, os.path.abspath("ext")) + # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information project = "Daft" @@ -45,10 +49,15 @@ "myst_nb", "sphinx_copybutton", "sphinx_autosummary_accessors", + "sphinx_tabs.tabs", + # Local extensions + "sql_autosummary", ] templates_path = ["_templates", sphinx_autosummary_accessors.templates_path] +# Removes module names that prefix our classes +add_module_names = False # -- Options for Notebook rendering # https://myst-nb.readthedocs.io/en/latest/configuration.html?highlight=nb_execution_mode#execution @@ -86,6 +95,13 @@ "learn/user_guides/remote_cluster_execution": "distributed-computing.html", "learn/quickstart": "learn/10-min.html", "learn/10-min": "../10-min.html", + "user_guide/basic_concepts/expressions": "user_guide/expressions", + "user_guide/basic_concepts/dataframe_introduction": "user_guide/basic_concepts", + "user_guide/basic_concepts/introduction": "user_guide/basic_concepts", + "user_guide/daft_in_depth/aggregations": "user_guide/aggregations", + "user_guide/daft_in_depth/dataframe-operations": "user_guide/dataframe-operations", + "user_guide/daft_in_depth/datatypes": "user_guide/datatypes", + "user_guide/daft_in_depth/udf": "user_guide/udf", } # Resolving code links to github diff --git a/docs/source/ext/__init__.py b/docs/source/ext/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/source/ext/sql_autosummary.py b/docs/source/ext/sql_autosummary.py new file mode 100644 index 0000000000..5e37456cbe --- /dev/null +++ b/docs/source/ext/sql_autosummary.py @@ -0,0 +1,80 @@ +import inspect +import os + +from sphinx.ext.autosummary import Autosummary +from sphinx.util import logging + +logger = logging.getLogger(__name__) + + +TOCTREE = "doc_gen/sql_funcs" +SQL_MODULE_NAME = "daft.sql._sql_funcs" + +STUB_TEMPLATE = """ +.. currentmodule:: None + +.. autofunction:: {module_name}.{name} +""" + + +class SQLAutosummary(Autosummary): + def run(self): + func_names = get_sql_func_names() + # Run the normal autosummary stuff, override self.content + self.content = [f"~{SQL_MODULE_NAME}.{f}" for f in func_names] + nodes = super().run() + return nodes + + def get_sql_module_name(self): + return self.arguments[0] + + +def get_sql_func_names(): + # Import the SQL functions module + module = __import__(SQL_MODULE_NAME, fromlist=[""]) + + names = [] + for name, obj in inspect.getmembers(module): + if inspect.isfunction(obj) and not name.startswith("_"): + names.append(name) + + return names + + +def generate_stub(name: str): + """Generates a stub string for a SQL function""" + stub = name + "\n" + stub += "=" * len(name) + "\n\n" + stub += STUB_TEMPLATE.format(module_name=SQL_MODULE_NAME, name=name) + return stub + + +def generate_files(app): + # Determine where to write .rst files to + output_dir = os.path.join(app.srcdir, "api_docs", TOCTREE) + os.makedirs(output_dir, exist_ok=True) + + # Write stubfiles + func_names = get_sql_func_names() + for name in func_names: + stub_content = generate_stub(name) + filename = f"{SQL_MODULE_NAME}.{name}.rst" + filepath = os.path.join(output_dir, filename) + with open(filepath, "w") as f: + f.write(stub_content) + + # HACK: Not sure if this is ok? + app.env.found_docs.add(filepath) + + +def setup(app): + app.add_directive("sql-autosummary", SQLAutosummary) + + # Generate and register files when the builder is initialized + app.connect("builder-inited", generate_files) + + return { + "version": "0.1", + "parallel_read_safe": True, + "parallel_write_safe": True, + } diff --git a/docs/source/index.rst b/docs/source/index.rst index 3a2d3eabb5..6ee5c431b7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,14 +1,53 @@ Daft Documentation ================== -Daft is a distributed query engine for large-scale data processing in Python and is implemented in Rust. - -* **Familiar interactive API:** Lazy Python Dataframe for rapid and interactive iteration -* **Focus on the what:** Powerful Query Optimizer that rewrites queries to be as efficient as possible -* **Data Catalog integrations:** Full integration with data catalogs such as Apache Iceberg -* **Rich multimodal type-system:** Supports multimodal types such as Images, URLs, Tensors and more -* **Seamless Interchange**: Built on the `Apache Arrow `_ In-Memory Format -* **Built for the cloud:** `Record-setting `_ I/O performance for integrations with S3 cloud storage +Daft is a unified data engine for **data engineering, analytics and ML/AI**. + +Daft exposes both **SQL and Python DataFrame interfaces** as first-class citizens and is written in Rust. + +Daft provides a **snappy and delightful local interactive experience**, but also seamlessly **scales to petabyte-scale distributed workloads**. + +Use-Cases +--------- + +Data Engineering +**************** + +*Combine the performance of DuckDB, Pythonic UX of Polars and scalability of Apache Spark for data engineering from MB to PB scale* + +* Scale ETL workflows effortlessly from local to distributed environments +* Enjoy a Python-first experience without JVM dependency hell +* Leverage native integrations with cloud storage, open catalogs, and data formats + +Data Analytics +************** + +*Blend the snappiness of DuckDB with the scalability of Spark/Trino for unified local and distributed analytics* + +* Utilize complementary SQL and Python interfaces for versatile analytics +* Perform snappy local exploration with DuckDB-like performance +* Seamlessly scale to the cloud, outperforming distributed engines like Spark and Trino + +ML/AI +***** + +*Streamline ML/AI workflows with efficient dataloading from open formats like Parquet and JPEG* + +* Load data efficiently from open formats directly into PyTorch or NumPy +* Schedule large-scale model batch inference on distributed GPU clusters +* Optimize data curation with advanced clustering, deduplication, and filtering + +Technology +---------- + +Daft boasts strong integrations with technologies common across these workloads: + +* **Cloud Object Storage:** Record-setting I/O performance for integrations with S3 cloud storage, `battle-tested at exabyte-scale at Amazon `_ +* **ML/AI Python Ecosystem:** first-class integrations with `PyTorch `_ and `NumPy `_ for efficient interoperability with your ML/AI stack +* **Data Catalogs/Table Formats:** capabilities to effectively query table formats such as `Apache Iceberg `_, `Delta Lake `_ and `Apache Hudi `_ +* **Seamless Data Interchange:** zero-copy integration with `Apache Arrow `_ +* **Multimodal/ML Data:** native functionality for data modalities such as tensors, images, URLs, long-form text and embeddings + Installing Daft --------------- diff --git a/docs/source/migration_guides/coming_from_dask.rst b/docs/source/migration_guides/coming_from_dask.rst index 4e649ec8d3..8eaedb1258 100644 --- a/docs/source/migration_guides/coming_from_dask.rst +++ b/docs/source/migration_guides/coming_from_dask.rst @@ -30,7 +30,7 @@ Daft does not use an index Dask aims for as much feature-parity with pandas as possible, including maintaining the presence of an Index in the DataFrame. But keeping an Index is difficult when moving to a distributed computing environment. Dask doesn’t support row-based positional indexing (with .iloc) because it does not track the length of its partitions. It also does not support pandas MultiIndex. The argument for keeping the Index is that it makes some operations against the sorted index column very fast. In reality, resetting the Index forces a data shuffle and is an expensive operation. -Daft drops the need for an Index to make queries more readable and consistent. How you write a query should not change because of the state of an index or a reset_index call. In our opinion, eliminating the index makes things simpler, more explicit, more readable and therefore less error-prone. Daft achieves this by using the [Expressions API](../user_guide/basic_concepts/expressions). +Daft drops the need for an Index to make queries more readable and consistent. How you write a query should not change because of the state of an index or a reset_index call. In our opinion, eliminating the index makes things simpler, more explicit, more readable and therefore less error-prone. Daft achieves this by using the :doc:`Expressions API <../api_docs/expressions>`. In Dask you would index your DataFrame to return row ``b`` as follows: @@ -80,7 +80,7 @@ For example: res = ddf.map_partitions(my_function, **kwargs) -Daft implements two APIs for mapping computations over the data in your DataFrame in parallel: :doc:`Expressions <../user_guide/basic_concepts/expressions>` and :doc:`UDFs <../user_guide/daft_in_depth/udf>`. Expressions are most useful when you need to define computation over your columns. +Daft implements two APIs for mapping computations over the data in your DataFrame in parallel: :doc:`Expressions <../user_guide/expressions>` and :doc:`UDFs <../user_guide/udf>`. Expressions are most useful when you need to define computation over your columns. .. code:: python @@ -113,7 +113,7 @@ Daft is built as a DataFrame API for distributed Machine learning. You can use D Daft supports Multimodal Data Types ----------------------------------- -Dask supports the same data types as pandas. Daft is built to support many more data types, including Images, nested JSON, tensors, etc. See :doc:`the documentation <../user_guide/daft_in_depth/datatypes>` for a list of all supported data types. +Dask supports the same data types as pandas. Daft is built to support many more data types, including Images, nested JSON, tensors, etc. See :doc:`the documentation <../user_guide/datatypes>` for a list of all supported data types. Distributed Computing and Remote Clusters ----------------------------------------- diff --git a/docs/source/user_guide/daft_in_depth/aggregations.rst b/docs/source/user_guide/aggregations.rst similarity index 100% rename from docs/source/user_guide/daft_in_depth/aggregations.rst rename to docs/source/user_guide/aggregations.rst diff --git a/docs/source/user_guide/basic_concepts.rst b/docs/source/user_guide/basic_concepts.rst index 3bb3a89023..50fb8641cc 100644 --- a/docs/source/user_guide/basic_concepts.rst +++ b/docs/source/user_guide/basic_concepts.rst @@ -1,9 +1,407 @@ Basic Concepts ============== -.. toctree:: +Daft is a distributed data engine. The main abstraction in Daft is the :class:`DataFrame `, which conceptually can be thought of as a "table" of data with rows and columns. - basic_concepts/introduction - basic_concepts/dataframe_introduction - basic_concepts/expressions - basic_concepts/read-and-write +Daft also exposes a :doc:`sql` interface which interoperates closely with the DataFrame interface, allowing you to express data transformations and queries on your tables as SQL strings. + +.. image:: /_static/daft_illustration.png + :alt: Daft python dataframes make it easy to load any data such as PDF documents, images, protobufs, csv, parquet and audio files into a table dataframe structure for easy querying + :width: 500 + :align: center + +Terminology +----------- + +DataFrames +^^^^^^^^^^ + +The :class:`DataFrame ` is the core concept in Daft. Think of it as a table with rows and columns, similar to a spreadsheet or a database table. It's designed to handle large amounts of data efficiently. + +Daft DataFrames are lazy. This means that calling most methods on a DataFrame will not execute that operation immediately - instead, DataFrames expose explicit methods such as :meth:`daft.DataFrame.show` and :meth:`daft.DataFrame.write_parquet` +which will actually trigger computation of the DataFrame. + +Expressions +^^^^^^^^^^^ + +An :class:`Expression ` is a fundamental concept in Daft that allows you to define computations on DataFrame columns. They are the building blocks for transforming and manipulating data +within your DataFrame and will be your best friend if you are working with Daft primarily using the Python API. + +Query Plan +^^^^^^^^^^ + +As mentioned earlier, Daft DataFrames are lazy. Under the hood, each DataFrame in Daft is represented by a plan of operations that describes how to compute that DataFrame. + +This plan is called the "query plan" and calling methods on the DataFrame actually adds steps to the query plan! + +When your DataFrame is executed, Daft will read this plan, optimize it to make it run faster and then execute it to compute the requested results. + +Structured Query Language (SQL) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +SQL is a common query language for expressing queries over tables of data. Daft exposes a SQL API as an alternative (but often also complementary API) to the Python :class:`DataFrame ` and +:class:`Expression ` APIs for building queries. + +You can use SQL in Daft via the :func:`daft.sql` function, and Daft will also convert many SQL-compatible strings into Expressions via :func:`daft.sql_expr` for easy interoperability with DataFrames. + +DataFrame +--------- + +If you are coming from other DataFrame libraries such as Pandas or Polars, here are some key differences about Daft DataFrames: + +1. **Distributed:** When running in a distributed cluster, Daft splits your data into smaller "chunks" called *Partitions*. This allows Daft to process your data in parallel across multiple machines, leveraging more resources to work with large datasets. + +2. **Lazy:** When you write operations on a DataFrame, Daft doesn't execute them immediately. Instead, it creates a plan (called a query plan) of what needs to be done. This plan is optimized and only executed when you specifically request the results, which can lead to more efficient computations. + +3. **Multimodal:** Unlike traditional tables that usually contain simple data types like numbers and text, Daft DataFrames can handle complex data types in its columns. This includes things like images, audio files, or even custom Python objects. + +Common data operations that you would perform on DataFrames are: + +1. **Filtering rows:** Use :meth:`df.where(...) ` to keep only the rows that meet certain conditions. +2. **Creating new columns:** Use :meth:`df.with_column(...) ` to add a new column based on calculations from existing ones. +3. **Joining tables:** Use :meth:`df.join(other_df, ...) ` to combine two DataFrames based on common columns. +4. **Sorting:** Use :meth:`df.sort(...) ` to arrange your data based on values in one or more columns. +5. **Grouping and aggregating:** Use :meth:`df.groupby(...).agg(...) ` to summarize your data by groups. + +Creating a Dataframe +^^^^^^^^^^^^^^^^^^^^ + +.. seealso:: + + :doc:`read-and-write` - a more in-depth guide on various options for reading/writing data to/from Daft DataFrames from in-memory data (Python, Arrow), files (Parquet, CSV, JSON), SQL Databases and Data Catalogs + +Let's create our first Dataframe from a Python dictionary of columns. + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + import daft + + df = daft.from_pydict({ + "A": [1, 2, 3, 4], + "B": [1.5, 2.5, 3.5, 4.5], + "C": [True, True, False, False], + "D": [None, None, None, None], + }) + +Examine your Dataframe by printing it: + +.. code:: python + + df + +.. code-block:: text + :caption: Output + + ╭───────┬─────────┬─────────┬──────╮ + │ A ┆ B ┆ C ┆ D │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ Int64 ┆ Float64 ┆ Boolean ┆ Null │ + ╞═══════╪═════════╪═════════╪══════╡ + │ 1 ┆ 1.5 ┆ true ┆ None │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 2 ┆ 2.5 ┆ true ┆ None │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 3 ┆ 3.5 ┆ false ┆ None │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 4 ┆ 4.5 ┆ false ┆ None │ + ╰───────┴─────────┴─────────┴──────╯ + + (Showing first 4 of 4 rows) + + +Congratulations - you just created your first DataFrame! It has 4 columns, "A", "B", "C", and "D". Let's try to select only the "A", "B", and "C" columns: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = df.select("A", "B", "C") + df + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.sql("SELECT A, B, C FROM df") + df + +.. code-block:: text + :caption: Output + + ╭───────┬─────────┬─────────╮ + │ A ┆ B ┆ C │ + │ --- ┆ --- ┆ --- │ + │ Int64 ┆ Float64 ┆ Boolean │ + ╰───────┴─────────┴─────────╯ + + (No data to display: Dataframe not materialized) + + +But wait - why is it printing the message ``(No data to display: Dataframe not materialized)`` and where are the rows of each column? + +Executing our DataFrame and Viewing Data +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The reason that our DataFrame currently does not display its rows is that Daft DataFrames are **lazy**. This just means that Daft DataFrames will defer all its work until you tell it to execute. + +In this case, Daft is just deferring the work required to read the data and select columns, however in practice this laziness can be very useful for helping Daft optimize your queries before execution! + +.. NOTE:: + + When you call methods on a Daft Dataframe, it defers the work by adding to an internal "plan". You can examine the current plan of a DataFrame by calling :meth:`df.explain() `! + + Passing the ``show_all=True`` argument will show you the plan after Daft applies its query optimizations and the physical (lower-level) plan. + + .. code-block:: text + :caption: Plan Output + + == Unoptimized Logical Plan == + + * Project: col(A), col(B), col(C) + | + * Source: + | Number of partitions = 1 + | Output schema = A#Int64, B#Float64, C#Boolean, D#Null + + + == Optimized Logical Plan == + + * Project: col(A), col(B), col(C) + | + * Source: + | Number of partitions = 1 + | Output schema = A#Int64, B#Float64, C#Boolean, D#Null + + + == Physical Plan == + + * Project: col(A), col(B), col(C) + | Clustering spec = { Num partitions = 1 } + | + * InMemoryScan: + | Schema = A#Int64, B#Float64, C#Boolean, D#Null, + | Size bytes = 65, + | Clustering spec = { Num partitions = 1 } + +We can tell Daft to execute our DataFrame and store the results in-memory using :meth:`df.collect() `: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df.collect() + df + +.. code-block:: text + :caption: Output + + ╭───────┬─────────┬─────────┬──────╮ + │ A ┆ B ┆ C ┆ D │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ Int64 ┆ Float64 ┆ Boolean ┆ Null │ + ╞═══════╪═════════╪═════════╪══════╡ + │ 1 ┆ 1.5 ┆ true ┆ None │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 2 ┆ 2.5 ┆ true ┆ None │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 3 ┆ 3.5 ┆ false ┆ None │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤ + │ 4 ┆ 4.5 ┆ false ┆ None │ + ╰───────┴─────────┴─────────┴──────╯ + + (Showing first 4 of 4 rows) + +Now your DataFrame object ``df`` is **materialized** - Daft has executed all the steps required to compute the results, and has cached the results in memory so that it can display this preview. + +Any subsequent operations on ``df`` will avoid recomputations, and just use this materialized result! + +When should I materialize my DataFrame? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you "eagerly" call :meth:`df.collect() ` immediately on every DataFrame, you may run into issues: + +1. If data is too large at any step, materializing all of it may cause memory issues +2. Optimizations are not possible since we cannot "predict future operations" + +However, data science is all about experimentation and trying different things on the same data. This means that materialization is crucial when working interactively with DataFrames, since it speeds up all subsequent experimentation on that DataFrame. + +We suggest materializing DataFrames using :meth:`df.collect() ` when they contain expensive operations (e.g. sorts or expensive function calls) and have to be called multiple times by downstream code: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = df.sort("A") # expensive sort + df.collect() # materialize the DataFrame + + # All subsequent work on df avoids recomputing previous steps + df.sum("B").show() + df.mean("B").show() + df.with_column("try_this", df["A"] + 1).show(5) + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.sql("SELECT * FROM df ORDER BY A") + df.collect() + + # All subsequent work on df avoids recomputing previous steps + daft.sql("SELECT sum(B) FROM df").show() + daft.sql("SELECT mean(B) FROM df").show() + daft.sql("SELECT *, (A + 1) AS try_this FROM df").show(5) + +.. code-block:: text + :caption: Output + + ╭─────────╮ + │ B │ + │ --- │ + │ Float64 │ + ╞═════════╡ + │ 12 │ + ╰─────────╯ + + (Showing first 1 of 1 rows) + + ╭─────────╮ + │ B │ + │ --- │ + │ Float64 │ + ╞═════════╡ + │ 3 │ + ╰─────────╯ + + (Showing first 1 of 1 rows) + + ╭───────┬─────────┬─────────┬──────────╮ + │ A ┆ B ┆ C ┆ try_this │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ Int64 ┆ Float64 ┆ Boolean ┆ Int64 │ + ╞═══════╪═════════╪═════════╪══════════╡ + │ 1 ┆ 1.5 ┆ true ┆ 2 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤ + │ 2 ┆ 2.5 ┆ true ┆ 3 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤ + │ 3 ┆ 3.5 ┆ false ┆ 4 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤ + │ 4 ┆ 4.5 ┆ false ┆ 5 │ + ╰───────┴─────────┴─────────┴──────────╯ + + (Showing first 4 of 4 rows) + + +In many other cases however, there are better options than materializing your entire DataFrame with :meth:`df.collect() `: + +1. **Peeking with df.show(N)**: If you only want to "peek" at the first few rows of your data for visualization purposes, you can use :meth:`df.show(N) `, which processes and shows only the first ``N`` rows. +2. **Writing to disk**: The ``df.write_*`` methods will process and write your data to disk per-partition, avoiding materializing it all in memory at once. +3. **Pruning data**: You can materialize your DataFrame after performing a :meth:`df.limit() `, :meth:`df.where() ` or :meth:`df.select() ` operation which processes your data or prune it down to a smaller size. + +Schemas and Types +^^^^^^^^^^^^^^^^^ + +Notice also that when we printed our DataFrame, Daft displayed its **schema**. Each column of your DataFrame has a **name** and a **type**, and all data in that column will adhere to that type! + +Daft can display your DataFrame's schema without materializing it. Under the hood, it performs intelligent sampling of your data to determine the appropriate schema, and if you make any modifications to your DataFrame it can infer the resulting types based on the operation. + +.. NOTE:: + + Under the hood, Daft represents data in the `Apache Arrow `_ format, which allows it to efficiently represent and work on data using high-performance kernels which are written in Rust. + + +Running Computation with Expressions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To run computations on data in our DataFrame, we use Expressions. + +The following statement will :meth:`df.show() ` a DataFrame that has only one column - the column ``A`` from our original DataFrame but with every row incremented by 1. + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df.select(df["A"] + 1).show() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + daft.sql("SELECT A + 1 FROM df").show() + +.. code-block:: text + :caption: Output + + ╭───────╮ + │ A │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 2 │ + ├╌╌╌╌╌╌╌┤ + │ 3 │ + ├╌╌╌╌╌╌╌┤ + │ 4 │ + ├╌╌╌╌╌╌╌┤ + │ 5 │ + ╰───────╯ + + (Showing first 4 of 4 rows) + +.. NOTE:: + + A common pattern is to create a new columns using ``DataFrame.with_column``: + + .. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + # Creates a new column named "foo" which takes on values + # of column "A" incremented by 1 + df = df.with_column("foo", df["A"] + 1) + df.show() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + # Creates a new column named "foo" which takes on values + # of column "A" incremented by 1 + df = daft.sql("SELECT *, A + 1 AS foo FROM df") + df.show() + +.. code-block:: text + :caption: Output + + ╭───────┬─────────┬─────────┬───────╮ + │ A ┆ B ┆ C ┆ foo │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ Int64 ┆ Float64 ┆ Boolean ┆ Int64 │ + ╞═══════╪═════════╪═════════╪═══════╡ + │ 1 ┆ 1.5 ┆ true ┆ 2 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 2 ┆ 2.5 ┆ true ┆ 3 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 3 ┆ 3.5 ┆ false ┆ 4 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 4 ┆ 4.5 ┆ false ┆ 5 │ + ╰───────┴─────────┴─────────┴───────╯ + + (Showing first 4 of 4 rows) + +Congratulations, you have just written your first **Expression**: ``df["A"] + 1``! + +Expressions are a powerful way of describing computation on columns. For more details, check out the next section on :doc:`expressions` diff --git a/docs/source/user_guide/basic_concepts/dataframe_introduction.rst b/docs/source/user_guide/basic_concepts/dataframe_introduction.rst deleted file mode 100644 index 7e1075b34b..0000000000 --- a/docs/source/user_guide/basic_concepts/dataframe_introduction.rst +++ /dev/null @@ -1,203 +0,0 @@ -Dataframe -========= - -Data in Daft is represented as a DataFrame, which is a collection of data organized as a **table** with **rows** and **columns**. - -.. image:: /_static/daft_illustration.png - :alt: Daft python dataframes make it easy to load any data such as PDF documents, images, protobufs, csv, parquet and audio files into a table dataframe structure for easy querying - :width: 500 - :align: center - -This document provides an introduction to the Daft Dataframe. - -Creating a Dataframe --------------------- - -Let's create our first Dataframe from a Python dictionary of columns. - -.. code:: python - - import daft - - df = daft.from_pydict({ - "A": [1, 2, 3, 4], - "B": [1.5, 2.5, 3.5, 4.5], - "C": [True, True, False, False], - "D": [None, None, None, None], - }) - -Examine your Dataframe by printing it: - -.. code:: python - - df - -.. code:: none - - +---------+-----------+-----------+-----------+ - | A | B | C | D | - | Int64 | Float64 | Boolean | Null | - +=========+===========+===========+===========+ - | 1 | 1.5 | true | None | - +---------+-----------+-----------+-----------+ - | 2 | 2.5 | true | None | - +---------+-----------+-----------+-----------+ - | 3 | 3.5 | false | None | - +---------+-----------+-----------+-----------+ - | 4 | 4.5 | false | None | - +---------+-----------+-----------+-----------+ - (Showing first 4 of 4 rows) - - -Congratulations - you just created your first DataFrame! It has 4 columns, "A", "B", "C", and "D". Let's try to select only the "A", "B", and "C" columns: - -.. code:: python - - df.select("A", "B", "C") - -.. code:: none - - +---------+-----------+-----------+ - | A | B | C | - | Int64 | Float64 | Boolean | - +=========+===========+===========+ - +---------+-----------+-----------+ - (No data to display: Dataframe not materialized) - - -But wait - why is it printing the message ``(No data to display: Dataframe not materialized)`` and where are the rows of each column? - -Executing our DataFrame and Viewing Data ----------------------------------------- - -The reason that our DataFrame currently does not display its rows is that Daft DataFrames are **lazy**. This just means that Daft DataFrames will defer all its work until you tell it to execute. - -In this case, Daft is just deferring the work required to read the data and select columns, however in practice this laziness can be very useful for helping Daft optimize your queries before execution! - -.. NOTE:: - - When you call methods on a Daft Dataframe, it defers the work by adding to an internal "plan". You can examine the current plan of a DataFrame by calling :meth:`df.explain() `! - - Passing the ``show_all=True`` argument will show you the plan after Daft applies its query optimizations and the physical (lower-level) plan. - -We can tell Daft to execute our DataFrame and cache the results using :meth:`df.collect() `: - -.. code:: python - - df.collect() - df - -.. code:: none - - +---------+-----------+-----------+ - | A | B | C | - | Int64 | Float64 | Boolean | - +=========+===========+===========+ - | 1 | 1.5 | true | - +---------+-----------+-----------+ - | 2 | 2.5 | true | - +---------+-----------+-----------+ - | 3 | 3.5 | false | - +---------+-----------+-----------+ - | 4 | 4.5 | false | - +---------+-----------+-----------+ - (Showing first 4 of 4 rows) - -Now your DataFrame object ``df`` is **materialized** - Daft has executed all the steps required to compute the results, and has cached the results in memory so that it can display this preview. - -Any subsequent operations on ``df`` will avoid recomputations, and just use this materialized result! - -When should I materialize my DataFrame? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -If you "eagerly" call :meth:`df.collect() ` immediately on every DataFrame, you may run into issues: - -1. If data is too large at any step, materializing all of it may cause memory issues -2. Optimizations are not possible since we cannot "predict future operations" - -However, data science is all about experimentation and trying different things on the same data. This means that materialization is crucial when working interactively with DataFrames, since it speeds up all subsequent experimentation on that DataFrame. - -We suggest materializing DataFrames using :meth:`df.collect() ` when they contain expensive operations (e.g. sorts or expensive function calls) and have to be called multiple times by downstream code: - -.. code:: python - - df = df.with_column("A", df["A"].apply(expensive_function)) # expensive function - df = df.sort("A") # expensive sort - df.collect() # materialize the DataFrame - - # All subsequent work on df avoids recomputing previous steps - df.sum().show() - df.mean().show() - df.with_column("try_this", df["A"] + 1).show(5) - -In many other cases however, there are better options than materializing your entire DataFrame with :meth:`df.collect() `: - -1. **Peeking with df.show(N)**: If you only want to "peek" at the first few rows of your data for visualization purposes, you can use :meth:`df.show(N) `, which processes and shows only the first ``N`` rows. -2. **Writing to disk**: The ``df.write_*`` methods will process and write your data to disk per-partition, avoiding materializing it all in memory at once. -3. **Pruning data**: You can materialize your DataFrame after performing a :meth:`df.limit() `, :meth:`df.where() ` or :meth:`df.select() ` operation which processes your data or prune it down to a smaller size. - -Schemas and Types ------------------ - -Notice also that when we printed our DataFrame, Daft displayed its **schema**. Each column of your DataFrame has a **name** and a **type**, and all data in that column will adhere to that type! - -Daft can display your DataFrame's schema without materializing it. Under the hood, it performs intelligent sampling of your data to determine the appropriate schema, and if you make any modifications to your DataFrame it can infer the resulting types based on the operation. - -.. NOTE:: - - Under the hood, Daft represents data in the `Apache Arrow `_ format, which allows it to efficiently represent and work on data using high-performance kernels which are written in Rust. - - -Running Computations --------------------- - -To run computations on data in our DataFrame, we use Expressions. - -The following statement will :meth:`df.show() ` a DataFrame that has only one column - the column ``A`` from our original DataFrame but with every row incremented by 1. - -.. code:: python - - df.select(df["A"] + 1).show() - -.. code:: none - - +---------+ - | A | - | Int64 | - +=========+ - | 2 | - +---------+ - | 3 | - +---------+ - | 4 | - +---------+ - | 5 | - +---------+ - (Showing first 4 rows) - -.. NOTE:: - - A common pattern is to create a new columns using ``DataFrame.with_column``: - - .. code:: python - - # Creates a new column named "foo" which takes on values - # of column "A" incremented by 1 - df = df.with_column("foo", df["A"] + 1) - -Congratulations, you have just written your first **Expression**: ``df["A"] + 1``! - -Expressions -^^^^^^^^^^^ - -Expressions are how you define computations on your columns in Daft. - -The world of Daft contains much more than just numbers, and you can do much more than just add numbers together. Daft's rich Expressions API allows you to do things such as: - -1. Convert between different types with :meth:`df["numbers"].cast(float) ` -2. Download Bytes from a column containing String URLs using :meth:`df["urls"].url.download() ` -3. Run arbitrary Python functions on your data using :meth:`df["objects"].apply(my_python_function) ` - -We are also constantly looking to improve Daft and add more Expression functionality. Please contribute to the project with your ideas and code if you have an Expression in mind! - -The next section on :doc:`expressions` will provide a much deeper look at the Expressions that Daft provides. diff --git a/docs/source/user_guide/basic_concepts/expressions.rst b/docs/source/user_guide/basic_concepts/expressions.rst deleted file mode 100644 index db62ddb2fb..0000000000 --- a/docs/source/user_guide/basic_concepts/expressions.rst +++ /dev/null @@ -1,343 +0,0 @@ -Expressions -=========== - -Expressions are how you can express computations that should be run over columns of data. - -Creating Expressions --------------------- - -Referring to a column in a DataFrame -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Most commonly you will be creating expressions by referring to a column from an existing DataFrame. - -To do so, simply index a DataFrame with the string name of the column: - -.. code:: python - - import daft - - df = daft.from_pydict({"A": [1, 2, 3]}) - - # Refers to column "A" in `df` - df["A"] - -.. code:: none - - col(A) - -When we evaluate this ``df["A"]`` Expression, it will evaluate to the column from the ``df`` DataFrame with name "A"! - -Refer to a column with a certain name -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -You may also find it necessary in certain situations to create an Expression with just the name of a column, without having an existing DataFrame to refer to. You can do this with the :func:`~daft.expressions.col` helper: - -.. code:: python - - from daft import col - - # Refers to a column named "A" - col("A") - -When this Expression is evaluated, it will resolve to "the column named A" in whatever evaluation context it is used within! - -Refer to multiple columns using a wildcard -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -You can create expressions on multiple columns at once using a wildcard. The expression `col("*")` selects every column in a DataFrame, and you can operate on this expression in the same way as a single column: - -.. code:: python - - import daft - from daft import col - - df = daft.from_pydict({"A": [1, 2, 3], "B": [4, 5, 6]}) - df.select(col("*") * 3).show() - -.. code:: none - - ╭───────┬───────╮ - │ A ┆ B │ - │ --- ┆ --- │ - │ Int64 ┆ Int64 │ - ╞═══════╪═══════╡ - │ 3 ┆ 12 │ - ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ - │ 6 ┆ 15 │ - ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ - │ 9 ┆ 18 │ - ╰───────┴───────╯ - -Literals -^^^^^^^^ - -You may find yourself needing to hardcode a "single value" oftentimes as an expression. Daft provides a :func:`~daft.expressions.lit` helper to do so: - -.. code:: python - - from daft import lit - - # Refers to an expression which always evaluates to 42 - lit(42) - -This special :func:`~daft.expressions.lit` expression we just created evaluates always to the value ``42``. - -.. _userguide-numeric-expressions: - -Numeric Expressions -------------------- - -Since column "A" is an integer, we can run numeric computation such as addition, division and checking its value. Here are some examples where we create new columns using the results of such computations: - -.. code:: python - - # Add 1 to each element in column "A" - df = df.with_column("A_add_one", df["A"] + 1) - - # Divide each element in column A by 2 - df = df.with_column("A_divide_two", df["A"] / 2.) - - # Check if each element in column A is more than 1 - df = df.with_column("A_gt_1", df["A"] > 1) - - df.collect() - -.. code:: none - - +---------+-------------+----------------+-----------+ - | A | A_add_one | A_divide_two | A_gt_1 | - | Int64 | Int64 | Float64 | Boolean | - +=========+=============+================+===========+ - | 1 | 2 | 0.5 | false | - +---------+-------------+----------------+-----------+ - | 2 | 3 | 1 | true | - +---------+-------------+----------------+-----------+ - | 3 | 4 | 1.5 | true | - +---------+-------------+----------------+-----------+ - (Showing first 3 of 3 rows) - -Notice that the returned types of these operations are also well-typed according to their input types. For example, calling ``df["A"] > 1`` returns a column of type :meth:`Boolean `. - -Both the :meth:`Float ` and :meth:`Int ` types are numeric types, and inherit many of the same arithmetic Expression operations. You may find the full list of numeric operations in the :ref:`Expressions API reference `. - -.. _userguide-string-expressions: - -String Expressions ------------------- - -Daft also lets you have columns of strings in a DataFrame. Let's take a look! - -.. code:: python - - df = daft.from_pydict({"B": ["foo", "bar", "baz"]}) - df.show() - -.. code:: none - - +--------+ - | B | - | Utf8 | - +========+ - | foo | - +--------+ - | bar | - +--------+ - | baz | - +--------+ - (Showing first 3 rows) - -Unlike the numeric types, the string type does not support arithmetic operations such as ``*`` and ``/``. The one exception to this is the ``+`` operator, which is overridden to concatenate two string expressions as is commonly done in Python. Let's try that! - -.. code:: python - - df = df.with_column("B2", df["B"] + "foo") - df.show() - -.. code:: none - - +--------+--------+ - | B | B2 | - | Utf8 | Utf8 | - +========+========+ - | foo | foofoo | - +--------+--------+ - | bar | barfoo | - +--------+--------+ - | baz | bazfoo | - +--------+--------+ - (Showing first 3 rows) - -There are also many string operators that are accessed through a separate :meth:`.str.* ` "method namespace". - -For example, to check if each element in column "B" contains the substring "a", we can use the :meth:`.str.contains ` method: - -.. code:: python - - df = df.with_column("B2_contains_B", df["B2"].str.contains(df["B"])) - df.show() - -.. code:: none - - +--------+--------+-----------------+ - | B | B2 | B2_contains_B | - | Utf8 | Utf8 | Boolean | - +========+========+=================+ - | foo | foofoo | true | - +--------+--------+-----------------+ - | bar | barfoo | true | - +--------+--------+-----------------+ - | baz | bazfoo | true | - +--------+--------+-----------------+ - (Showing first 3 rows) - -You may find a full list of string operations in the :ref:`Expressions API reference `. - -URL Expressions -^^^^^^^^^^^^^^^ - -One special case of a String column you may find yourself working with is a column of URL strings. - -Daft provides the :meth:`.url.* ` method namespace with functionality for working with URL strings. For example, to download data from URLs: - -.. code:: python - - df = daft.from_pydict({ - "urls": [ - "https://www.google.com", - "s3://daft-public-data/open-images/validation-images/0001eeaf4aed83f9.jpg", - ], - }) - df = df.with_column("data", df["urls"].url.download()) - df.collect() - -.. code:: none - - +----------------------+----------------------+ - | urls | data | - | Utf8 | Binary | - +======================+======================+ - | https://www.google.c | b'`_ as the underlying executor, so you can find the full list of supported filters in the `jaq documentation `_. - -.. _userguide-logical-expressions: - -Logical Expressions -------------------- - -Logical Expressions are an expression that refers to a column of type :meth:`Boolean `, and can only take on the values True or False. - -.. code:: python - - df = daft.from_pydict({"C": [True, False, True]}) - df["C"] - -Daft supports logical operations such as ``&`` (and) and ``|`` (or) between logical expressions. - -Comparisons -^^^^^^^^^^^ - -Many of the types in Daft support comparisons between expressions that returns a Logical Expression. - -For example, here we can compare if each element in column "A" is equal to elements in column "B": - -.. code:: python - - df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 4]}) - - df = df.with_column("A_eq_B", df["A"] == df["B"]) - - df.collect() - -.. code:: none - - +---------+---------+-----------+ - | A | B | A_eq_B | - | Int64 | Int64 | Boolean | - +=========+=========+===========+ - | 1 | 1 | true | - +---------+---------+-----------+ - | 2 | 2 | true | - +---------+---------+-----------+ - | 3 | 4 | false | - +---------+---------+-----------+ - (Showing first 3 of 3 rows) - -Other useful comparisons can be found in the :ref:`Expressions API reference `. - -If Else Pattern -^^^^^^^^^^^^^^^ - -The :meth:`.if_else() ` method is a useful expression to have up your sleeve for choosing values between two other expressions based on a logical expression: - -.. code:: python - - df = daft.from_pydict({"A": [1, 2, 3], "B": [0, 2, 4]}) - - # Pick values from column A if the value in column A is bigger - # than the value in column B. Otherwise, pick values from column B. - df = df.with_column( - "A_if_bigger_else_B", - (df["A"] > df["B"]).if_else(df["A"], df["B"]), - ) - - df.collect() - -.. code:: none - - +---------+---------+----------------------+ - | A | B | A_if_bigger_else_B | - | Int64 | Int64 | Int64 | - +=========+=========+======================+ - | 1 | 0 | 1 | - +---------+---------+----------------------+ - | 2 | 2 | 2 | - +---------+---------+----------------------+ - | 3 | 4 | 4 | - +---------+---------+----------------------+ - (Showing first 3 of 3 rows) - -This is a useful expression for cleaning your data! diff --git a/docs/source/user_guide/basic_concepts/introduction.rst b/docs/source/user_guide/basic_concepts/introduction.rst deleted file mode 100644 index 2fa1c8fa94..0000000000 --- a/docs/source/user_guide/basic_concepts/introduction.rst +++ /dev/null @@ -1,92 +0,0 @@ -Introduction -============ - -Daft is a distributed query engine with a DataFrame API. The two key concepts to Daft are: - -1. :class:`DataFrame `: a Table-like structure that represents rows and columns of data -2. :class:`Expression `: a symbolic representation of computation that transforms columns of the DataFrame to a new one. - -With Daft, you create :class:`DataFrame ` from a variety of sources (e.g. reading data from files, data catalogs or from Python dictionaries) and use :class:`Expression ` to manipulate data in that DataFrame. Let's take a closer look at these two abstractions! - -DataFrame ---------- - -Conceptually, a DataFrame is a "table" of data, with rows and columns. - -.. image:: /_static/daft_illustration.png - :alt: Daft python dataframes make it easy to load any data such as PDF documents, images, protobufs, csv, parquet and audio files into a table dataframe structure for easy querying - :width: 500 - :align: center - -Using this abstraction of a DataFrame, you can run common tabular operations such as: - -1. Filtering rows: :meth:`df.where(...) ` -2. Creating new columns as a computation of existing columns: :meth:`df.with_column(...) ` -3. Joining two tables together: :meth:`df.join(...) ` -4. Sorting a table by the values in specified column(s): :meth:`df.sort(...) ` -5. Grouping and aggregations: :meth:`df.groupby(...).agg(...) ` - -Daft DataFrames are: - -1. **Distributed:** your data is split into *Partitions* and can be processed in parallel/on different machines -2. **Lazy:** computations are enqueued in a query plan which is then optimized and executed only when requested -3. **Multimodal:** columns can contain complex datatypes such as tensors, images and Python objects - -Since Daft is lazy, it can actually execute the query plan on a variety of different backends. By default, it will run computations locally using Python multithreading. However if you need to scale to large amounts of data that cannot be processed on a single machine, using the Ray runner allows Daft to run computations on a `Ray `_ cluster instead. - -Expressions ------------ - -The other important concept to understand when working with Daft are **expressions**. - -Because Daft is "lazy", it needs a way to represent computations that need to be performed on its data so that it can execute these computations at some later time. The answer to this is an :class:`~daft.expressions.Expression`! - -The simplest Expressions are: - -1. The column expression: :func:`col("a") ` which is used to refer to "some column named 'a'" -2. Or, if you already have an existing DataFrame ``df`` with a column named "a", you can refer to its column with Python's square bracket indexing syntax: ``df["a"]`` -3. The literal expression: :func:`lit(100) ` which represents a column that always takes on the provided value - -Daft then provides an extremely rich Expressions library to allow you to compose different computations that need to happen. For example: - -.. code:: python - - from daft import col, DataType - - # Take the column named "a" and add 1 to each element - col("a") + 1 - - # Take the column named "a", cast it to a string and check each element, returning True if it starts with "1" - col("a").cast(DataType.string()).str.startswith("1") - -Expressions are used in DataFrame operations, and the names of these Expressions are resolved to column names on the DataFrame that they are running on. Here is an example: - -.. code:: python - - import daft - - # Create a dataframe with a column "a" that has values [1, 2, 3] - df = daft.from_pydict({"a": [1, 2, 3]}) - - # Create new columns called "a_plus_1" and "a_startswith_1" using Expressions - df = df.select( - col("a"), - (col("a") + 1).alias("a_plus_1"), - col("a").cast(DataType.string()).str.startswith("1").alias("a_startswith_1"), - ) - - df.show() - -.. code:: none - - +---------+------------+------------------+ - | a | a_plus_1 | a_startswith_1 | - | Int64 | Int64 | Boolean | - +=========+============+==================+ - | 1 | 2 | true | - +---------+------------+------------------+ - | 2 | 3 | false | - +---------+------------+------------------+ - | 3 | 4 | false | - +---------+------------+------------------+ - (Showing first 3 rows) diff --git a/docs/source/user_guide/daft_in_depth.rst b/docs/source/user_guide/daft_in_depth.rst deleted file mode 100644 index 9b9702daca..0000000000 --- a/docs/source/user_guide/daft_in_depth.rst +++ /dev/null @@ -1,9 +0,0 @@ -Daft in Depth -============= - -.. toctree:: - - daft_in_depth/datatypes - daft_in_depth/dataframe-operations - daft_in_depth/aggregations - daft_in_depth/udf diff --git a/docs/source/user_guide/daft_in_depth/dataframe-operations.rst b/docs/source/user_guide/dataframe-operations.rst similarity index 100% rename from docs/source/user_guide/daft_in_depth/dataframe-operations.rst rename to docs/source/user_guide/dataframe-operations.rst diff --git a/docs/source/user_guide/daft_in_depth/datatypes.rst b/docs/source/user_guide/datatypes.rst similarity index 100% rename from docs/source/user_guide/daft_in_depth/datatypes.rst rename to docs/source/user_guide/datatypes.rst diff --git a/docs/source/user_guide/expressions.rst b/docs/source/user_guide/expressions.rst new file mode 100644 index 0000000000..54147a9401 --- /dev/null +++ b/docs/source/user_guide/expressions.rst @@ -0,0 +1,584 @@ +Expressions +=========== + +Expressions are how you can express computations that should be run over columns of data. + +Creating Expressions +^^^^^^^^^^^^^^^^^^^^ + +Referring to a column in a DataFrame +#################################### + +Most commonly you will be creating expressions by using the :func:`daft.col` function. + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + # Refers to column "A" + daft.col("A") + + .. group-tab:: ⚙️ SQL + + .. code:: python + + daft.sql_expr("A") + +.. code-block:: text + :caption: Output + + col(A) + +The above code creates an Expression that refers to a column named ``"A"``. + +Using SQL +######### + +Daft can also parse valid SQL as expressions. + +.. tabs:: + + .. group-tab:: ⚙️ SQL + + .. code:: python + + daft.sql_expr("A + 1") + +.. code-block:: text + :caption: Output + + col(A) + lit(1) + +The above code will create an expression representing "the column named 'x' incremented by 1". For many APIs, sql_expr will actually be applied for you as syntactic sugar! + +Literals +######## + +You may find yourself needing to hardcode a "single value" oftentimes as an expression. Daft provides a :func:`~daft.expressions.lit` helper to do so: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + from daft import lit + + # Refers to an expression which always evaluates to 42 + lit(42) + + .. group-tab:: ⚙️ SQL + + .. code:: python + + # Refers to an expression which always evaluates to 42 + daft.sql_expr("42") + +.. code-block:: text + :caption: Output + + lit(42) + +This special :func:`~daft.expressions.lit` expression we just created evaluates always to the value ``42``. + +Wildcard Expressions +#################### + +You can create expressions on multiple columns at once using a wildcard. The expression `col("*")` selects every column in a DataFrame, and you can operate on this expression in the same way as a single column: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + import daft + from daft import col + + df = daft.from_pydict({"A": [1, 2, 3], "B": [4, 5, 6]}) + df.select(col("*") * 3).show() + +.. code-block:: text + :caption: Output + + ╭───────┬───────╮ + │ A ┆ B │ + │ --- ┆ --- │ + │ Int64 ┆ Int64 │ + ╞═══════╪═══════╡ + │ 3 ┆ 12 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 6 ┆ 15 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 9 ┆ 18 │ + ╰───────┴───────╯ + +Wildcards also work very well for accessing all members of a struct column: + + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + import daft + from daft import col + + df = daft.from_pydict({ + "person": [ + {"name": "Alice", "age": 30}, + {"name": "Bob", "age": 25}, + {"name": "Charlie", "age": 35} + ] + }) + + # Access all fields of the 'person' struct + df.select(col("person.*")).show() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + import daft + + df = daft.from_pydict({ + "person": [ + {"name": "Alice", "age": 30}, + {"name": "Bob", "age": 25}, + {"name": "Charlie", "age": 35} + ] + }) + + # Access all fields of the 'person' struct using SQL + daft.sql("SELECT person.* FROM df").show() + +.. code-block:: text + :caption: Output + + ╭──────────┬───────╮ + │ name ┆ age │ + │ --- ┆ --- │ + │ String ┆ Int64 │ + ╞══════════╪═══════╡ + │ Alice ┆ 30 │ + ├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ Bob ┆ 25 │ + ├╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ Charlie ┆ 35 │ + ╰──────────┴───────╯ + +In this example, we use the wildcard `*` to access all fields of the `person` struct column. This is equivalent to selecting each field individually (`person.name`, `person.age`), but is more concise and flexible, especially when dealing with structs that have many fields. + + + +Composing Expressions +^^^^^^^^^^^^^^^^^^^^^ + +.. _userguide-numeric-expressions: + +Numeric Expressions +################### + +Since column "A" is an integer, we can run numeric computation such as addition, division and checking its value. Here are some examples where we create new columns using the results of such computations: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + # Add 1 to each element in column "A" + df = df.with_column("A_add_one", df["A"] + 1) + + # Divide each element in column A by 2 + df = df.with_column("A_divide_two", df["A"] / 2.) + + # Check if each element in column A is more than 1 + df = df.with_column("A_gt_1", df["A"] > 1) + + df.collect() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.sql(""" + SELECT + *, + A + 1 AS A_add_one, + A / 2.0 AS A_divide_two, + A > 1 AS A_gt_1 + FROM df + """) + df.collect() + +.. code-block:: text + :caption: Output + + +---------+-------------+----------------+-----------+ + | A | A_add_one | A_divide_two | A_gt_1 | + | Int64 | Int64 | Float64 | Boolean | + +=========+=============+================+===========+ + | 1 | 2 | 0.5 | false | + +---------+-------------+----------------+-----------+ + | 2 | 3 | 1 | true | + +---------+-------------+----------------+-----------+ + | 3 | 4 | 1.5 | true | + +---------+-------------+----------------+-----------+ + (Showing first 3 of 3 rows) + +Notice that the returned types of these operations are also well-typed according to their input types. For example, calling ``df["A"] > 1`` returns a column of type :meth:`Boolean `. + +Both the :meth:`Float ` and :meth:`Int ` types are numeric types, and inherit many of the same arithmetic Expression operations. You may find the full list of numeric operations in the :ref:`Expressions API reference `. + +.. _userguide-string-expressions: + +String Expressions +################## + +Daft also lets you have columns of strings in a DataFrame. Let's take a look! + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = daft.from_pydict({"B": ["foo", "bar", "baz"]}) + df.show() + +.. code-block:: text + :caption: Output + + +--------+ + | B | + | Utf8 | + +========+ + | foo | + +--------+ + | bar | + +--------+ + | baz | + +--------+ + (Showing first 3 rows) + +Unlike the numeric types, the string type does not support arithmetic operations such as ``*`` and ``/``. The one exception to this is the ``+`` operator, which is overridden to concatenate two string expressions as is commonly done in Python. Let's try that! + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = df.with_column("B2", df["B"] + "foo") + df.show() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.sql("SELECT *, B + 'foo' AS B2 FROM df") + df.show() + +.. code-block:: text + :caption: Output + + +--------+--------+ + | B | B2 | + | Utf8 | Utf8 | + +========+========+ + | foo | foofoo | + +--------+--------+ + | bar | barfoo | + +--------+--------+ + | baz | bazfoo | + +--------+--------+ + (Showing first 3 rows) + +There are also many string operators that are accessed through a separate :meth:`.str.* ` "method namespace". + +For example, to check if each element in column "B" contains the substring "a", we can use the :meth:`.str.contains ` method: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = df.with_column("B2_contains_B", df["B2"].str.contains(df["B"])) + df.show() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.sql("SELECT *, contains(B2, B) AS B2_contains_B FROM df") + df.show() + +.. code-block:: text + :caption: Output + + +--------+--------+-----------------+ + | B | B2 | B2_contains_B | + | Utf8 | Utf8 | Boolean | + +========+========+=================+ + | foo | foofoo | true | + +--------+--------+-----------------+ + | bar | barfoo | true | + +--------+--------+-----------------+ + | baz | bazfoo | true | + +--------+--------+-----------------+ + (Showing first 3 rows) + +You may find a full list of string operations in the :ref:`Expressions API reference `. + +URL Expressions +############### + +One special case of a String column you may find yourself working with is a column of URL strings. + +Daft provides the :meth:`.url.* ` method namespace with functionality for working with URL strings. For example, to download data from URLs: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = daft.from_pydict({ + "urls": [ + "https://www.google.com", + "s3://daft-public-data/open-images/validation-images/0001eeaf4aed83f9.jpg", + ], + }) + df = df.with_column("data", df["urls"].url.download()) + df.collect() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + + df = daft.from_pydict({ + "urls": [ + "https://www.google.com", + "s3://daft-public-data/open-images/validation-images/0001eeaf4aed83f9.jpg", + ], + }) + df = daft.sql(""" + SELECT + urls, + url_download(urls) AS data + FROM df + """) + df.collect() + +.. code-block:: text + :caption: Output + + +----------------------+----------------------+ + | urls | data | + | Utf8 | Binary | + +======================+======================+ + | https://www.google.c | b'`_ as the underlying executor, so you can find the full list of supported filters in the `jaq documentation `_. + +.. _userguide-logical-expressions: + +Logical Expressions +################### + +Logical Expressions are an expression that refers to a column of type :meth:`Boolean `, and can only take on the values True or False. + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = daft.from_pydict({"C": [True, False, True]}) + +Daft supports logical operations such as ``&`` (and) and ``|`` (or) between logical expressions. + +Comparisons +########### + +Many of the types in Daft support comparisons between expressions that returns a Logical Expression. + +For example, here we can compare if each element in column "A" is equal to elements in column "B": + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 4]}) + + df = df.with_column("A_eq_B", df["A"] == df["B"]) + + df.collect() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 4]}) + + df = daft.sql(""" + SELECT + A, + B, + A = B AS A_eq_B + FROM df + """) + + df.collect() + +.. code-block:: text + :caption: Output + + ╭───────┬───────┬─────────╮ + │ A ┆ B ┆ A_eq_B │ + │ --- ┆ --- ┆ --- │ + │ Int64 ┆ Int64 ┆ Boolean │ + ╞═══════╪═══════╪═════════╡ + │ 1 ┆ 1 ┆ true │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ + │ 2 ┆ 2 ┆ true │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ + │ 3 ┆ 4 ┆ false │ + ╰───────┴───────┴─────────╯ + + (Showing first 3 of 3 rows) + +Other useful comparisons can be found in the :ref:`Expressions API reference `. + +If Else Pattern +############### + +The :meth:`.if_else() ` method is a useful expression to have up your sleeve for choosing values between two other expressions based on a logical expression: + +.. tabs:: + + .. group-tab:: 🐍 Python + + .. code:: python + + df = daft.from_pydict({"A": [1, 2, 3], "B": [0, 2, 4]}) + + # Pick values from column A if the value in column A is bigger + # than the value in column B. Otherwise, pick values from column B. + df = df.with_column( + "A_if_bigger_else_B", + (df["A"] > df["B"]).if_else(df["A"], df["B"]), + ) + + df.collect() + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.from_pydict({"A": [1, 2, 3], "B": [0, 2, 4]}) + + df = daft.sql(""" + SELECT + A, + B, + CASE + WHEN A > B THEN A + ELSE B + END AS A_if_bigger_else_B + FROM df + """) + + df.collect() + +.. code-block:: text + :caption: Output + + ╭───────┬───────┬────────────────────╮ + │ A ┆ B ┆ A_if_bigger_else_B │ + │ --- ┆ --- ┆ --- │ + │ Int64 ┆ Int64 ┆ Int64 │ + ╞═══════╪═══════╪════════════════════╡ + │ 1 ┆ 0 ┆ 1 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ 2 ┆ 2 ┆ 2 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ 3 ┆ 4 ┆ 4 │ + ╰───────┴───────┴────────────────────╯ + + (Showing first 3 of 3 rows) + +This is a useful expression for cleaning your data! diff --git a/docs/source/user_guide/fotw/fotw-001-images.ipynb b/docs/source/user_guide/fotw/fotw-001-images.ipynb index 827f98dd57..37d1f796d2 100644 --- a/docs/source/user_guide/fotw/fotw-001-images.ipynb +++ b/docs/source/user_guide/fotw/fotw-001-images.ipynb @@ -447,7 +447,7 @@ "metadata": {}, "source": [ "### Create Thumbnails\n", - "[Expressions](../basic_concepts/expressions) are a Daft API for defining computation that needs to happen over your columns. There are dedicated `image.(...)` Expressions for working with images.\n", + "[Expressions](../expressions) are a Daft API for defining computation that needs to happen over your columns. There are dedicated `image.(...)` Expressions for working with images.\n", "\n", "You can use the `image.resize` Expression to create a thumbnail of each image:" ] @@ -527,7 +527,7 @@ "\n", "We'll define a function that uses a pre-trained PyTorch model [ResNet50](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html) to classify the dog pictures. We'll then pass the `image` column to this PyTorch model and send the classification predictions to a new column `classify_breed`. \n", "\n", - "You will use Daft [User-Defined Functions (UDFs)](../daft_in_depth/udf) to do this. Daft UDFs which are the best way to run computations over multiple rows or columns.\n", + "You will use Daft [User-Defined Functions (UDFs)](../udf) to do this. Daft UDFs which are the best way to run computations over multiple rows or columns.\n", "\n", "#### Setting up PyTorch\n", "\n", diff --git a/docs/source/user_guide/index.rst b/docs/source/user_guide/index.rst index e79607a84d..b4b7150215 100644 --- a/docs/source/user_guide/index.rst +++ b/docs/source/user_guide/index.rst @@ -6,7 +6,13 @@ Daft User Guide :maxdepth: 1 basic_concepts - daft_in_depth + read-and-write + expressions + datatypes + dataframe-operations + sql + aggregations + udf poweruser integrations tutorials @@ -14,22 +20,7 @@ Daft User Guide Welcome to **Daft**! -Daft is a Python dataframe library that enables Pythonic data processing at large scale. - -* **Fast** - Daft kernels are written and accelerated using Rust on Apache Arrow arrays. - -* **Flexible** - you can work with any Python object in a Daft Dataframe. - -* **Interactive** - Daft provides a first-class notebook experience. - -* **Scalable** - Daft uses out-of-core algorithms to work with datasets that cannot fit in memory. - -* **Distributed** - Daft scales to a cluster of machines using Ray to crunch terabytes of data. - -* **Intelligent** - Daft performs query optimizations to speed up your work. - -The core interface provided by Daft is the *DataFrame*, which is a table of data consisting of rows and columns. This user guide -aims to help Daft users master the usage of the Daft *DataFrame* for all your data processing needs! +This user guide aims to help Daft users master the usage of the Daft for all your data needs. .. NOTE:: @@ -39,8 +30,7 @@ aims to help Daft users master the usage of the Daft *DataFrame* for all your da code you may wish to take a look at these resources: 1. :doc:`../10-min`: Itching to run some Daft code? Hit the ground running with our 10 minute quickstart notebook. - 2. (Coming soon!) Cheatsheet: Quick reference to commonly-used Daft APIs and usage patterns - useful to keep next to your laptop as you code! - 3. :doc:`../api_docs/index`: Searchable documentation and reference material to Daft's public Python API. + 2. :doc:`../api_docs/index`: Searchable documentation and reference material to Daft's public API. Table of Contents ----------------- @@ -52,11 +42,23 @@ The Daft User Guide is laid out as follows: High-level overview of Daft interfaces and usage to give you a better understanding of how Daft will fit into your day-to-day workflow. -:doc:`Daft in Depth ` -************************************ +Daft in Depth +************* Core Daft concepts all Daft users will find useful to understand deeply. +* :doc:`read-and-write` +* :doc:`expressions` +* :doc:`datatypes` +* :doc:`dataframe-operations` +* :doc:`aggregations` +* :doc:`udf` + +:doc:`Structured Query Language (SQL) ` +******************************************** + +A look into Daft's SQL interface and how it complements Daft's Pythonic DataFrame APIs. + :doc:`The Daft Poweruser ` ************************************* diff --git a/docs/source/user_guide/basic_concepts/read-and-write.rst b/docs/source/user_guide/read-and-write.rst similarity index 64% rename from docs/source/user_guide/basic_concepts/read-and-write.rst rename to docs/source/user_guide/read-and-write.rst index 1d1a481fea..f8585111d9 100644 --- a/docs/source/user_guide/basic_concepts/read-and-write.rst +++ b/docs/source/user_guide/read-and-write.rst @@ -1,5 +1,5 @@ -Reading/Writing -=============== +Reading/Writing Data +==================== Daft can read data from a variety of sources, and write data to many destinations. @@ -37,7 +37,7 @@ To learn more about each of these constructors, as well as the options that they From Data Catalogs ^^^^^^^^^^^^^^^^^^ -If you use catalogs such as Apache Iceberg or Hive, you may wish to consult our user guide on integrations with Data Catalogs: :doc:`Daft integration with Data Catalogs <../integrations/>`. +If you use catalogs such as Apache Iceberg or Hive, you may wish to consult our user guide on integrations with Data Catalogs: :doc:`Daft integration with Data Catalogs `. From File Paths ^^^^^^^^^^^^^^^ @@ -87,7 +87,50 @@ In order to partition the data, you can specify a partition column, which will a # Read with a partition column df = daft.read_sql("SELECT * FROM my_table", partition_col="date", uri) -To learn more, consult the :doc:`SQL User Guide <../integrations/sql>` or the API documentation on :func:`daft.read_sql`. +To learn more, consult the :doc:`SQL User Guide ` or the API documentation on :func:`daft.read_sql`. + + +Reading a column of URLs +------------------------ + +Daft provides a convenient way to read data from a column of URLs using the :meth:`.url.download() ` method. This is particularly useful when you have a DataFrame with a column containing URLs pointing to external resources that you want to fetch and incorporate into your DataFrame. + +Here's an example of how to use this feature: + +.. code:: python + + # Assume we have a DataFrame with a column named 'image_urls' + df = daft.from_pydict({ + "image_urls": [ + "https://example.com/image1.jpg", + "https://example.com/image2.jpg", + "https://example.com/image3.jpg" + ] + }) + + # Download the content from the URLs and create a new column 'image_data' + df = df.with_column("image_data", df["image_urls"].url.download()) + df.show() + +.. code-block:: text + :caption: Output + + +------------------------------------+------------------------------------+ + | image_urls | image_data | + | Utf8 | Binary | + +====================================+====================================+ + | https://example.com/image1.jpg | b'\xff\xd8\xff\xe0\x00\x10JFIF...' | + +------------------------------------+------------------------------------+ + | https://example.com/image2.jpg | b'\xff\xd8\xff\xe0\x00\x10JFIF...' | + +------------------------------------+------------------------------------+ + | https://example.com/image3.jpg | b'\xff\xd8\xff\xe0\x00\x10JFIF...' | + +------------------------------------+------------------------------------+ + + (Showing first 3 of 3 rows) + + +This approach allows you to efficiently download and process data from a large number of URLs in parallel, leveraging Daft's distributed computing capabilities. + Writing Data diff --git a/docs/source/user_guide/sql.rst b/docs/source/user_guide/sql.rst new file mode 100644 index 0000000000..fec2761e05 --- /dev/null +++ b/docs/source/user_guide/sql.rst @@ -0,0 +1,244 @@ +SQL +=== + +Daft supports Structured Query Language (SQL) as a way of constructing query plans (represented in Python as a :class:`daft.DataFrame`) and expressions (:class:`daft.Expression`). + +SQL is a human-readable way of constructing these query plans, and can often be more ergonomic than using DataFrames for writing queries. + +.. NOTE:: + Daft's SQL support is new and is constantly being improved on! Please give us feedback and we'd love to hear more about what you would like. + +Running SQL on DataFrames +------------------------- + +Daft's :func:`daft.sql` function will automatically detect any :class:`daft.DataFrame` objects in your current Python environment to let you query them easily by name. + +.. tabs:: + + .. group-tab:: ⚙️ SQL + + .. code:: python + + # Note the variable name `my_special_df` + my_special_df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 3]}) + + # Use the SQL table name "my_special_df" to refer to the above DataFrame! + sql_df = daft.sql("SELECT A, B FROM my_special_df") + + sql_df.show() + +.. code-block:: text + :caption: Output + + ╭───────┬───────╮ + │ A ┆ B │ + │ --- ┆ --- │ + │ Int64 ┆ Int64 │ + ╞═══════╪═══════╡ + │ 1 ┆ 1 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 2 ┆ 2 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 3 ┆ 3 │ + ╰───────┴───────╯ + + (Showing first 3 of 3 rows) + +In the above example, we query the DataFrame called `"my_special_df"` by simply referring to it in the SQL command. This produces a new DataFrame `sql_df` which can +natively integrate with the rest of your Daft query. + +Reading data from SQL +--------------------- + +.. WARNING:: + + This feature is a WIP and will be coming soon! We will support reading common datasources directly from SQL: + + .. code-block:: python + + daft.sql("SELECT * FROM read_parquet('s3://...')") + daft.sql("SELECT * FROM read_delta_lake('s3://...')") + + Today, a workaround for this is to construct your dataframe in Python first and use it from SQL instead: + + .. code-block:: python + + df = daft.read_parquet("s3://...") + daft.sql("SELECT * FROM df") + + We appreciate your patience with us and hope to deliver this crucial feature soon! + +SQL Expressions +--------------- + +SQL has the concept of expressions as well. Here is an example of a simple addition expression, adding columns "a" and "b" in SQL to produce a new column C. + +We also present here the equivalent query for SQL and DataFrame. Notice how similar the concepts are! + +.. tabs:: + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 3]}) + df = daft.sql("SELECT A + B as C FROM df") + df.show() + + .. group-tab:: 🐍 Python + + .. code:: python + + expr = (daft.col("A") + daft.col("B")).alias("C") + + df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 3]}) + df = df.select(expr) + df.show() + +.. code-block:: text + :caption: Output + + ╭───────╮ + │ C │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 2 │ + ├╌╌╌╌╌╌╌┤ + │ 4 │ + ├╌╌╌╌╌╌╌┤ + │ 6 │ + ╰───────╯ + + (Showing first 3 of 3 rows) + +In the above query, both the SQL version of the query and the DataFrame version of the query produce the same result. + +Under the hood, they run the same Expression ``col("A") + col("B")``! + +One really cool trick you can do is to use the :func:`daft.sql_expr` function as a helper to easily create Expressions. The following are equivalent: + +.. tabs:: + + .. group-tab:: ⚙️ SQL + + .. code:: python + + sql_expr = daft.sql_expr("A + B as C") + print("SQL expression:", sql_expr) + + .. group-tab:: 🐍 Python + + .. code:: python + + py_expr = (daft.col("A") + daft.col("B")).alias("C") + print("Python expression:", py_expr) + + +.. code-block:: text + :caption: Output + + SQL expression: col(A) + col(B) as C + Python expression: col(A) + col(B) as C + +This means that you can pretty much use SQL anywhere you use Python expressions, making Daft extremely versatile at mixing workflows which leverage both SQL and Python. + +As an example, consider the filter query below and compare the two equivalent Python and SQL queries: + +.. tabs:: + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 3]}) + + # Daft automatically converts this string using `daft.sql_expr` + df = df.where("A < 2") + + df.show() + + .. group-tab:: 🐍 Python + + .. code:: python + + df = daft.from_pydict({"A": [1, 2, 3], "B": [1, 2, 3]}) + + # Using Daft's Python Expression API + df = df.where(df["A"] < 2) + + df.show() + +.. code-block:: text + :caption: Output + + ╭───────┬───────╮ + │ A ┆ B │ + │ --- ┆ --- │ + │ Int64 ┆ Int64 │ + ╞═══════╪═══════╡ + │ 1 ┆ 1 │ + ╰───────┴───────╯ + + (Showing first 1 of 1 rows) + +Pretty sweet! Of course, this support for running Expressions on your columns extends well beyond arithmetic as we'll see in the next section on SQL Functions. + +SQL Functions +------------- + +SQL also has access to all of Daft's powerful :class:`daft.Expression` functionality through SQL functions. + +However, unlike the Python Expression API which encourages method-chaining (e.g. ``col("a").url.download().image.decode()``), in SQL you have to do function nesting instead (e.g. ``"image_decode(url_download(a))""``). + +.. NOTE:: + + A full catalog of the available SQL Functions in Daft is available in the :doc:`../api_docs/sql`. + + Note that it closely mirrors the Python API, with some function naming differences vs the available Python methods. + We also have some aliased functions for ANSI SQL-compliance or familiarity to users coming from other common SQL dialects such as PostgreSQL and SparkSQL to easily find their functionality. + +Here is an example of an equivalent function call in SQL vs Python: + +.. tabs:: + + .. group-tab:: ⚙️ SQL + + .. code:: python + + df = daft.from_pydict({"urls": [ + "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", + "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", + "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", + ]}) + df = daft.sql("SELECT image_decode(url_download(urls)) FROM df") + df.show() + + .. group-tab:: 🐍 Python + + .. code:: python + + df = daft.from_pydict({"urls": [ + "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", + "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", + "https://user-images.githubusercontent.com/17691182/190476440-28f29e87-8e3b-41c4-9c28-e112e595f558.png", + ]}) + df = df.select(daft.col("urls").url.download().image.decode()) + df.show() + +.. code-block:: text + :caption: Output + + ╭──────────────╮ + │ urls │ + │ --- │ + │ Image[MIXED] │ + ╞══════════════╡ + │ │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ │ + ╰──────────────╯ + + (Showing first 3 of 3 rows) diff --git a/docs/source/user_guide/daft_in_depth/udf.rst b/docs/source/user_guide/udf.rst similarity index 100% rename from docs/source/user_guide/daft_in_depth/udf.rst rename to docs/source/user_guide/udf.rst diff --git a/requirements-dev.txt b/requirements-dev.txt index a67574df90..9c7809ac80 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -88,3 +88,4 @@ sphinx-book-theme==1.1.0; python_version >= "3.9" sphinx-reredirects>=0.1.1 sphinx-copybutton>=0.5.2 sphinx-autosummary-accessors==2023.4.0; python_version >= "3.9" +sphinx-tabs==3.4.5 diff --git a/src/arrow2/src/array/list/mod.rs b/src/arrow2/src/array/list/mod.rs index 3948e12002..6d0735ca04 100644 --- a/src/arrow2/src/array/list/mod.rs +++ b/src/arrow2/src/array/list/mod.rs @@ -209,12 +209,18 @@ impl ListArray { if O::IS_LARGE { match data_type.to_logical_type() { DataType::LargeList(child) => Ok(child.as_ref()), - _ => Err(Error::oos("ListArray expects DataType::LargeList")), + got => { + let msg = format!("ListArray expects DataType::LargeList, but got {got:?}"); + Err(Error::oos(msg)) + }, } } else { match data_type.to_logical_type() { DataType::List(child) => Ok(child.as_ref()), - _ => Err(Error::oos("ListArray expects DataType::List")), + got => { + let msg = format!("ListArray expects DataType::List, but got {got:?}"); + Err(Error::oos(msg)) + }, } } } diff --git a/src/arrow2/src/array/map/mod.rs b/src/arrow2/src/array/map/mod.rs index d0dcb46efb..d49a463a2f 100644 --- a/src/arrow2/src/array/map/mod.rs +++ b/src/arrow2/src/array/map/mod.rs @@ -1,3 +1,4 @@ +use super::{new_empty_array, specification::try_check_offsets_bounds, Array, ListArray}; use crate::{ bitmap::Bitmap, datatypes::{DataType, Field}, @@ -5,8 +6,6 @@ use crate::{ offset::OffsetsBuffer, }; -use super::{new_empty_array, specification::try_check_offsets_bounds, Array}; - mod ffi; pub(super) mod fmt; mod iterator; @@ -41,20 +40,27 @@ impl MapArray { try_check_offsets_bounds(&offsets, field.len())?; let inner_field = Self::try_get_field(&data_type)?; - if let DataType::Struct(inner) = inner_field.data_type() { - if inner.len() != 2 { - return Err(Error::InvalidArgumentError( - "MapArray's inner `Struct` must have 2 fields (keys and maps)".to_string(), - )); - } - } else { + + let inner_data_type = inner_field.data_type(); + let DataType::Struct(inner) = inner_data_type else { return Err(Error::InvalidArgumentError( - "MapArray expects `DataType::Struct` as its inner logical type".to_string(), + format!("MapArray expects `DataType::Struct` as its inner logical type, but found {inner_data_type:?}"), )); + }; + + let inner_len = inner.len(); + if inner_len != 2 { + let msg = format!( + "MapArray's inner `Struct` must have 2 fields (keys and maps), but found {} fields", + inner_len + ); + return Err(Error::InvalidArgumentError(msg)); } - if field.data_type() != inner_field.data_type() { + + let field_data_type = field.data_type(); + if field_data_type != inner_field.data_type() { return Err(Error::InvalidArgumentError( - "MapArray expects `field.data_type` to match its inner DataType".to_string(), + format!("MapArray expects `field.data_type` to match its inner DataType, but found \n{field_data_type:?}\nvs\n\n\n{inner_field:?}"), )); } @@ -195,6 +201,57 @@ impl MapArray { impl Array for MapArray { impl_common_array!(); + fn convert_logical_type(&self, target_data_type: DataType) -> Box { + let is_target_map = matches!(target_data_type, DataType::Map { .. }); + + let DataType::Map(current_field, _) = self.data_type() else { + unreachable!( + "Expected MapArray to have Map data type, but found {:?}", + self.data_type() + ); + }; + + if is_target_map { + // For Map-to-Map conversions, we can clone + // (same top level representation we are still a Map). and then change the subtype in + // place. + let mut converted_array = self.to_boxed(); + converted_array.change_type(target_data_type); + return converted_array; + } + + // Target type is a LargeList, so we need to convert to a ListArray before converting + let DataType::LargeList(target_field) = &target_data_type else { + panic!("MapArray can only be converted to Map or LargeList, but target type is {target_data_type:?}"); + }; + + + let current_physical_type = current_field.data_type.to_physical_type(); + let target_physical_type = target_field.data_type.to_physical_type(); + + if current_physical_type != target_physical_type { + panic!( + "Inner physical types must be equal for conversion. Current: {:?}, Target: {:?}", + current_physical_type, target_physical_type + ); + } + + let mut converted_field = self.field.clone(); + converted_field.change_type(target_field.data_type.clone()); + + let original_offsets = self.offsets().clone(); + let converted_offsets = unsafe { original_offsets.map_unchecked(|offset| offset as i64) }; + + let converted_list = ListArray::new( + target_data_type, + converted_offsets, + converted_field, + self.validity.clone(), + ); + + Box::new(converted_list) + } + fn validity(&self) -> Option<&Bitmap> { self.validity.as_ref() } diff --git a/src/arrow2/src/array/mod.rs b/src/arrow2/src/array/mod.rs index f77cc5d60d..4216812ae4 100644 --- a/src/arrow2/src/array/mod.rs +++ b/src/arrow2/src/array/mod.rs @@ -17,13 +17,12 @@ //! //! Most arrays contain a [`MutableArray`] counterpart that is neither clonable nor sliceable, but //! can be operated in-place. -use std::any::Any; -use std::sync::Arc; +use std::{any::Any, sync::Arc}; -use crate::error::Result; use crate::{ bitmap::{Bitmap, MutableBitmap}, datatypes::DataType, + error::Result, }; mod physical_binary; @@ -55,6 +54,21 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { /// When the validity is [`None`], all slots are valid. fn validity(&self) -> Option<&Bitmap>; + /// Returns an iterator over the direct children of this Array. + /// + /// This method is useful for accessing child Arrays in composite types such as struct arrays. + /// By default, it returns an empty iterator, as most array types do not have child arrays. + /// + /// # Returns + /// A boxed iterator yielding mutable references to child Arrays. + /// + /// # Examples + /// For a StructArray, this would return an iterator over its field arrays. + /// For most other array types, this returns an empty iterator. + fn direct_children<'a>(&'a mut self) -> Box + 'a> { + Box::new(core::iter::empty()) + } + /// The number of null slots on this [`Array`]. /// # Implementation /// This is `O(1)` since the number of null elements is pre-computed. @@ -144,46 +158,51 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { /// Clone a `&dyn Array` to an owned `Box`. fn to_boxed(&self) -> Box; - /// Overwrites [`Array`]'s type with a different logical type. + /// Changes the logical type of this array in-place. /// - /// This function is useful to assign a different [`DataType`] to the array. - /// Used to change the arrays' logical type (see example). This updates the array - /// in place and does not clone the array. - /// # Example - /// ```rust,ignore - /// use arrow2::array::Int32Array; - /// use arrow2::datatypes::DataType; + /// This method modifies the array's `DataType` without changing its underlying data. + /// It's useful for reinterpreting the logical meaning of the data (e.g., from Int32 to Date32). + /// + /// # Arguments + /// * `data_type` - The new [`DataType`] to assign to this array. /// - /// let &mut array = Int32Array::from(&[Some(1), None, Some(2)]) - /// array.to(DataType::Date32); - /// assert_eq!( - /// format!("{:?}", array), - /// "Date32[1970-01-02, None, 1970-01-03]" - /// ); - /// ``` /// # Panics - /// Panics iff the `data_type`'s [`PhysicalType`] is not equal to array's `PhysicalType`. + /// Panics if the new `data_type`'s [`PhysicalType`] is not equal to the array's current [`PhysicalType`]. + /// + /// # Example + /// ``` + /// # use arrow2::array::{Array, Int32Array}; + /// # use arrow2::datatypes::DataType; + /// let mut array = Int32Array::from(&[Some(1), None, Some(2)]); + /// array.change_type(DataType::Date32); + /// assert_eq!(array.data_type(), &DataType::Date32); + /// ``` fn change_type(&mut self, data_type: DataType); - /// Returns a new [`Array`] with a different logical type. + /// Creates a new [`Array`] with a different logical type. /// - /// This function is useful to assign a different [`DataType`] to the array. - /// Used to change the arrays' logical type (see example). Unlike, this clones the array - /// in order to return a new array. - /// # Example - /// ```rust,ignore - /// use arrow2::array::Int32Array; - /// use arrow2::datatypes::DataType; + /// This method returns a new array with the specified `DataType`, leaving the original array unchanged. + /// It's useful for creating a new view of the data with a different logical interpretation. + /// + /// # Arguments + /// * `data_type` - The [`DataType`] for the new array. + /// + /// # Returns + /// A new `Box` with the specified `DataType`. /// - /// let array = Int32Array::from(&[Some(1), None, Some(2)]).to(DataType::Date32); - /// assert_eq!( - /// format!("{:?}", array), - /// "Date32[1970-01-02, None, 1970-01-03]" - /// ); - /// ``` /// # Panics - /// Panics iff the `data_type`'s [`PhysicalType`] is not equal to array's `PhysicalType`. - fn to_type(&self, data_type: DataType) -> Box { + /// Panics if the new `data_type`'s [`PhysicalType`] is not equal to the array's current [`PhysicalType`]. + /// + /// # Example + /// ``` + /// # use arrow2::array::Int32Array; + /// # use arrow2::datatypes::DataType; + /// let array = Int32Array::from(&[Some(1), None, Some(2)]); + /// let new_array = array.convert_logical_type(DataType::Date32); + /// assert_eq!(new_array.data_type(), &DataType::Date32); + /// assert_eq!(array.data_type(), &DataType::Int32); // Original array unchanged + /// ``` + fn convert_logical_type(&self, data_type: DataType) -> Box { let mut new = self.to_boxed(); new.change_type(data_type); new @@ -634,14 +653,21 @@ macro_rules! impl_common_array { fn change_type(&mut self, data_type: DataType) { if data_type.to_physical_type() != self.data_type().to_physical_type() { panic!( - "Converting array with logical type {:?} to logical type {:?} failed, physical types do not match: {:?} -> {:?}", + "Cannot change array type from {:?} to {:?}", self.data_type(), - data_type, - self.data_type().to_physical_type(), - data_type.to_physical_type(), + data_type ); } - self.data_type = data_type; + + self.data_type = data_type.clone(); + let mut children = self.direct_children(); + + data_type.direct_children(|child| { + let Some(child_elem) = children.next() else { + return; + }; + child_elem.change_type(child.clone()); + }) } }; } @@ -710,17 +736,15 @@ pub mod dyn_ord; pub mod growable; pub mod ord; -pub(crate) use iterator::ArrayAccessor; -pub use iterator::ArrayValuesIter; - -pub use equal::equal; -pub use fmt::{get_display, get_value_display}; - pub use binary::{BinaryArray, BinaryValueIter, MutableBinaryArray, MutableBinaryValuesArray}; pub use boolean::{BooleanArray, MutableBooleanArray}; pub use dictionary::{DictionaryArray, DictionaryKey, MutableDictionaryArray}; +pub use equal::equal; pub use fixed_size_binary::{FixedSizeBinaryArray, MutableFixedSizeBinaryArray}; pub use fixed_size_list::{FixedSizeListArray, MutableFixedSizeListArray}; +pub use fmt::{get_display, get_value_display}; +pub(crate) use iterator::ArrayAccessor; +pub use iterator::ArrayValuesIter; pub use list::{ListArray, ListValuesIter, MutableListArray}; pub use map::MapArray; pub use null::{MutableNullArray, NullArray}; @@ -729,9 +753,7 @@ pub use struct_::{MutableStructArray, StructArray}; pub use union::UnionArray; pub use utf8::{MutableUtf8Array, MutableUtf8ValuesArray, Utf8Array, Utf8ValuesIter}; -pub(crate) use self::ffi::offset_buffers_children_dictionary; -pub(crate) use self::ffi::FromFfi; -pub(crate) use self::ffi::ToFfi; +pub(crate) use self::ffi::{offset_buffers_children_dictionary, FromFfi, ToFfi}; /// A trait describing the ability of a struct to create itself from a iterator. /// This is similar to [`Extend`], but accepted the creation to error. @@ -774,3 +796,96 @@ pub unsafe trait GenericBinaryArray: Array { /// The offsets of the array fn offsets(&self) -> &[O]; } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + array::{ + BooleanArray, Int32Array, Int64Array, ListArray, StructArray, + }, + datatypes::{DataType, Field, TimeUnit}, + }; + + #[test] + fn test_int32_to_date32() { + let array = Int32Array::from_slice([1, 2, 3]); + let result = array.convert_logical_type(DataType::Date32); + assert_eq!(result.data_type(), &DataType::Date32); + } + + #[test] + fn test_int64_to_timestamp() { + let array = Int64Array::from_slice([1000, 2000, 3000]); + let result = array.convert_logical_type(DataType::Timestamp(TimeUnit::Millisecond, None)); + assert_eq!( + result.data_type(), + &DataType::Timestamp(TimeUnit::Millisecond, None) + ); + } + + #[test] + fn test_boolean_to_boolean() { + let array = BooleanArray::from_slice([true, false, true]); + let result = array.convert_logical_type(DataType::Boolean); + assert_eq!(result.data_type(), &DataType::Boolean); + } + + #[test] + fn test_list_to_list() { + let values = Int32Array::from_slice([1, 2, 3, 4, 5]); + let offsets = vec![0, 2, 5]; + let list_array = ListArray::try_new( + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + offsets.try_into().unwrap(), + Box::new(values), + None, + ) + .unwrap(); + let result = list_array.convert_logical_type(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))); + assert_eq!( + result.data_type(), + &DataType::List(Box::new(Field::new("item", DataType::Int32, true))) + ); + } + + #[test] + fn test_struct_to_struct() { + let boolean = BooleanArray::from_slice([true, false, true]); + let int = Int32Array::from_slice([1, 2, 3]); + let struct_array = StructArray::try_new( + DataType::Struct(vec![ + Field::new("b", DataType::Boolean, true), + Field::new("i", DataType::Int32, true), + ]), + vec![ + Box::new(boolean) as Box, + Box::new(int) as Box, + ], + None, + ) + .unwrap(); + let result = struct_array.convert_logical_type(DataType::Struct(vec![ + Field::new("b", DataType::Boolean, true), + Field::new("i", DataType::Int32, true), + ])); + assert_eq!( + result.data_type(), + &DataType::Struct(vec![ + Field::new("b", DataType::Boolean, true), + Field::new("i", DataType::Int32, true), + ]) + ); + } + + #[test] + #[should_panic] + fn test_invalid_conversion() { + let array = Int32Array::from_slice([1, 2, 3]); + array.convert_logical_type(DataType::Utf8); + } +} diff --git a/src/arrow2/src/array/struct_/mod.rs b/src/arrow2/src/array/struct_/mod.rs index fb2812375c..f096e1aeb6 100644 --- a/src/arrow2/src/array/struct_/mod.rs +++ b/src/arrow2/src/array/struct_/mod.rs @@ -1,3 +1,4 @@ +use std::ops::DerefMut; use crate::{ bitmap::Bitmap, datatypes::{DataType, Field}, @@ -246,6 +247,14 @@ impl StructArray { impl Array for StructArray { impl_common_array!(); + fn direct_children<'a>(&'a mut self) -> Box + 'a> { + let iter = self.values + .iter_mut() + .map(|x| x.deref_mut()); + + Box::new(iter) + } + fn validity(&self) -> Option<&Bitmap> { self.validity.as_ref() } diff --git a/src/arrow2/src/compute/cast/mod.rs b/src/arrow2/src/compute/cast/mod.rs index b48949b215..6ad12f2cb4 100644 --- a/src/arrow2/src/compute/cast/mod.rs +++ b/src/arrow2/src/compute/cast/mod.rs @@ -506,16 +506,16 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu match (from_type, to_type) { (Null, _) | (_, Null) => Ok(new_null_array(to_type.clone(), array.len())), (Extension(_, from_inner, _), Extension(_, to_inner, _)) => { - let new_arr = cast(array.to_type(*from_inner.clone()).as_ref(), to_inner, options)?; - Ok(new_arr.to_type(to_type.clone())) + let new_arr = cast(array.convert_logical_type(*from_inner.clone()).as_ref(), to_inner, options)?; + Ok(new_arr.convert_logical_type(to_type.clone())) } (Extension(_, from_inner, _), _) => { - let new_arr = cast(array.to_type(*from_inner.clone()).as_ref(), to_type, options)?; + let new_arr = cast(array.convert_logical_type(*from_inner.clone()).as_ref(), to_type, options)?; Ok(new_arr) } (_, Extension(_, to_inner, _)) => { let new_arr = cast(array, to_inner, options)?; - Ok(new_arr.to_type(to_type.clone())) + Ok(new_arr.convert_logical_type(to_type.clone())) } (Struct(from_fields), Struct(to_fields)) => match (from_fields.len(), to_fields.len()) { (from_len, to_len) if from_len == to_len => { diff --git a/src/arrow2/src/datatypes/mod.rs b/src/arrow2/src/datatypes/mod.rs index 2debc5a4f2..b5c5b1a8b5 100644 --- a/src/arrow2/src/datatypes/mod.rs +++ b/src/arrow2/src/datatypes/mod.rs @@ -5,13 +5,11 @@ mod field; mod physical_type; mod schema; +use std::{collections::BTreeMap, sync::Arc}; + pub use field::Field; pub use physical_type::*; pub use schema::Schema; - -use std::collections::BTreeMap; -use std::sync::Arc; - use serde::{Deserialize, Serialize}; /// typedef for [BTreeMap] denoting [`Field`]'s and [`Schema`]'s metadata. @@ -19,6 +17,12 @@ pub type Metadata = BTreeMap; /// typedef for [Option<(String, Option)>] descr pub(crate) type Extension = Option<(String, Option)>; +#[allow(unused_imports, reason = "used in documentation")] +use crate::array::Array; + +pub type ArrowDataType = DataType; +pub type ArrowField = Field; + /// The set of supported logical types in this crate. /// /// Each variant uniquely identifies a logical type, which define specific semantics to the data @@ -159,6 +163,55 @@ pub enum DataType { Extension(String, Box, Option), } +impl DataType { + pub fn map(field: impl Into>, keys_sorted: bool) -> Self { + Self::Map(field.into(), keys_sorted) + } + + /// Processes the direct children data types of this DataType. + /// + /// This method is useful for traversing the structure of complex data types. + /// It calls the provided closure for each immediate child data type. + /// + /// This can be used in conjunction with the [`Array::direct_children`] method + /// to process both the data types and the corresponding array data. + /// + /// # Arguments + /// + /// * `processor` - A closure that takes a reference to a DataType as its argument. + /// + /// # Examples + /// + /// ``` + /// use arrow2::datatypes::{DataType, Field}; + /// + /// let struct_type = DataType::Struct(vec![ + /// Field::new("a", DataType::Int32, true), + /// Field::new("b", DataType::Utf8, false), + /// ]); + /// + /// let mut child_types = Vec::new(); + /// struct_type.direct_children(|child_type| { + /// child_types.push(child_type); + /// }); + /// + /// assert_eq!(child_types, vec![&DataType::Int32, &DataType::Utf8]); + /// ``` + pub fn direct_children<'a>(&'a self, mut processor: impl FnMut(&'a DataType)) { + match self { + DataType::List(field) + | DataType::FixedSizeList(field, _) + | DataType::LargeList(field) + | DataType::Map(field, ..) => processor(&field.data_type), + DataType::Struct(fields) | DataType::Union(fields, _, _) => { + fields.iter().for_each(|field| processor(&field.data_type)) + } + DataType::Dictionary(_, value_type, _) => processor(value_type), + _ => {} // Other types don't have child data types + } + } +} + /// Mode of [`DataType::Union`] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum UnionMode { diff --git a/src/arrow2/src/error.rs b/src/arrow2/src/error.rs index 3b7eaadf3e..3df1e19381 100644 --- a/src/arrow2/src/error.rs +++ b/src/arrow2/src/error.rs @@ -55,6 +55,7 @@ impl Error { Self::OutOfSpec(msg.into()) } + #[allow(unused)] pub(crate) fn nyi>(msg: A) -> Self { Self::NotYetImplemented(msg.into()) } diff --git a/src/arrow2/src/io/parquet/mod.rs b/src/arrow2/src/io/parquet/mod.rs index 7fe33f8564..655df953e8 100644 --- a/src/arrow2/src/io/parquet/mod.rs +++ b/src/arrow2/src/io/parquet/mod.rs @@ -22,6 +22,9 @@ impl From for Error { parquet2::error::Error::Transport(msg) => { Error::Io(std::io::Error::new(std::io::ErrorKind::Other, msg)) } + parquet2::error::Error::IoError(msg) => { + Error::Io(std::io::Error::new(std::io::ErrorKind::Other, msg)) + } _ => Error::ExternalFormat(error.to_string()), } } diff --git a/src/arrow2/src/offset.rs b/src/arrow2/src/offset.rs index 80b45d6680..3d7a2aa869 100644 --- a/src/arrow2/src/offset.rs +++ b/src/arrow2/src/offset.rs @@ -1,9 +1,8 @@ //! Contains the declaration of [`Offset`] use std::hint::unreachable_unchecked; -use crate::buffer::Buffer; -use crate::error::Error; pub use crate::types::Offset; +use crate::{buffer::Buffer, error::Error}; /// A wrapper type of [`Vec`] representing the invariants of Arrow's offsets. /// It is guaranteed to (sound to assume that): @@ -144,10 +143,9 @@ impl Offsets { /// Returns the last offset of this container. #[inline] pub fn last(&self) -> &O { - match self.0.last() { - Some(element) => element, - None => unsafe { unreachable_unchecked() }, - } + self.0 + .last() + .unwrap_or_else(|| unsafe { unreachable_unchecked() }) } /// Returns a range (start, end) corresponding to the position `index` @@ -338,7 +336,7 @@ fn try_check_offsets(offsets: &[O]) -> Result<(), Error> { /// * Every element is `>= 0` /// * element at position `i` is >= than element at position `i-1`. #[derive(Clone, PartialEq, Debug)] -pub struct OffsetsBuffer(Buffer); +pub struct OffsetsBuffer(Buffer); impl Default for OffsetsBuffer { #[inline] @@ -347,6 +345,39 @@ impl Default for OffsetsBuffer { } } +impl OffsetsBuffer { + + /// Maps each offset to a new value, creating a new [`Self`]. + /// + /// # Safety + /// + /// This function is marked as `unsafe` because it does not check whether the resulting offsets + /// maintain the invariants required by [`OffsetsBuffer`]. The caller must ensure that: + /// + /// - The resulting offsets are monotonically increasing. + /// - The first offset is zero. + /// - All offsets are non-negative. + /// + /// Violating these invariants can lead to undefined behavior when using the resulting [`OffsetsBuffer`]. + /// + /// # Example + /// + /// ``` + /// # use arrow2::offset::OffsetsBuffer; + /// # let offsets = unsafe { OffsetsBuffer::new_unchecked(vec![0, 2, 5, 7].into()) }; + /// let doubled = unsafe { offsets.map_unchecked(|x| x * 2) }; + /// assert_eq!(doubled.buffer().as_slice(), &[0, 4, 10, 14]); + /// ``` + /// + /// Note that in this example, doubling the offsets maintains the required invariants, + /// but this may not be true for all transformations. + pub unsafe fn map_unchecked(&self, f: impl Fn(O) -> T) -> OffsetsBuffer { + let buffer = self.0.iter().copied().map(f).collect(); + + OffsetsBuffer(buffer) + } +} + impl OffsetsBuffer { /// # Safety /// This is safe iff the invariants of this struct are guaranteed in `offsets`. @@ -401,22 +432,29 @@ impl OffsetsBuffer { *self.last() - *self.first() } + pub fn ranges(&self) -> impl Iterator> + '_ { + self.0.windows(2).map(|w| { + let from = w[0]; + let to = w[1]; + debug_assert!(from <= to, "offsets must be monotonically increasing"); + from..to + }) + } + /// Returns the first offset. #[inline] pub fn first(&self) -> &O { - match self.0.first() { - Some(element) => element, - None => unsafe { unreachable_unchecked() }, - } + self.0 + .first() + .unwrap_or_else(|| unsafe { unreachable_unchecked() }) } /// Returns the last offset. #[inline] pub fn last(&self) -> &O { - match self.0.last() { - Some(element) => element, - None => unsafe { unreachable_unchecked() }, - } + self.0 + .last() + .unwrap_or_else(|| unsafe { unreachable_unchecked() }) } /// Returns a range (start, end) corresponding to the position `index` diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index 077d4b7e83..2202b20d39 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -14,6 +14,7 @@ pub struct DaftPlanningConfig { } impl DaftPlanningConfig { + #[must_use] pub fn from_env() -> Self { let mut cfg = Self::default(); @@ -84,6 +85,7 @@ impl Default for DaftExecutionConfig { } impl DaftExecutionConfig { + #[must_use] pub fn from_env() -> Self { let mut cfg = Self::default(); let aqe_env_var_name = "DAFT_ENABLE_AQE"; diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 44bb95c1b0..27663f4841 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -16,24 +16,34 @@ pub struct PyDaftPlanningConfig { #[pymethods] impl PyDaftPlanningConfig { #[new] + #[must_use] pub fn new() -> Self { Self::default() } #[staticmethod] + #[must_use] pub fn from_env() -> Self { Self { config: Arc::new(DaftPlanningConfig::from_env()), } } - fn with_config_values(&mut self, default_io_config: Option) -> PyResult { + fn with_config_values( + &mut self, + default_io_config: Option, + enable_actor_pool_projections: Option, + ) -> PyResult { let mut config = self.config.as_ref().clone(); if let Some(default_io_config) = default_io_config { config.default_io_config = default_io_config.config; } + if let Some(enable_actor_pool_projections) = enable_actor_pool_projections { + config.enable_actor_pool_projections = enable_actor_pool_projections; + } + Ok(Self { config: Arc::new(config), }) @@ -63,11 +73,13 @@ pub struct PyDaftExecutionConfig { #[pymethods] impl PyDaftExecutionConfig { #[new] + #[must_use] pub fn new() -> Self { Self::default() } #[staticmethod] + #[must_use] pub fn from_env() -> Self { Self { config: Arc::new(DaftExecutionConfig::from_env()), diff --git a/src/common/display/src/ascii.rs b/src/common/display/src/ascii.rs index 4851fac320..4365f1f25b 100644 --- a/src/common/display/src/ascii.rs +++ b/src/common/display/src/ascii.rs @@ -45,6 +45,8 @@ fn fmt_tree_gitstyle<'a, W: fmt::Write + 'a>( s: &'a mut W, level: crate::DisplayLevel, ) -> fmt::Result { + use terminal_size::{terminal_size, Width}; + // Print the current node. // e.g. | | * // | | | @@ -52,7 +54,6 @@ fn fmt_tree_gitstyle<'a, W: fmt::Write + 'a>( let desc = node.display_as(level); let lines = desc.lines(); - use terminal_size::{terminal_size, Width}; let size = terminal_size(); let term_width = if let Some((Width(w), _)) = size { w as usize diff --git a/src/common/display/src/mermaid.rs b/src/common/display/src/mermaid.rs index 41b452b528..fd64d8663c 100644 --- a/src/common/display/src/mermaid.rs +++ b/src/common/display/src/mermaid.rs @@ -102,7 +102,7 @@ where if display.is_empty() { return Err(fmt::Error); } - writeln!(self.output, r#"{}["{}"]"#, id, display)?; + writeln!(self.output, r#"{id}["{display}"]"#)?; self.nodes.insert(node.id(), id); Ok(()) @@ -146,21 +146,18 @@ where } pub fn fmt(&mut self, node: &dyn TreeDisplay) -> fmt::Result { - match &self.subgraph_options { - Some(SubgraphOptions { name, subgraph_id }) => { - writeln!(self.output, r#"subgraph {subgraph_id}["{name}"]"#)?; - self.fmt_node(node)?; - writeln!(self.output, "end")?; - } - None => { - if self.bottom_up { - writeln!(self.output, "flowchart BT")?; - } else { - writeln!(self.output, "flowchart TD")?; - } - - self.fmt_node(node)?; + if let Some(SubgraphOptions { name, subgraph_id }) = &self.subgraph_options { + writeln!(self.output, r#"subgraph {subgraph_id}["{name}"]"#)?; + self.fmt_node(node)?; + writeln!(self.output, "end")?; + } else { + if self.bottom_up { + writeln!(self.output, "flowchart BT")?; + } else { + writeln!(self.output, "flowchart TD")?; } + + self.fmt_node(node)?; } Ok(()) } diff --git a/src/common/display/src/table_display.rs b/src/common/display/src/table_display.rs index 8f0ba51d1a..c3a6efc30b 100644 --- a/src/common/display/src/table_display.rs +++ b/src/common/display/src/table_display.rs @@ -57,6 +57,9 @@ pub fn make_comfy_table>( num_rows: Option, max_col_width: Option, ) -> comfy_table::Table { + const DOTS: &str = "…"; + const TOTAL_ROWS: usize = 10; + let mut table = comfy_table::Table::new(); let default_width_if_no_tty = 120usize; @@ -74,22 +77,17 @@ pub fn make_comfy_table>( let expected_col_width = 18usize; - let max_cols = (((terminal_width + expected_col_width - 1) / expected_col_width) - 1).max(1); - const DOTS: &str = "…"; + let max_cols = (terminal_width.div_ceil(expected_col_width) - 1).max(1); let num_columns = fields.len(); - let head_cols; - let tail_cols; - let total_cols; - if num_columns > max_cols { - head_cols = (max_cols + 1) / 2; - tail_cols = max_cols / 2; - total_cols = head_cols + tail_cols + 1; + let (head_cols, tail_cols, total_cols) = if num_columns > max_cols { + let head_cols = (max_cols + 1) / 2; + let tail_cols = max_cols / 2; + (head_cols, tail_cols, head_cols + tail_cols + 1) } else { - head_cols = num_columns; - tail_cols = 0; - total_cols = head_cols; - } + (num_columns, 0, num_columns) + }; + let mut header = fields .iter() .take(head_cols) @@ -98,12 +96,8 @@ pub fn make_comfy_table>( if tail_cols > 0 { let unseen_cols = num_columns - (head_cols + tail_cols); header.push( - create_table_cell(&format!( - "{DOTS}\n\n({unseen_cols} hidden)", - DOTS = DOTS, - unseen_cols = unseen_cols - )) - .set_alignment(comfy_table::CellAlignment::Center), + create_table_cell(&format!("{DOTS}\n\n({unseen_cols} hidden)")) + .set_alignment(comfy_table::CellAlignment::Center), ); header.extend( fields @@ -118,17 +112,11 @@ pub fn make_comfy_table>( { table.set_header(header); let len = num_rows.expect("if columns are set, so should `num_rows`"); - const TOTAL_ROWS: usize = 10; - let head_rows; - let tail_rows; - - if len > TOTAL_ROWS { - head_rows = TOTAL_ROWS / 2; - tail_rows = TOTAL_ROWS / 2; + let (head_rows, tail_rows) = if len > TOTAL_ROWS { + (TOTAL_ROWS / 2, TOTAL_ROWS / 2) } else { - head_rows = len; - tail_rows = 0; - } + (len, 0) + }; for i in 0..head_rows { let all_cols = columns diff --git a/src/common/display/src/tree.rs b/src/common/display/src/tree.rs index b2e5e0c0e4..4380323983 100644 --- a/src/common/display/src/tree.rs +++ b/src/common/display/src/tree.rs @@ -17,7 +17,10 @@ pub trait TreeDisplay { fn id(&self) -> String { let mut s = String::new(); s.push_str(&self.get_name()); - s.push_str(&format!("{:p}", self as *const Self as *const ())); + s.push_str(&format!( + "{:p}", + std::ptr::from_ref::(self).cast::<()>() + )); s } diff --git a/src/common/display/src/utils.rs b/src/common/display/src/utils.rs index 082ec0a883..cf86588c69 100644 --- a/src/common/display/src/utils.rs +++ b/src/common/display/src/utils.rs @@ -1,9 +1,11 @@ +#[must_use] pub fn bytes_to_human_readable(byte_count: usize) -> String { + const UNITS: &[&str] = &["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"]; + if byte_count == 0 { return "0 B".to_string(); } - const UNITS: &[&str] = &["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"]; let base = byte_count.ilog2() / 10; // log2(1024) = 10 let index = std::cmp::min(base, (UNITS.len() - 1) as u32); diff --git a/src/common/error/Cargo.toml b/src/common/error/Cargo.toml index b64ef5c901..29e9a1123e 100644 --- a/src/common/error/Cargo.toml +++ b/src/common/error/Cargo.toml @@ -1,10 +1,13 @@ [dependencies] -arrow2 = {workspace = true} +arrow2 = {workspace = true, features = ["io_parquet"]} pyo3 = {workspace = true, optional = true} regex = {workspace = true} serde_json = {workspace = true} thiserror = {workspace = true} +[dev-dependencies] +parquet2 = {workspace = true} + [features] python = ["dep:pyo3"] diff --git a/src/common/error/src/error.rs b/src/common/error/src/error.rs index 0513d3e112..31cb71b7ba 100644 --- a/src/common/error/src/error.rs +++ b/src/common/error/src/error.rs @@ -14,7 +14,7 @@ pub enum DaftError { #[error("DaftError::ComputeError {0}")] ComputeError(String), #[error("DaftError::ArrowError {0}")] - ArrowError(#[from] arrow2::error::Error), + ArrowError(arrow2::error::Error), #[error("DaftError::ValueError {0}")] ValueError(String), #[cfg(feature = "python")] @@ -47,3 +47,51 @@ pub enum DaftError { #[error("DaftError::RegexError {0}")] RegexError(#[from] regex::Error), } + +impl From for DaftError { + fn from(error: arrow2::error::Error) -> Self { + match error { + arrow2::error::Error::Io(_) => Self::ByteStreamError(error.into()), + _ => Self::ArrowError(error), + } + } +} + +#[cfg(test)] +mod tests { + use std::io::ErrorKind; + + use super::*; + + #[test] + fn test_arrow_io_error_conversion() { + // Ensure that arrow2 IO errors get converted into transient Byte Stream errors. + let error_message = "IO error occurred"; + let arrow_io_error = + arrow2::error::Error::Io(std::io::Error::new(ErrorKind::Other, error_message)); + let daft_error: DaftError = arrow_io_error.into(); + match daft_error { + DaftError::ByteStreamError(e) => { + assert_eq!(e.to_string(), format!("Io error: {error_message}")); + } + _ => panic!("Expected ByteStreamError"), + } + } + + #[test] + fn test_parquet_io_error_conversion() { + // Ensure that parquet2 IO errors get converted into transient Byte Stream errors. + let error_message = "IO error occurred"; + let parquet_io_error = + parquet2::error::Error::IoError(std::io::Error::new(ErrorKind::Other, error_message)); + let arrow_error: arrow2::error::Error = parquet_io_error.into(); + //let arrow_error = arrow2::error::Error::from(parquet_io_error); + let daft_error: DaftError = arrow_error.into(); + match daft_error { + DaftError::ByteStreamError(e) => { + assert_eq!(e.to_string(), format!("Io error: {error_message}")); + } + _ => panic!("Expected ByteStreamError"), + } + } +} diff --git a/src/common/file-formats/src/file_format.rs b/src/common/file-formats/src/file_format.rs index c5e553aceb..15a7684813 100644 --- a/src/common/file-formats/src/file_format.rs +++ b/src/common/file-formats/src/file_format.rs @@ -39,7 +39,7 @@ impl FromStr for FileFormat { type Err = DaftError; fn from_str(file_format: &str) -> DaftResult { - use FileFormat::*; + use FileFormat::{Csv, Database, Json, Parquet}; if file_format.trim().eq_ignore_ascii_case("parquet") { Ok(Parquet) @@ -51,8 +51,7 @@ impl FromStr for FileFormat { Ok(Database) } else { Err(DaftError::TypeError(format!( - "FileFormat {} not supported!", - file_format + "FileFormat {file_format} not supported!" ))) } } diff --git a/src/common/file-formats/src/file_format_config.rs b/src/common/file-formats/src/file_format_config.rs index fe659bc444..5d166ddbeb 100644 --- a/src/common/file-formats/src/file_format_config.rs +++ b/src/common/file-formats/src/file_format_config.rs @@ -25,24 +25,25 @@ pub enum FileFormatConfig { } impl FileFormatConfig { + #[must_use] pub fn file_format(&self) -> FileFormat { self.into() } + #[must_use] pub fn var_name(&self) -> &'static str { - use FileFormatConfig::*; - match self { - Parquet(_) => "Parquet", - Csv(_) => "Csv", - Json(_) => "Json", + Self::Parquet(_) => "Parquet", + Self::Csv(_) => "Csv", + Self::Json(_) => "Json", #[cfg(feature = "python")] - Database(_) => "Database", + Self::Database(_) => "Database", #[cfg(feature = "python")] - PythonFunction => "PythonFunction", + Self::PythonFunction => "PythonFunction", } } + #[must_use] pub fn multiline_display(&self) -> Vec { match self { Self::Parquet(source) => source.multiline_display(), @@ -76,6 +77,7 @@ pub struct ParquetSourceConfig { } impl ParquetSourceConfig { + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!( @@ -101,7 +103,7 @@ impl ParquetSourceConfig { rg.as_ref() .map(|rg| { rg.iter() - .map(|i| i.to_string()) + .map(std::string::ToString::to_string) .collect::>() .join(",") }) @@ -115,6 +117,17 @@ impl ParquetSourceConfig { } } +impl Default for ParquetSourceConfig { + fn default() -> Self { + Self { + coerce_int96_timestamp_unit: TimeUnit::Nanoseconds, + field_id_mapping: None, + row_groups: None, + chunk_size: None, + } + } +} + #[cfg(feature = "python")] #[pymethods] impl ParquetSourceConfig { @@ -128,13 +141,10 @@ impl ParquetSourceConfig { ) -> Self { Self { coerce_int96_timestamp_unit: coerce_int96_timestamp_unit - .unwrap_or(TimeUnit::Nanoseconds.into()) + .unwrap_or_else(|| TimeUnit::Nanoseconds.into()) .into(), - field_id_mapping: field_id_mapping.map(|map| { - Arc::new(BTreeMap::from_iter( - map.into_iter().map(|(k, v)| (k, v.field)), - )) - }), + field_id_mapping: field_id_mapping + .map(|map| Arc::new(map.into_iter().map(|(k, v)| (k, v.field)).collect())), row_groups, chunk_size, } @@ -164,31 +174,32 @@ pub struct CsvSourceConfig { } impl CsvSourceConfig { + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; if let Some(delimiter) = self.delimiter { - res.push(format!("Delimiter = {}", delimiter)); + res.push(format!("Delimiter = {delimiter}")); } res.push(format!("Has headers = {}", self.has_headers)); res.push(format!("Double quote = {}", self.double_quote)); if let Some(quote) = self.quote { - res.push(format!("Quote = {}", quote)); + res.push(format!("Quote = {quote}")); } if let Some(escape_char) = self.escape_char { - res.push(format!("Escape char = {}", escape_char)); + res.push(format!("Escape char = {escape_char}")); } if let Some(comment) = self.comment { - res.push(format!("Comment = {}", comment)); + res.push(format!("Comment = {comment}")); } res.push(format!( "Allow_variable_columns = {}", self.allow_variable_columns )); if let Some(buffer_size) = self.buffer_size { - res.push(format!("Buffer size = {}", buffer_size)); + res.push(format!("Buffer size = {buffer_size}")); } if let Some(chunk_size) = self.chunk_size { - res.push(format!("Chunk size = {}", chunk_size)); + res.push(format!("Chunk size = {chunk_size}")); } res } @@ -243,6 +254,7 @@ pub struct JsonSourceConfig { } impl JsonSourceConfig { + #[must_use] pub fn new_internal(buffer_size: Option, chunk_size: Option) -> Self { Self { buffer_size, @@ -250,13 +262,14 @@ impl JsonSourceConfig { } } + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; if let Some(buffer_size) = self.buffer_size { - res.push(format!("Buffer size = {}", buffer_size)); + res.push(format!("Buffer size = {buffer_size}")); } if let Some(chunk_size) = self.chunk_size { - res.push(format!("Chunk size = {}", chunk_size)); + res.push(format!("Chunk size = {chunk_size}")); } res } @@ -323,10 +336,12 @@ impl Hash for DatabaseSourceConfig { #[cfg(feature = "python")] impl DatabaseSourceConfig { + #[must_use] pub fn new_internal(sql: String, conn: PyObject) -> Self { Self { sql, conn } } + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("SQL = \"{}\"", self.sql)); diff --git a/src/common/io-config/src/azure.rs b/src/common/io-config/src/azure.rs index 1aac69e17a..ba264bdca6 100644 --- a/src/common/io-config/src/azure.rs +++ b/src/common/io-config/src/azure.rs @@ -38,28 +38,29 @@ impl Default for AzureConfig { } impl AzureConfig { + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; if let Some(storage_account) = &self.storage_account { - res.push(format!("Storage account = {}", storage_account)); + res.push(format!("Storage account = {storage_account}")); } if let Some(access_key) = &self.access_key { - res.push(format!("Access key = {}", access_key)); + res.push(format!("Access key = {access_key}")); } if let Some(sas_token) = &self.sas_token { - res.push(format!("Shared Access Signature = {}", sas_token)); + res.push(format!("Shared Access Signature = {sas_token}")); } if let Some(bearer_token) = &self.bearer_token { - res.push(format!("Bearer Token = {}", bearer_token)); + res.push(format!("Bearer Token = {bearer_token}")); } if let Some(tenant_id) = &self.tenant_id { - res.push(format!("Tenant ID = {}", tenant_id)); + res.push(format!("Tenant ID = {tenant_id}")); } if let Some(client_id) = &self.client_id { - res.push(format!("Client ID = {}", client_id)); + res.push(format!("Client ID = {client_id}")); } if let Some(client_secret) = &self.client_secret { - res.push(format!("Client Secret = {}", client_secret)); + res.push(format!("Client Secret = {client_secret}")); } res.push(format!( "Use Fabric Endpoint = {}", @@ -67,7 +68,7 @@ impl AzureConfig { )); res.push(format!("Anonymous = {}", self.anonymous)); if let Some(endpoint_url) = &self.endpoint_url { - res.push(format!("Endpoint URL = {}", endpoint_url)); + res.push(format!("Endpoint URL = {endpoint_url}")); } res.push(format!("Use SSL = {}", self.use_ssl)); res diff --git a/src/common/io-config/src/config.rs b/src/common/io-config/src/config.rs index 7d9ce2230e..a0af9c3caa 100644 --- a/src/common/io-config/src/config.rs +++ b/src/common/io-config/src/config.rs @@ -12,6 +12,7 @@ pub struct IOConfig { } impl IOConfig { + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!( diff --git a/src/common/io-config/src/gcs.rs b/src/common/io-config/src/gcs.rs index cdbf57671d..cd5e8628a3 100644 --- a/src/common/io-config/src/gcs.rs +++ b/src/common/io-config/src/gcs.rs @@ -13,10 +13,11 @@ pub struct GCSConfig { } impl GCSConfig { + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; if let Some(project_id) = &self.project_id { - res.push(format!("Project ID = {}", project_id)); + res.push(format!("Project ID = {project_id}")); } res.push(format!("Anonymous = {}", self.anonymous)); res diff --git a/src/common/io-config/src/http.rs b/src/common/io-config/src/http.rs index 6241de3028..554de2cec9 100644 --- a/src/common/io-config/src/http.rs +++ b/src/common/io-config/src/http.rs @@ -22,17 +22,18 @@ impl Default for HTTPConfig { impl HTTPConfig { pub fn new>(bearer_token: Option) -> Self { Self { - bearer_token: bearer_token.map(|t| t.into()), + bearer_token: bearer_token.map(std::convert::Into::into), ..Default::default() } } } impl HTTPConfig { + #[must_use] pub fn multiline_display(&self) -> Vec { let mut v = vec![format!("user_agent = {}", self.user_agent)]; if let Some(bearer_token) = &self.bearer_token { - v.push(format!("bearer_token = {}", bearer_token)); + v.push(format!("bearer_token = {bearer_token}")); } v @@ -52,8 +53,7 @@ impl Display for HTTPConfig { write!( f, " - bearer_token: {}", - bearer_token + bearer_token: {bearer_token}" ) } else { Ok(()) diff --git a/src/common/io-config/src/lib.rs b/src/common/io-config/src/lib.rs index 46e4de278d..ae620a112d 100644 --- a/src/common/io-config/src/lib.rs +++ b/src/common/io-config/src/lib.rs @@ -27,6 +27,7 @@ pub use crate::{ pub struct ObfuscatedString(Secret); impl ObfuscatedString { + #[must_use] pub fn as_string(&self) -> &String { self.0.expose_secret() } @@ -42,7 +43,7 @@ impl Eq for ObfuscatedString {} impl Hash for ObfuscatedString { fn hash(&self, state: &mut H) { - self.0.expose_secret().hash(state) + self.0.expose_secret().hash(state); } } diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index 6ae67a2443..2632ebecdd 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -72,6 +72,7 @@ pub struct S3Credentials { } /// Create configurations to be used when accessing Azure Blob Storage. +/// /// To authenticate with Microsoft Entra ID, `tenant_id`, `client_id`, and `client_secret` must be provided. /// If no credentials are provided, Daft will attempt to fetch credentials from the environment. /// @@ -98,6 +99,7 @@ pub struct AzureConfig { } /// Create configurations to be used when accessing Google Cloud Storage. +/// /// Credentials may be provided directly with the `credentials` parameter, or set with the `GOOGLE_APPLICATION_CREDENTIALS_JSON` or `GOOGLE_APPLICATION_CREDENTIALS` environment variables. /// /// Args: @@ -148,6 +150,7 @@ pub struct HTTPConfig { #[pymethods] impl IOConfig { #[new] + #[must_use] pub fn new( s3: Option, azure: Option, @@ -164,6 +167,7 @@ impl IOConfig { } } + #[must_use] pub fn replace( &self, s3: Option, @@ -173,14 +177,18 @@ impl IOConfig { ) -> Self { Self { config: config::IOConfig { - s3: s3.map(|s3| s3.config).unwrap_or(self.config.s3.clone()), + s3: s3 + .map(|s3| s3.config) + .unwrap_or_else(|| self.config.s3.clone()), azure: azure .map(|azure| azure.config) - .unwrap_or(self.config.azure.clone()), - gcs: gcs.map(|gcs| gcs.config).unwrap_or(self.config.gcs.clone()), + .unwrap_or_else(|| self.config.azure.clone()), + gcs: gcs + .map(|gcs| gcs.config) + .unwrap_or_else(|| self.config.gcs.clone()), http: http .map(|http| http.config) - .unwrap_or(self.config.http.clone()), + .unwrap_or_else(|| self.config.http.clone()), }, } } @@ -279,8 +287,10 @@ impl S3Config { region_name: region_name.or(def.region_name), endpoint_url: endpoint_url.or(def.endpoint_url), key_id: key_id.or(def.key_id), - session_token: session_token.map(|v| v.into()).or(def.session_token), - access_key: access_key.map(|v| v.into()).or(def.access_key), + session_token: session_token + .map(std::convert::Into::into) + .or(def.session_token), + access_key: access_key.map(std::convert::Into::into).or(def.access_key), credentials_provider: credentials_provider .map(|p| { Ok::<_, PyErr>(Box::new(PyS3CredentialsProvider::new(p)?) @@ -339,10 +349,10 @@ impl S3Config { endpoint_url: endpoint_url.or_else(|| self.config.endpoint_url.clone()), key_id: key_id.or_else(|| self.config.key_id.clone()), session_token: session_token - .map(|v| v.into()) + .map(std::convert::Into::into) .or_else(|| self.config.session_token.clone()), access_key: access_key - .map(|v| v.into()) + .map(std::convert::Into::into) .or_else(|| self.config.access_key.clone()), credentials_provider: credentials_provider .map(|p| { @@ -416,7 +426,7 @@ impl S3Config { .config .session_token .as_ref() - .map(|v| v.as_string()) + .map(super::ObfuscatedString::as_string) .cloned()) } @@ -427,7 +437,7 @@ impl S3Config { .config .access_key .as_ref() - .map(|v| v.as_string()) + .map(super::ObfuscatedString::as_string) .cloned()) } @@ -671,7 +681,7 @@ impl S3CredentialsProvider for PyS3CredentialsProvider { } fn dyn_hash(&self, mut state: &mut dyn Hasher) { - self.hash(&mut state) + self.hash(&mut state); } } @@ -679,6 +689,7 @@ impl S3CredentialsProvider for PyS3CredentialsProvider { impl AzureConfig { #[allow(clippy::too_many_arguments)] #[new] + #[must_use] pub fn new( storage_account: Option, access_key: Option, @@ -696,12 +707,14 @@ impl AzureConfig { Self { config: crate::AzureConfig { storage_account: storage_account.or(def.storage_account), - access_key: access_key.map(|v| v.into()).or(def.access_key), + access_key: access_key.map(std::convert::Into::into).or(def.access_key), sas_token: sas_token.or(def.sas_token), bearer_token: bearer_token.or(def.bearer_token), tenant_id: tenant_id.or(def.tenant_id), client_id: client_id.or(def.client_id), - client_secret: client_secret.map(|v| v.into()).or(def.client_secret), + client_secret: client_secret + .map(std::convert::Into::into) + .or(def.client_secret), use_fabric_endpoint: use_fabric_endpoint.unwrap_or(def.use_fabric_endpoint), anonymous: anonymous.unwrap_or(def.anonymous), endpoint_url: endpoint_url.or(def.endpoint_url), @@ -711,6 +724,7 @@ impl AzureConfig { } #[allow(clippy::too_many_arguments)] + #[must_use] pub fn replace( &self, storage_account: Option, @@ -729,14 +743,14 @@ impl AzureConfig { config: crate::AzureConfig { storage_account: storage_account.or_else(|| self.config.storage_account.clone()), access_key: access_key - .map(|v| v.into()) + .map(std::convert::Into::into) .or_else(|| self.config.access_key.clone()), sas_token: sas_token.or_else(|| self.config.sas_token.clone()), bearer_token: bearer_token.or_else(|| self.config.bearer_token.clone()), tenant_id: tenant_id.or_else(|| self.config.tenant_id.clone()), client_id: client_id.or_else(|| self.config.client_id.clone()), client_secret: client_secret - .map(|v| v.into()) + .map(std::convert::Into::into) .or_else(|| self.config.client_secret.clone()), use_fabric_endpoint: use_fabric_endpoint.unwrap_or(self.config.use_fabric_endpoint), anonymous: anonymous.unwrap_or(self.config.anonymous), @@ -763,7 +777,7 @@ impl AzureConfig { .config .access_key .as_ref() - .map(|v| v.as_string()) + .map(super::ObfuscatedString::as_string) .cloned()) } @@ -795,7 +809,7 @@ impl AzureConfig { .config .client_secret .as_ref() - .map(|v| v.as_string()) + .map(super::ObfuscatedString::as_string) .cloned()) } @@ -828,6 +842,7 @@ impl AzureConfig { impl GCSConfig { #[allow(clippy::too_many_arguments)] #[new] + #[must_use] pub fn new( project_id: Option, credentials: Option, @@ -838,13 +853,16 @@ impl GCSConfig { Self { config: crate::GCSConfig { project_id: project_id.or(def.project_id), - credentials: credentials.map(|v| v.into()).or(def.credentials), + credentials: credentials + .map(std::convert::Into::into) + .or(def.credentials), token: token.or(def.token), anonymous: anonymous.unwrap_or(def.anonymous), }, } } + #[must_use] pub fn replace( &self, project_id: Option, @@ -856,7 +874,7 @@ impl GCSConfig { config: crate::GCSConfig { project_id: project_id.or_else(|| self.config.project_id.clone()), credentials: credentials - .map(|v| v.into()) + .map(std::convert::Into::into) .or_else(|| self.config.credentials.clone()), token: token.or_else(|| self.config.token.clone()), anonymous: anonymous.unwrap_or(self.config.anonymous), @@ -906,6 +924,7 @@ impl From for IOConfig { #[pymethods] impl HTTPConfig { #[new] + #[must_use] pub fn new(bearer_token: Option) -> Self { Self { config: crate::HTTPConfig::new(bearer_token), diff --git a/src/common/io-config/src/s3.rs b/src/common/io-config/src/s3.rs index cb02fad7fb..41db6c8b29 100644 --- a/src/common/io-config/src/s3.rs +++ b/src/common/io-config/src/s3.rs @@ -67,7 +67,7 @@ impl Eq for Box {} impl Hash for Box { fn hash(&self, state: &mut H) { - self.dyn_hash(state) + self.dyn_hash(state); } } @@ -83,28 +83,29 @@ impl ProvideCredentials for Box { } impl S3Config { + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; if let Some(region_name) = &self.region_name { - res.push(format!("Region name = {}", region_name)); + res.push(format!("Region name = {region_name}")); } if let Some(endpoint_url) = &self.endpoint_url { - res.push(format!("Endpoint URL = {}", endpoint_url)); + res.push(format!("Endpoint URL = {endpoint_url}")); } if let Some(key_id) = &self.key_id { - res.push(format!("Key ID = {}", key_id)); + res.push(format!("Key ID = {key_id}")); } if let Some(session_token) = &self.session_token { - res.push(format!("Session token = {}", session_token)); + res.push(format!("Session token = {session_token}")); } if let Some(access_key) = &self.access_key { - res.push(format!("Access key = {}", access_key)); + res.push(format!("Access key = {access_key}")); } if let Some(credentials_provider) = &self.credentials_provider { - res.push(format!("Credentials provider = {:?}", credentials_provider)); + res.push(format!("Credentials provider = {credentials_provider:?}")); } if let Some(buffer_time) = &self.buffer_time { - res.push(format!("Buffer time = {}", buffer_time)); + res.push(format!("Buffer time = {buffer_time}")); } res.push(format!( "Max connections = {}", @@ -118,7 +119,7 @@ impl S3Config { res.push(format!("Read timeout ms = {}", self.read_timeout_ms)); res.push(format!("Max retries = {}", self.num_tries)); if let Some(retry_mode) = &self.retry_mode { - res.push(format!("Retry mode = {}", retry_mode)); + res.push(format!("Retry mode = {retry_mode}")); } res.push(format!("Anonymous = {}", self.anonymous)); res.push(format!("Use SSL = {}", self.use_ssl)); @@ -130,7 +131,7 @@ impl S3Config { self.force_virtual_addressing )); if let Some(name) = &self.profile_name { - res.push(format!("Profile Name = {}", name)); + res.push(format!("Profile Name = {name}")); } res } @@ -214,13 +215,14 @@ impl Display for S3Config { } impl S3Credentials { + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("Key ID = {}", self.key_id)); res.push(format!("Access key = {}", self.access_key)); if let Some(session_token) = &self.session_token { - res.push(format!("Session token = {}", session_token)); + res.push(format!("Session token = {session_token}")); } if let Some(expiry) = &self.expiry { let expiry: DateTime = (*expiry).into(); diff --git a/src/common/py-serde/src/python.rs b/src/common/py-serde/src/python.rs index 63c20a4699..e634743f2d 100644 --- a/src/common/py-serde/src/python.rs +++ b/src/common/py-serde/src/python.rs @@ -7,8 +7,8 @@ use serde::{ ser::Error as SerError, Deserializer, Serializer, }; -#[cfg(feature = "python")] +#[cfg(feature = "python")] pub fn serialize_py_object(obj: &PyObject, s: S) -> Result where S: Serializer, @@ -23,10 +23,9 @@ where s.serialize_bytes(bytes.as_slice()) } #[cfg(feature = "python")] - struct PyObjectVisitor; -#[cfg(feature = "python")] +#[cfg(feature = "python")] impl<'de> Visitor<'de> for PyObjectVisitor { type Value = PyObject; diff --git a/src/common/resource-request/src/lib.rs b/src/common/resource-request/src/lib.rs index a422c91475..0b27d4a054 100644 --- a/src/common/resource-request/src/lib.rs +++ b/src/common/resource-request/src/lib.rs @@ -25,6 +25,7 @@ pub struct ResourceRequest { } impl ResourceRequest { + #[must_use] pub fn new_internal( num_cpus: Option, num_gpus: Option, @@ -37,10 +38,12 @@ impl ResourceRequest { } } + #[must_use] pub fn default_cpu() -> Self { Self::new_internal(Some(1.0), None, None) } + #[must_use] pub fn or_num_cpus(&self, num_cpus: Option) -> Self { Self { num_cpus: self.num_cpus.or(num_cpus), @@ -48,6 +51,7 @@ impl ResourceRequest { } } + #[must_use] pub fn or_num_gpus(&self, num_gpus: Option) -> Self { Self { num_gpus: self.num_gpus.or(num_gpus), @@ -55,6 +59,7 @@ impl ResourceRequest { } } + #[must_use] pub fn or_memory_bytes(&self, memory_bytes: Option) -> Self { Self { memory_bytes: self.memory_bytes.or(memory_bytes), @@ -62,20 +67,22 @@ impl ResourceRequest { } } + #[must_use] pub fn has_any(&self) -> bool { self.num_cpus.is_some() || self.num_gpus.is_some() || self.memory_bytes.is_some() } + #[must_use] pub fn multiline_display(&self) -> Vec { let mut requests = vec![]; if let Some(num_cpus) = self.num_cpus { - requests.push(format!("num_cpus = {}", num_cpus)); + requests.push(format!("num_cpus = {num_cpus}")); } if let Some(num_gpus) = self.num_gpus { - requests.push(format!("num_gpus = {}", num_gpus)); + requests.push(format!("num_gpus = {num_gpus}")); } if let Some(memory_bytes) = self.memory_bytes { - requests.push(format!("memory_bytes = {}", memory_bytes)); + requests.push(format!("memory_bytes = {memory_bytes}")); } requests } @@ -85,6 +92,7 @@ impl ResourceRequest { /// /// Currently, this returns true unless one resource request has a non-zero CPU request and the other task has a /// non-zero GPU request. + #[must_use] pub fn is_pipeline_compatible_with(&self, other: &Self) -> bool { let self_num_cpus = self.num_cpus; let self_num_gpus = self.num_gpus; @@ -100,6 +108,7 @@ impl ResourceRequest { } } + #[must_use] pub fn max(&self, other: &Self) -> Self { let max_num_cpus = lift(float_max, self.num_cpus, other.num_cpus); let max_num_gpus = lift(float_max, self.num_gpus, other.num_gpus); @@ -112,9 +121,10 @@ impl ResourceRequest { ) -> Self { resource_requests .iter() - .fold(Default::default(), |acc, e| acc.max(e.as_ref())) + .fold(Self::default(), |acc, e| acc.max(e.as_ref())) } + #[must_use] pub fn multiply(&self, factor: f64) -> Self { Self::new_internal( self.num_cpus.map(|x| x * factor), @@ -148,7 +158,7 @@ impl Hash for ResourceRequest { fn hash(&self, state: &mut H) { self.num_cpus.map(FloatWrapper).hash(state); self.num_gpus.map(FloatWrapper).hash(state); - self.memory_bytes.hash(state) + self.memory_bytes.hash(state); } } @@ -174,12 +184,14 @@ fn float_max(left: f64, right: f64) -> f64 { #[pymethods] impl ResourceRequest { #[new] + #[must_use] pub fn new(num_cpus: Option, num_gpus: Option, memory_bytes: Option) -> Self { Self::new_internal(num_cpus, num_gpus, memory_bytes) } /// Take a field-wise max of the list of resource requests. #[staticmethod] + #[must_use] pub fn max_resources(resource_requests: Vec) -> Self { Self::max_all(&resource_requests.iter().collect::>()) } @@ -199,6 +211,7 @@ impl ResourceRequest { Ok(self.memory_bytes) } + #[must_use] pub fn with_num_cpus(&self, num_cpus: Option) -> Self { Self { num_cpus, @@ -206,6 +219,7 @@ impl ResourceRequest { } } + #[must_use] pub fn with_num_gpus(&self, num_gpus: Option) -> Self { Self { num_gpus, @@ -213,6 +227,7 @@ impl ResourceRequest { } } + #[must_use] pub fn with_memory_bytes(&self, memory_bytes: Option) -> Self { Self { memory_bytes, @@ -237,7 +252,7 @@ impl ResourceRequest { } fn __repr__(&self) -> PyResult { - Ok(format!("{:?}", self)) + Ok(format!("{self:?}")) } } impl_bincode_py_state_serialization!(ResourceRequest); diff --git a/src/common/system-info/src/lib.rs b/src/common/system-info/src/lib.rs index 3ef6ba180e..37cd75232f 100644 --- a/src/common/system-info/src/lib.rs +++ b/src/common/system-info/src/lib.rs @@ -23,14 +23,17 @@ impl Default for SystemInfo { #[pymethods] impl SystemInfo { #[new] + #[must_use] pub fn new() -> Self { - Default::default() + Self::default() } + #[must_use] pub fn cpu_count(&self) -> Option { self.info.physical_core_count().map(|x| x as u64) } + #[must_use] pub fn total_memory(&self) -> u64 { if let Some(cgroup) = self.info.cgroup_limits() { cgroup.total_memory diff --git a/src/common/tracing/src/lib.rs b/src/common/tracing/src/lib.rs index 699ad6816c..9db19c016c 100644 --- a/src/common/tracing/src/lib.rs +++ b/src/common/tracing/src/lib.rs @@ -12,36 +12,39 @@ lazy_static! { pub fn init_tracing(enable_chrome_trace: bool) { use std::sync::atomic::Ordering; - if !TRACING_INIT.swap(true, Ordering::Relaxed) { - if enable_chrome_trace { - let mut mg = CHROME_GUARD_HANDLE.lock().unwrap(); - assert!( - mg.is_none(), - "Expected chrome flush guard to be None on init" - ); - let (chrome_layer, _guard) = ChromeLayerBuilder::new() - .trace_style(tracing_chrome::TraceStyle::Threaded) - .name_fn(Box::new(|event_or_span| { - match event_or_span { - tracing_chrome::EventOrSpan::Event(ev) => ev.metadata().name().into(), - tracing_chrome::EventOrSpan::Span(s) => { - // TODO: this is where we should extract out fields (such as node id to show the different pipelines) - s.name().into() - } - } - })) - .build(); - tracing::subscriber::set_global_default( - tracing_subscriber::registry().with(chrome_layer), - ) - .unwrap(); - *mg = Some(_guard); - } else { - // Do nothing for now - } - } else { - panic!("Cannot init tracing, already initialized!") + + assert!( + !TRACING_INIT.swap(true, Ordering::Relaxed), + "Cannot init tracing, already initialized!" + ); + + if !enable_chrome_trace { + return; // Do nothing for now } + + let mut mg = CHROME_GUARD_HANDLE.lock().unwrap(); + assert!( + mg.is_none(), + "Expected chrome flush guard to be None on init" + ); + + let (chrome_layer, guard) = ChromeLayerBuilder::new() + .trace_style(tracing_chrome::TraceStyle::Threaded) + .name_fn(Box::new(|event_or_span| { + match event_or_span { + tracing_chrome::EventOrSpan::Event(ev) => ev.metadata().name().into(), + tracing_chrome::EventOrSpan::Span(s) => { + // TODO: this is where we should extract out fields (such as node id to show the different pipelines) + s.name().into() + } + } + })) + .build(); + + tracing::subscriber::set_global_default(tracing_subscriber::registry().with(chrome_layer)) + .unwrap(); + + *mg = Some(guard); } pub fn refresh_chrome_trace() -> bool { diff --git a/src/common/treenode/src/lib.rs b/src/common/treenode/src/lib.rs index 2507de4986..f749040895 100644 --- a/src/common/treenode/src/lib.rs +++ b/src/common/treenode/src/lib.rs @@ -517,7 +517,7 @@ pub trait TreeNodeRewriter: Sized { } /// Controls how [`TreeNode`] recursions should proceed. -#[derive(Debug, PartialEq, Clone, Copy)] +#[derive(Debug, PartialEq, Clone, Copy, Eq)] pub enum TreeNodeRecursion { /// Continue recursion with the next node. Continue, @@ -585,7 +585,7 @@ impl TreeNodeRecursion { /// - [`TreeNode::transform_down`], /// - [`TreeNode::transform_up`], /// - [`TreeNode::transform_down_up`] -#[derive(PartialEq, Debug)] +#[derive(PartialEq, Eq, Debug)] pub struct Transformed { pub data: T, pub transformed: bool, @@ -623,6 +623,7 @@ impl Transformed { } /// Returns self if self is transformed, otherwise returns other. + #[must_use] pub fn or(self, other: Self) -> Self { if self.transformed { self @@ -840,7 +841,9 @@ impl TransformedResult for Result> { } /// Helper trait for implementing [`TreeNode`] that have children stored as -/// `Arc`s. If some trait object, such as `dyn T`, implements this trait, +/// `Arc`s. +/// +/// If some trait object, such as `dyn T`, implements this trait, /// its related `Arc` will automatically implement [`TreeNode`]. pub trait DynTreeNode { /// Returns all children of the specified `TreeNode`. @@ -864,7 +867,9 @@ impl TreeNode for Arc { f: F, ) -> Result> { let children = self.arc_children(); - if !children.is_empty() { + if children.is_empty() { + Ok(Transformed::no(self)) + } else { let new_children = children.into_iter().map_until_stop_and_collect(f)?; // Propagate up `new_children.transformed` and `new_children.tnr` // along with the node containing transformed children. @@ -873,14 +878,14 @@ impl TreeNode for Arc { } else { Ok(Transformed::new(self, false, new_children.tnr)) } - } else { - Ok(Transformed::no(self)) } } } /// Instead of implementing [`TreeNode`], it's recommended to implement a [`ConcreteTreeNode`] for -/// trees that contain nodes with payloads. This approach ensures safe execution of algorithms +/// trees that contain nodes with payloads. +/// +/// This approach ensures safe execution of algorithms /// involving payloads, by enforcing rules for detaching and reattaching child nodes. pub trait ConcreteTreeNode: Sized { /// Provides read-only access to child nodes. @@ -906,13 +911,13 @@ impl TreeNode for T { f: F, ) -> Result> { let (new_self, children) = self.take_children(); - if !children.is_empty() { + if children.is_empty() { + Ok(Transformed::no(new_self)) + } else { let new_children = children.into_iter().map_until_stop_and_collect(f)?; // Propagate up `new_children.transformed` and `new_children.tnr` along with // the node containing transformed children. new_children.map_data(|new_children| new_self.with_new_children(new_children)) - } else { - Ok(Transformed::no(new_self)) } } } @@ -1013,7 +1018,7 @@ mod tests { "f_up(j)", ] .into_iter() - .map(|s| s.to_string()) + .map(std::string::ToString::to_string) .collect() } @@ -1084,7 +1089,7 @@ mod tests { "f_up(j)", ] .into_iter() - .map(|s| s.to_string()) + .map(std::string::ToString::to_string) .collect() } @@ -1118,7 +1123,7 @@ mod tests { "f_up(j)", ] .into_iter() - .map(|s| s.to_string()) + .map(std::string::ToString::to_string) .collect() } @@ -1170,7 +1175,7 @@ mod tests { "f_up(j)", ] .into_iter() - .map(|s| s.to_string()) + .map(std::string::ToString::to_string) .collect() } @@ -1225,7 +1230,7 @@ mod tests { "f_up(j)", ] .into_iter() - .map(|s| s.to_string()) + .map(std::string::ToString::to_string) .collect() } @@ -1252,7 +1257,7 @@ mod tests { "f_down(a)", ] .into_iter() - .map(|s| s.to_string()) + .map(std::string::ToString::to_string) .collect() } @@ -1286,7 +1291,7 @@ mod tests { fn f_down_stop_on_e_visits() -> Vec { vec!["f_down(j)", "f_down(i)", "f_down(f)", "f_down(e)"] .into_iter() - .map(|s| s.to_string()) + .map(std::string::ToString::to_string) .collect() } @@ -1331,7 +1336,7 @@ mod tests { "f_up(a)", ] .into_iter() - .map(|s| s.to_string()) + .map(std::string::ToString::to_string) .collect() } @@ -1379,7 +1384,7 @@ mod tests { "f_up(e)", ] .into_iter() - .map(|s| s.to_string()) + .map(std::string::ToString::to_string) .collect() } diff --git a/src/daft-compression/src/compression.rs b/src/daft-compression/src/compression.rs index 23ca21df47..7283bcc353 100644 --- a/src/daft-compression/src/compression.rs +++ b/src/daft-compression/src/compression.rs @@ -20,6 +20,7 @@ pub enum CompressionCodec { } impl CompressionCodec { + #[must_use] pub fn from_uri(uri: &str) -> Option { let url = Url::parse(uri); let path = match &url { @@ -32,8 +33,9 @@ impl CompressionCodec { .to_string(); Self::from_extension(extension.as_ref()) } + #[must_use] pub fn from_extension(extension: &str) -> Option { - use CompressionCodec::*; + use CompressionCodec::{Brotli, Bz, Deflate, Gzip, Lzma, Xz, Zlib, Zstd}; match extension { "br" => Some(Brotli), "bz2" => Some(Bz), @@ -52,7 +54,7 @@ impl CompressionCodec { &self, reader: T, ) -> Pin> { - use CompressionCodec::*; + use CompressionCodec::{Brotli, Bz, Deflate, Gzip, Lzma, Xz, Zlib, Zstd}; match self { Brotli => Box::pin(BrotliDecoder::new(reader)), Bz => Box::pin(BzDecoder::new(reader)), diff --git a/src/daft-core/src/array/fixed_size_list_array.rs b/src/daft-core/src/array/fixed_size_list_array.rs index a8b5048b82..29aeb2179a 100644 --- a/src/daft-core/src/array/fixed_size_list_array.rs +++ b/src/daft-core/src/array/fixed_size_list_array.rs @@ -1,16 +1,19 @@ use std::sync::Arc; +use arrow2::offset::OffsetsBuffer; use common_error::{DaftError, DaftResult}; use crate::{ array::growable::{Growable, GrowableArray}, datatypes::{DaftArrayType, DataType, Field}, + prelude::ListArray, series::Series, }; #[derive(Clone, Debug)] pub struct FixedSizeListArray { pub field: Arc, + /// contains all the elements of the nested lists flattened into a single contiguous array. pub flat_child: Series, validity: Option, } @@ -37,16 +40,13 @@ impl FixedSizeListArray { "FixedSizeListArray::new received values with len {} but expected it to match len of validity {} * size: {}", flat_child.len(), validity.len(), - (validity.len() * size), + validity.len() * size, ) } - if child_dtype.as_ref() != flat_child.data_type() { - panic!( - "FixedSizeListArray::new expects the child series to have dtype {}, but received: {}", + assert!(!(child_dtype.as_ref() != flat_child.data_type()), "FixedSizeListArray::new expects the child series to have dtype {}, but received: {}", child_dtype, flat_child.data_type(), - ) - } + ); } _ => panic!( "FixedSizeListArray::new expected FixedSizeList datatype, but received field: {}", @@ -64,6 +64,13 @@ impl FixedSizeListArray { self.validity.as_ref() } + pub fn null_count(&self) -> usize { + match self.validity() { + None => 0, + Some(validity) => validity.unset_bits(), + } + } + pub fn concat(arrays: &[&Self]) -> DaftResult { if arrays.is_empty() { return Err(DaftError::ValueError( @@ -105,10 +112,12 @@ impl FixedSizeListArray { &self.field.name } + #[must_use] pub fn data_type(&self) -> &DataType { &self.field.dtype } + #[must_use] pub fn child_data_type(&self) -> &DataType { match &self.field.dtype { DataType::FixedSizeList(child, _) => child.as_ref(), @@ -116,6 +125,7 @@ impl FixedSizeListArray { } } + #[must_use] pub fn rename(&self, name: &str) -> Self { Self::new( Field::new(name, self.data_type().clone()), @@ -174,6 +184,27 @@ impl FixedSizeListArray { validity, )) } + + fn generate_offsets(&self) -> OffsetsBuffer { + let size = self.fixed_element_len(); + let len = self.len(); + + // Create new offsets + let offsets: Vec = (0..=len) + .map(|i| i64::try_from(i * size).unwrap()) + .collect(); + + OffsetsBuffer::try_from(offsets).expect("Failed to create OffsetsBuffer") + } + + pub fn to_list(&self) -> ListArray { + ListArray::new( + self.field.clone(), + self.flat_child.clone(), + self.generate_offsets(), + self.validity.clone(), + ) + } } impl<'a> IntoIterator for &'a FixedSizeListArray { diff --git a/src/daft-core/src/array/from.rs b/src/daft-core/src/array/from.rs index b48c16a4ba..4320a9ff8a 100644 --- a/src/daft-core/src/array/from.rs +++ b/src/daft-core/src/array/from.rs @@ -1,3 +1,8 @@ +#![expect( + clippy::fallible_impl_from, + reason = "TODO(andrewgazelka/others): This should really be changed in the future" +)] + use std::{borrow::Cow, sync::Arc}; use common_error::{DaftError, DaftResult}; @@ -106,7 +111,7 @@ impl From<(&str, &[Option])> for BooleanArray { fn from(item: (&str, &[Option])) -> Self { let (name, slice) = item; let arrow_array = Box::new(arrow2::array::BooleanArray::from_trusted_len_iter( - slice.iter().cloned(), + slice.iter().copied(), )); Self::new(Field::new(name, DataType::Boolean).into(), arrow_array).unwrap() } diff --git a/src/daft-core/src/array/growable/arrow_growable.rs b/src/daft-core/src/array/growable/arrow_growable.rs index 4cea3b2569..655c3a314b 100644 --- a/src/daft-core/src/array/growable/arrow_growable.rs +++ b/src/daft-core/src/array/growable/arrow_growable.rs @@ -32,7 +32,7 @@ where #[inline] fn add_nulls(&mut self, additional: usize) { - self.arrow2_growable.extend_validity(additional) + self.arrow2_growable.extend_validity(additional); } #[inline] @@ -195,11 +195,11 @@ impl<'a> ArrowExtensionGrowable<'a> { impl<'a> Growable for ArrowExtensionGrowable<'a> { #[inline] fn extend(&mut self, index: usize, start: usize, len: usize) { - self.child_growable.extend(index, start, len) + self.child_growable.extend(index, start, len); } #[inline] fn add_nulls(&mut self, additional: usize) { - self.child_growable.extend_validity(additional) + self.child_growable.extend_validity(additional); } #[inline] fn build(&mut self) -> DaftResult { diff --git a/src/daft-core/src/array/growable/bitmap_growable.rs b/src/daft-core/src/array/growable/bitmap_growable.rs index 9c08375656..33d90dec63 100644 --- a/src/daft-core/src/array/growable/bitmap_growable.rs +++ b/src/daft-core/src/array/growable/bitmap_growable.rs @@ -18,17 +18,17 @@ impl<'a> ArrowBitmapGrowable<'a> { Some(bm) => { let (bm_data, bm_start, _bm_len) = bm.as_slice(); self.mutable_bitmap - .extend_from_slice(bm_data, bm_start + start, len) + .extend_from_slice(bm_data, bm_start + start, len); } } } pub fn add_nulls(&mut self, additional: usize) { - self.mutable_bitmap.extend_constant(additional, false) + self.mutable_bitmap.extend_constant(additional, false); } pub fn build(self) -> arrow2::bitmap::Bitmap { - self.mutable_bitmap.clone().into() + self.mutable_bitmap.into() } } diff --git a/src/daft-core/src/array/growable/fixed_size_list_growable.rs b/src/daft-core/src/array/growable/fixed_size_list_growable.rs index fd10c8dd94..68f0650120 100644 --- a/src/daft-core/src/array/growable/fixed_size_list_growable.rs +++ b/src/daft-core/src/array/growable/fixed_size_list_growable.rs @@ -62,9 +62,8 @@ impl<'a> Growable for FixedSizeListGrowable<'a> { len * self.element_fixed_len, ); - match &mut self.growable_validity { - Some(growable_validity) => growable_validity.extend(index, start, len), - None => (), + if let Some(growable_validity) = &mut self.growable_validity { + growable_validity.extend(index, start, len); } } @@ -72,9 +71,8 @@ impl<'a> Growable for FixedSizeListGrowable<'a> { self.child_growable .add_nulls(additional * self.element_fixed_len); - match &mut self.growable_validity { - Some(growable_validity) => growable_validity.add_nulls(additional), - None => (), + if let Some(growable_validity) = &mut self.growable_validity { + growable_validity.add_nulls(additional); } } diff --git a/src/daft-core/src/array/growable/list_growable.rs b/src/daft-core/src/array/growable/list_growable.rs index 25f44761be..7b3e7805e2 100644 --- a/src/daft-core/src/array/growable/list_growable.rs +++ b/src/daft-core/src/array/growable/list_growable.rs @@ -71,9 +71,8 @@ impl<'a> Growable for ListGrowable<'a> { (end_offset - start_offset).to_usize(), ); - match &mut self.growable_validity { - Some(growable_validity) => growable_validity.extend(index, start, len), - None => (), + if let Some(growable_validity) = &mut self.growable_validity { + growable_validity.extend(index, start, len); } self.growable_offsets @@ -82,9 +81,8 @@ impl<'a> Growable for ListGrowable<'a> { } fn add_nulls(&mut self, additional: usize) { - match &mut self.growable_validity { - Some(growable_validity) => growable_validity.add_nulls(additional), - None => (), + if let Some(growable_validity) = &mut self.growable_validity { + growable_validity.add_nulls(additional); } self.growable_offsets.extend_constant(additional); } diff --git a/src/daft-core/src/array/growable/logical_growable.rs b/src/daft-core/src/array/growable/logical_growable.rs index be95087443..aaab91dca4 100644 --- a/src/daft-core/src/array/growable/logical_growable.rs +++ b/src/daft-core/src/array/growable/logical_growable.rs @@ -29,7 +29,7 @@ where } #[inline] fn add_nulls(&mut self, additional: usize) { - self.physical_growable.add_nulls(additional) + self.physical_growable.add_nulls(additional); } #[inline] fn build(&mut self) -> DaftResult { diff --git a/src/daft-core/src/array/growable/struct_growable.rs b/src/daft-core/src/array/growable/struct_growable.rs index fb266ebb88..f33d9c050b 100644 --- a/src/daft-core/src/array/growable/struct_growable.rs +++ b/src/daft-core/src/array/growable/struct_growable.rs @@ -64,12 +64,11 @@ impl<'a> StructGrowable<'a> { impl<'a> Growable for StructGrowable<'a> { fn extend(&mut self, index: usize, start: usize, len: usize) { for child_growable in &mut self.children_growables { - child_growable.extend(index, start, len) + child_growable.extend(index, start, len); } - match &mut self.growable_validity { - Some(growable_validity) => growable_validity.extend(index, start, len), - None => (), + if let Some(growable_validity) = &mut self.growable_validity { + growable_validity.extend(index, start, len); } } @@ -78,9 +77,8 @@ impl<'a> Growable for StructGrowable<'a> { child_growable.add_nulls(additional); } - match &mut self.growable_validity { - Some(growable_validity) => growable_validity.add_nulls(additional), - None => (), + if let Some(growable_validity) = &mut self.growable_validity { + growable_validity.add_nulls(additional); } } diff --git a/src/daft-core/src/array/image_array.rs b/src/daft-core/src/array/image_array.rs index 5daa11d42d..9574f3745c 100644 --- a/src/daft-core/src/array/image_array.rs +++ b/src/daft-core/src/array/image_array.rs @@ -138,9 +138,7 @@ impl ImageArray { let offsets = arrow2::offset::OffsetsBuffer::try_from(offsets)?; let arrow_dtype: arrow2::datatypes::DataType = T::PRIMITIVE.into(); if let DataType::Image(Some(mode)) = &data_type { - if mode.get_dtype().to_arrow()? != arrow_dtype { - panic!("Inner value dtype of provided dtype {data_type:?} is inconsistent with inferred value dtype {arrow_dtype:?}"); - } + assert!(!(mode.get_dtype().to_arrow()? != arrow_dtype), "Inner value dtype of provided dtype {data_type:?} is inconsistent with inferred value dtype {arrow_dtype:?}"); } let data_array = ListArray::new( Field::new("data", DataType::List(Box::new((&arrow_dtype).into()))), diff --git a/src/daft-core/src/array/list_array.rs b/src/daft-core/src/array/list_array.rs index 538c24e716..d5aa207231 100644 --- a/src/daft-core/src/array/list_array.rs +++ b/src/daft-core/src/array/list_array.rs @@ -12,6 +12,8 @@ use crate::{ pub struct ListArray { pub field: Arc, pub flat_child: Series, + + /// Where each row starts and ends. Null rows usually have the same start/end index, but this is not guaranteed. offsets: arrow2::offset::OffsetsBuffer, validity: Option, } @@ -37,16 +39,13 @@ impl ListArray { { panic!("ListArray::new validity length does not match computed length from offsets") } - if child_dtype.as_ref() != flat_child.data_type() { - panic!( - "ListArray::new expects the child series to have field {}, but received: {}", - child_dtype, - flat_child.data_type(), - ) - } - if *offsets.last() > flat_child.len() as i64 { - panic!("ListArray::new received offsets with last value {}, but child series has length {}", offsets.last(), flat_child.len()) - } + assert!( + !(child_dtype.as_ref() != flat_child.data_type()), + "ListArray::new expects the child series to have field {}, but received: {}", + child_dtype, + flat_child.data_type(), + ); + assert!(*offsets.last() <= flat_child.len() as i64, "ListArray::new received offsets with last value {}, but child series has length {}", offsets.last(), flat_child.len()); } _ => panic!( "ListArray::new expected List datatype, but received field: {}", @@ -201,6 +200,15 @@ impl<'a> IntoIterator for &'a ListArray { } } +impl ListArray { + pub fn iter(&self) -> ListArrayIter<'_> { + ListArrayIter { + array: self, + idx: 0, + } + } +} + pub struct ListArrayIter<'a> { array: &'a ListArray, idx: usize, diff --git a/src/daft-core/src/array/mod.rs b/src/daft-core/src/array/mod.rs index 7c300c6a38..d75eabb535 100644 --- a/src/daft-core/src/array/mod.rs +++ b/src/daft-core/src/array/mod.rs @@ -18,11 +18,12 @@ pub mod prelude; use std::{marker::PhantomData, sync::Arc}; use common_error::{DaftError, DaftResult}; +use daft_schema::field::DaftField; use crate::datatypes::{DaftArrayType, DaftPhysicalType, DataType, Field}; #[derive(Debug)] -pub struct DataArray { +pub struct DataArray { pub field: Arc, pub data: Box, marker_: PhantomData, @@ -40,30 +41,43 @@ impl DaftArrayType for DataArray { } } -impl DataArray -where - T: DaftPhysicalType, -{ - pub fn new(field: Arc, data: Box) -> DaftResult { +impl DataArray { + pub fn new( + physical_field: Arc, + arrow_array: Box, + ) -> DaftResult { assert!( - field.dtype.is_physical(), + physical_field.dtype.is_physical(), "Can only construct DataArray for PhysicalTypes, got {}", - field.dtype + physical_field.dtype ); - if let Ok(arrow_dtype) = field.dtype.to_physical().to_arrow() { - if !arrow_dtype.eq(data.data_type()) { - panic!( - "expected {:?}, got {:?} when creating a new DataArray", - arrow_dtype, - data.data_type() - ) - } + if let Ok(expected_arrow_physical_type) = physical_field.dtype.to_arrow() { + let arrow_data_type = arrow_array.data_type(); + + assert!( + !(&expected_arrow_physical_type != arrow_data_type), + "Mismatch between expected and actual Arrow types for DataArray.\n\ + Field name: {}\n\ + Logical type: {}\n\ + Physical type: {}\n\ + Expected Arrow physical type: {:?}\n\ + Actual Arrow Logical type: {:?} + + This error typically occurs when there's a discrepancy between the Daft DataType \ + and the underlying Arrow representation. Please ensure that the physical type \ + of the Daft DataType matches the Arrow type of the provided data.", + physical_field.name, + physical_field.dtype, + physical_field.dtype.to_physical(), + expected_arrow_physical_type, + arrow_data_type + ); } Ok(Self { - field, - data, + field: physical_field, + data: arrow_array, marker_: PhantomData, }) } diff --git a/src/daft-core/src/array/ops/arithmetic.rs b/src/daft-core/src/array/ops/arithmetic.rs index 21e23657c6..365c178a28 100644 --- a/src/daft-core/src/array/ops/arithmetic.rs +++ b/src/daft-core/src/array/ops/arithmetic.rs @@ -10,9 +10,6 @@ use crate::{ kernels::utf8::add_utf8_arrays, series::Series, }; -/// Helper function to perform arithmetic operations on a DataArray -/// Takes both Kernel (array x array operation) and operation (scalar x scalar) functions -/// The Kernel is used for when both arrays are non-unit length and the operation is used when broadcasting // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights @@ -31,6 +28,9 @@ use crate::{ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. +/// Helper function to perform arithmetic operations on a DataArray +/// Takes both Kernel (array x array operation) and operation (scalar x scalar) functions +/// The Kernel is used for when both arrays are non-unit length and the operation is used when broadcasting fn arithmetic_helper( lhs: &DataArray, rhs: &DataArray, @@ -131,9 +131,7 @@ where T: arrow2::types::NativeType, F: Fn(T, T) -> T, { - if lhs.len() != rhs.len() { - panic!("expected same length") - } + assert!(lhs.len() == rhs.len(), "expected same length"); let values = lhs.iter().zip(rhs.iter()).map(|(l, r)| match (l, r) { (None, _) => None, (_, None) => None, diff --git a/src/daft-core/src/array/ops/arrow2/comparison.rs b/src/daft-core/src/array/ops/arrow2/comparison.rs index 37f7b2a37b..a9c37c50fb 100644 --- a/src/daft-core/src/array/ops/arrow2/comparison.rs +++ b/src/daft-core/src/array/ops/arrow2/comparison.rs @@ -49,7 +49,7 @@ fn build_is_equal_with_nan( } } -fn build_is_equal( +pub fn build_is_equal( left: &dyn Array, right: &dyn Array, nulls_equal: bool, @@ -95,7 +95,7 @@ pub fn build_multi_array_is_equal( } let combined_fn = Box::new(move |a_idx: usize, b_idx: usize| -> bool { - for f in fn_list.iter() { + for f in &fn_list { if !f(a_idx, b_idx) { return false; } diff --git a/src/daft-core/src/array/ops/arrow2/sort/primitive/sort.rs b/src/daft-core/src/array/ops/arrow2/sort/primitive/sort.rs index 8535b8ca7c..62af2988b2 100644 --- a/src/daft-core/src/array/ops/arrow2/sort/primitive/sort.rs +++ b/src/daft-core/src/array/ops/arrow2/sort/primitive/sort.rs @@ -89,7 +89,7 @@ where // extend buffer with constants followed by non-null values buffer.resize(validity.unset_bits(), T::default()); for (start, len) in slices { - buffer.extend_from_slice(&values[start..start + len]) + buffer.extend_from_slice(&values[start..start + len]); } // sort values @@ -105,7 +105,7 @@ where // extend buffer with non-null values for (start, len) in slices { - buffer.extend_from_slice(&values[start..start + len]) + buffer.extend_from_slice(&values[start..start + len]); } // sort all non-null values @@ -200,7 +200,7 @@ mod tests { .unwrap() .clone(); let output = sort_by(&input, ord::total_cmp, &options, Some(3)); - assert_eq!(expected, output) + assert_eq!(expected, output); } #[test] diff --git a/src/daft-core/src/array/ops/as_arrow.rs b/src/daft-core/src/array/ops/as_arrow.rs index 51ba52dd2c..c2315d39cd 100644 --- a/src/daft-core/src/array/ops/as_arrow.rs +++ b/src/daft-core/src/array/ops/as_arrow.rs @@ -15,7 +15,9 @@ use crate::{ pub trait AsArrow { type Output; - // Retrieve the underlying concrete Arrow2 array. + /// This does not correct for the logical types and will just yield the physical type of the array. + /// For example, a TimestampArray will yield an arrow Int64Array rather than a arrow Timestamp Array. + /// To get a corrected arrow type, see `.to_arrow()`. fn as_arrow(&self) -> &Self::Output; } diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index c3dbe0c209..d495990d45 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -206,7 +206,7 @@ impl DateArray { /// Formats a naive timestamp to a string in the format "%Y-%m-%d %H:%M:%S%.f". /// Example: 2021-01-01 00:00:00 /// See https://docs.rs/chrono/latest/chrono/format/strftime/index.html for format string options. -pub(crate) fn timestamp_to_str_naive(val: i64, unit: &TimeUnit) -> String { +pub fn timestamp_to_str_naive(val: i64, unit: &TimeUnit) -> String { let chrono_ts = arrow2::temporal_conversions::timestamp_to_naive_datetime(val, unit.to_arrow()); let format_str = "%Y-%m-%d %H:%M:%S%.f"; chrono_ts.format(format_str).to_string() @@ -215,11 +215,7 @@ pub(crate) fn timestamp_to_str_naive(val: i64, unit: &TimeUnit) -> String { /// Formats a timestamp with an offset to a string in the format "%Y-%m-%d %H:%M:%S%.f %:z". /// Example: 2021-01-01 00:00:00 -07:00 /// See https://docs.rs/chrono/latest/chrono/format/strftime/index.html for format string options. -pub(crate) fn timestamp_to_str_offset( - val: i64, - unit: &TimeUnit, - offset: &chrono::FixedOffset, -) -> String { +pub fn timestamp_to_str_offset(val: i64, unit: &TimeUnit, offset: &chrono::FixedOffset) -> String { let chrono_ts = arrow2::temporal_conversions::timestamp_to_datetime(val, unit.to_arrow(), offset); let format_str = "%Y-%m-%d %H:%M:%S%.f %:z"; @@ -229,7 +225,7 @@ pub(crate) fn timestamp_to_str_offset( /// Formats a timestamp with a timezone to a string in the format "%Y-%m-%d %H:%M:%S%.f %Z". /// Example: 2021-01-01 00:00:00 PST /// See https://docs.rs/chrono/latest/chrono/format/strftime/index.html for format string options. -pub(crate) fn timestamp_to_str_tz(val: i64, unit: &TimeUnit, tz: &chrono_tz::Tz) -> String { +pub fn timestamp_to_str_tz(val: i64, unit: &TimeUnit, tz: &chrono_tz::Tz) -> String { let chrono_ts = arrow2::temporal_conversions::timestamp_to_datetime(val, unit.to_arrow(), tz); let format_str = "%Y-%m-%d %H:%M:%S%.f %Z"; chrono_ts.format(format_str).to_string() @@ -647,9 +643,9 @@ fn extract_python_to_vec< if let Some(list_size) = list_size { if num_values != list_size { return Err(DaftError::ValueError(format!( - "Expected Array-like Object to have {list_size} elements but got {} at index {}", - num_values, i - ))); + "Expected Array-like Object to have {list_size} elements but got {} at index {}", + num_values, i + ))); } } else { offsets_vec.push(offsets_vec.last().unwrap() + num_values as i64); @@ -700,7 +696,7 @@ fn extract_python_to_vec< }; if collected.is_err() { - log::warn!("Could not convert python object to list at index: {i} for input series: {}", python_objects.name()) + log::warn!("Could not convert python object to list at index: {i} for input series: {}", python_objects.name()); } let collected: Vec = collected?; if let Some(list_size) = list_size { @@ -1351,7 +1347,7 @@ impl TensorArray { .call_method1(pyo3::intern!(py, "reshape"), (shape,))?; ndarrays.push(py_array.unbind()); } else { - ndarrays.push(py.None()) + ndarrays.push(py.None()); } } let values_array = @@ -1400,12 +1396,10 @@ impl TensorArray { let zero_series = Int64Array::from(("item", [0].as_slice())).into_series(); let mut non_zero_values = Vec::new(); let mut non_zero_indices = Vec::new(); - let mut offsets = Vec::::new(); for (i, (shape_series, data_series)) in shape_and_data_iter.enumerate() { let is_valid = validity.map_or(true, |v| v.get_bit(i)); if !is_valid { // Handle invalid row by populating dummy data. - offsets.push(1); non_zero_values.push(Series::empty("dummy", inner_dtype.as_ref())); non_zero_indices.push(Series::empty("dummy", &DataType::UInt64)); continue; @@ -1422,7 +1416,6 @@ impl TensorArray { let indices = UInt64Array::arange("item", 0, data_series.len() as i64, 1)? .into_series() .filter(&non_zero_mask)?; - offsets.push(data.len()); non_zero_values.push(data); non_zero_indices.push(indices); } @@ -1635,7 +1628,7 @@ impl SparseTensorArray { }) }) .collect(); - let offsets: Offsets = Offsets::try_from_iter(sizes_vec.iter().cloned())?; + let offsets: Offsets = Offsets::try_from_iter(sizes_vec.iter().copied())?; let n_values = sizes_vec.iter().sum::(); let validity = non_zero_indices_array.validity(); let item = cast_sparse_to_dense_for_inner_dtype( @@ -1698,6 +1691,44 @@ impl SparseTensorArray { ); Ok(sparse_tensor_array.into_series()) } + #[cfg(feature = "python")] + DataType::Python => Python::with_gil(|py| { + let mut pydicts: Vec> = Vec::with_capacity(self.len()); + let sa = self.shape_array(); + let va = self.values_array(); + let ia = self.indices_array(); + let pyarrow = py.import_bound(pyo3::intern!(py, "pyarrow"))?; + for ((shape_array, values_array), indices_array) in + sa.into_iter().zip(va.into_iter()).zip(ia.into_iter()) + { + if let (Some(shape_array), Some(values_array), Some(indices_array)) = + (shape_array, values_array, indices_array) + { + let shape_array = shape_array.u64().unwrap().as_arrow(); + let shape = shape_array.values().to_vec(); + let py_values_array = + ffi::to_py_array(py, values_array.to_arrow(), &pyarrow)? + .call_method1(pyo3::intern!(py, "to_numpy"), (false,))?; + let py_indices_array = + ffi::to_py_array(py, indices_array.to_arrow(), &pyarrow)? + .call_method1(pyo3::intern!(py, "to_numpy"), (false,))?; + let pydict = pyo3::types::PyDict::new_bound(py); + pydict.set_item("values", py_values_array)?; + pydict.set_item("indices", py_indices_array)?; + pydict.set_item("shape", shape)?; + pydicts.push(pydict.unbind().into()); + } else { + pydicts.push(py.None()); + } + } + let py_objects_array = + PseudoArrowArray::new(pydicts.into(), self.physical.validity().cloned()); + Ok(PythonArray::new( + Field::new(self.name(), dtype.clone()).into(), + py_objects_array.to_boxed(), + )? + .into_series()) + }), _ => self.physical.cast(dtype), } } @@ -1793,6 +1824,13 @@ impl FixedShapeSparseTensorArray { FixedShapeTensorArray::new(Field::new(self.name(), dtype.clone()), physical); Ok(fixed_shape_tensor_array.into_series()) } + #[cfg(feature = "python")] + (DataType::Python, DataType::FixedShapeSparseTensor(inner_dtype, _)) => { + let sparse_tensor_series = + self.cast(&DataType::SparseTensor(inner_dtype.clone()))?; + let sparse_pytensor_series = sparse_tensor_series.cast(&DataType::Python)?; + Ok(sparse_pytensor_series) + } (_, _) => self.physical.cast(dtype), } } @@ -1882,24 +1920,24 @@ impl FixedShapeTensorArray { let zero_series = Int64Array::from(("item", [0].as_slice())).into_series(); let mut non_zero_values = Vec::new(); let mut non_zero_indices = Vec::new(); - let mut offsets = Vec::::new(); for (i, data_series) in physical_arr.into_iter().enumerate() { let is_valid = validity.map_or(true, |v| v.get_bit(i)); if !is_valid { // Handle invalid row by populating dummy data. - offsets.push(1); non_zero_values.push(Series::empty("dummy", inner_dtype.as_ref())); non_zero_indices.push(Series::empty("dummy", &DataType::UInt64)); continue; } let data_series = data_series.unwrap(); - assert!(data_series.len() == tensor_shape.iter().product::() as usize); + assert_eq!( + data_series.len(), + tensor_shape.iter().product::() as usize + ); let non_zero_mask = data_series.not_equal(&zero_series)?; let data = data_series.filter(&non_zero_mask)?; let indices = UInt64Array::arange("item", 0, data_series.len() as i64, 1)? .into_series() .filter(&non_zero_mask)?; - offsets.push(data.len()); non_zero_values.push(data); non_zero_indices.push(indices); } @@ -2057,7 +2095,7 @@ impl ListArray { } Ok(FixedSizeListArray::new( Field::new(self.name(), dtype.clone()), - casted_child.clone(), + casted_child, None, ) .into_series()) @@ -2091,7 +2129,7 @@ impl ListArray { } } } - DataType::Map(..) => Ok(MapArray::new( + DataType::Map { .. } => Ok(MapArray::new( Field::new(self.name(), dtype.clone()), self.clone(), ) @@ -2198,7 +2236,10 @@ where { Python::with_gil(|py| { let arrow_dtype = array.data_type().to_arrow()?; - let arrow_array = array.as_arrow().to_type(arrow_dtype).with_validity(None); + let arrow_array = array + .as_arrow() + .convert_logical_type(arrow_dtype) + .with_validity(None); let pyarrow = py.import_bound(pyo3::intern!(py, "pyarrow"))?; let py_array: Vec = ffi::to_py_array(py, arrow_array.to_boxed(), &pyarrow)? .call_method0(pyo3::intern!(py, "to_pylist"))? diff --git a/src/daft-core/src/array/ops/comparison.rs b/src/daft-core/src/array/ops/comparison.rs index aee84893de..8c941f8b2a 100644 --- a/src/daft-core/src/array/ops/comparison.rs +++ b/src/daft-core/src/array/ops/comparison.rs @@ -766,7 +766,7 @@ impl DaftLogical for BooleanArray { Bitmap::new_zeroed(self.len()), validity.cloned(), ); - return Ok(Self::from((self.name(), arrow_array))); + Ok(Self::from((self.name(), arrow_array))) } } @@ -780,9 +780,9 @@ impl DaftLogical for BooleanArray { validity.cloned(), ); return Ok(Self::from((self.name(), arrow_array))); - } else { - Ok(self.clone()) } + + Ok(self.clone()) } fn xor(&self, rhs: bool) -> Self::Output { diff --git a/src/daft-core/src/array/ops/concat_agg.rs b/src/daft-core/src/array/ops/concat_agg.rs index d3681ea3a5..c222f6190e 100644 --- a/src/daft-core/src/array/ops/concat_agg.rs +++ b/src/daft-core/src/array/ops/concat_agg.rs @@ -214,7 +214,7 @@ impl DaftConcatAggable for DataArray { #[cfg(test)] mod test { - use std::iter::repeat; + use std::{iter, iter::repeat}; use common_error::DaftResult; @@ -231,7 +231,9 @@ mod test { Field::new("foo", DataType::List(Box::new(DataType::Int64))), Int64Array::from(( "item", - Box::new(arrow2::array::Int64Array::from_iter([].iter())), + Box::new(arrow2::array::Int64Array::from_iter(iter::empty::< + &Option, + >())), )) .into_series(), arrow2::offset::OffsetsBuffer::::try_from(vec![0, 0, 0, 0])?, diff --git a/src/daft-core/src/array/ops/from_arrow.rs b/src/daft-core/src/array/ops/from_arrow.rs index 1739b524a9..7fd252ab02 100644 --- a/src/daft-core/src/array/ops/from_arrow.rs +++ b/src/daft-core/src/array/ops/from_arrow.rs @@ -21,7 +21,7 @@ where impl FromArrow for DataArray { fn from_arrow(field: FieldRef, arrow_arr: Box) -> DaftResult { - Self::try_from((field.clone(), arrow_arr)) + Self::try_from((field, arrow_arr)) } } @@ -30,19 +30,16 @@ where ::ArrayType: FromArrow, { fn from_arrow(field: FieldRef, arrow_arr: Box) -> DaftResult { - let data_array_field = Arc::new(Field::new(field.name.clone(), field.dtype.to_physical())); - let physical_arrow_arr = match field.dtype { - // TODO: Consolidate Map to use the same .to_type conversion as other logical types - // Currently, .to_type does not work for Map in Arrow2 because it requires physical types to be equivalent, - // but the physical type of MapArray in Arrow2 is a MapArray, not a ListArray - DataType::Map(..) => arrow_arr, - _ => arrow_arr.to_type(data_array_field.dtype.to_arrow()?), - }; + let target_convert = field.to_physical(); + let target_convert_arrow = target_convert.dtype.to_arrow()?; + + let physical_arrow_array = arrow_arr.convert_logical_type(target_convert_arrow); + let physical = ::ArrayType::from_arrow( - data_array_field, - physical_arrow_arr, + Arc::new(target_convert), + physical_arrow_array, )?; - Ok(Self::new(field.clone(), physical)) + Ok(Self::new(field, physical)) } } @@ -69,8 +66,14 @@ impl FromArrow for FixedSizeListArray { } impl FromArrow for ListArray { - fn from_arrow(field: FieldRef, arrow_arr: Box) -> DaftResult { - match (&field.dtype, arrow_arr.data_type()) { + fn from_arrow( + target_field: FieldRef, + arrow_arr: Box, + ) -> DaftResult { + let target_dtype = &target_field.dtype; + let arrow_dtype = arrow_arr.data_type(); + + let result = match (target_dtype, arrow_dtype) { ( DataType::List(daft_child_dtype), arrow2::datatypes::DataType::List(arrow_child_field), @@ -79,47 +82,40 @@ impl FromArrow for ListArray { DataType::List(daft_child_dtype), arrow2::datatypes::DataType::LargeList(arrow_child_field), ) => { - let arrow_arr = arrow_arr.to_type(arrow2::datatypes::DataType::LargeList( - arrow_child_field.clone(), - )); + // unifying lists + let arrow_arr = arrow_arr.convert_logical_type( + arrow2::datatypes::DataType::LargeList(arrow_child_field.clone()), + ); + let arrow_arr = arrow_arr .as_any() - .downcast_ref::>() + .downcast_ref::>() // list array with i64 offsets .unwrap(); + let arrow_child_array = arrow_arr.values(); let child_series = Series::from_arrow( Arc::new(Field::new("list", daft_child_dtype.as_ref().clone())), arrow_child_array.clone(), )?; Ok(Self::new( - field.clone(), + target_field.clone(), child_series, arrow_arr.offsets().clone(), arrow_arr.validity().cloned(), )) } - (DataType::List(daft_child_dtype), arrow2::datatypes::DataType::Map(..)) => { - let map_arr = arrow_arr - .as_any() - .downcast_ref::() - .unwrap(); - let arrow_child_array = map_arr.field(); - let child_series = Series::from_arrow( - Arc::new(Field::new("map", daft_child_dtype.as_ref().clone())), - arrow_child_array.clone(), - )?; - Ok(Self::new( - field.clone(), - child_series, - map_arr.offsets().into(), - arrow_arr.validity().cloned(), - )) + (DataType::List(daft_child_dtype), arrow2::datatypes::DataType::Map { .. }) => { + Err(DaftError::TypeError(format!( + "Arrow Map type should be converted to Daft Map type, not List. Attempted to create Daft ListArray with type {daft_child_dtype} from Arrow Map type.", + ))) } (d, a) => Err(DaftError::TypeError(format!( "Attempting to create Daft ListArray with type {} from arrow array with type {:?}", d, a ))), - } + }?; + + Ok(result) } } @@ -128,7 +124,7 @@ impl FromArrow for StructArray { match (&field.dtype, arrow_arr.data_type()) { (DataType::Struct(fields), arrow2::datatypes::DataType::Struct(arrow_fields)) => { if fields.len() != arrow_fields.len() { - return Err(DaftError::ValueError(format!("Attempting to create Daft StructArray with {} fields from Arrow array with {} fields: {} vs {:?}", fields.len(), arrow_fields.len(), &field.dtype, arrow_arr.data_type()))) + return Err(DaftError::ValueError(format!("Attempting to create Daft StructArray with {} fields from Arrow array with {} fields: {} vs {:?}", fields.len(), arrow_fields.len(), &field.dtype, arrow_arr.data_type()))); } let arrow_arr = arrow_arr.as_ref().as_any().downcast_ref::().unwrap(); @@ -143,7 +139,7 @@ impl FromArrow for StructArray { child_series, arrow_arr.validity().cloned(), )) - }, + } (d, a) => Err(DaftError::TypeError(format!("Attempting to create Daft StructArray with type {} from arrow array with type {:?}", d, a))) } } diff --git a/src/daft-core/src/array/ops/full.rs b/src/daft-core/src/array/ops/full.rs index ac65be6a7a..d2b90ae2e2 100644 --- a/src/daft-core/src/array/ops/full.rs +++ b/src/daft-core/src/array/ops/full.rs @@ -25,12 +25,11 @@ where { /// Creates a DataArray of size `length` that is filled with all nulls. fn full_null(name: &str, dtype: &DataType, length: usize) -> Self { - if dtype != &T::get_dtype() && !matches!(T::get_dtype(), DataType::Unknown) { - panic!( - "Cannot create DataArray from dtype: {dtype} with physical type: {}", - T::get_dtype() - ); - } + assert!( + !(dtype != &T::get_dtype() && !matches!(T::get_dtype(), DataType::Unknown)), + "Cannot create DataArray from dtype: {dtype} with physical type: {}", + T::get_dtype() + ); let field = Field::new(name, dtype.clone()); #[cfg(feature = "python")] if dtype.is_python() { diff --git a/src/daft-core/src/array/ops/get.rs b/src/daft-core/src/array/ops/get.rs index a9b5a14ae6..eb33064178 100644 --- a/src/daft-core/src/array/ops/get.rs +++ b/src/daft-core/src/array/ops/get.rs @@ -18,9 +18,12 @@ where { #[inline] pub fn get(&self, idx: usize) -> Option { - if idx >= self.len() { - panic!("Out of bounds: {} vs len: {}", idx, self.len()) - } + assert!( + idx < self.len(), + "Out of bounds: {} vs len: {}", + idx, + self.len() + ); let arrow_array = self.as_arrow(); let is_valid = arrow_array .validity() @@ -76,9 +79,12 @@ impl_array_arrow_get!(TimestampArray, i64); impl NullArray { #[inline] pub fn get(&self, idx: usize) -> Option<()> { - if idx >= self.len() { - panic!("Out of bounds: {} vs len: {}", idx, self.len()) - } + assert!( + idx < self.len(), + "Out of bounds: {} vs len: {}", + idx, + self.len() + ); None } } @@ -86,9 +92,12 @@ impl NullArray { impl ExtensionArray { #[inline] pub fn get(&self, idx: usize) -> Option> { - if idx >= self.len() { - panic!("Out of bounds: {} vs len: {}", idx, self.len()) - } + assert!( + idx < self.len(), + "Out of bounds: {} vs len: {}", + idx, + self.len() + ); let is_valid = self .data .validity() @@ -108,9 +117,12 @@ impl crate::datatypes::PythonArray { use arrow2::array::Array; use pyo3::prelude::*; - if idx >= self.len() { - panic!("Out of bounds: {} vs len: {}", idx, self.len()) - } + assert!( + idx < self.len(), + "Out of bounds: {} vs len: {}", + idx, + self.len() + ); let valid = self .as_arrow() .validity() @@ -127,9 +139,12 @@ impl crate::datatypes::PythonArray { impl FixedSizeListArray { #[inline] pub fn get(&self, idx: usize) -> Option { - if idx >= self.len() { - panic!("Out of bounds: {} vs len: {}", idx, self.len()) - } + assert!( + idx < self.len(), + "Out of bounds: {} vs len: {}", + idx, + self.len() + ); let fixed_len = self.fixed_element_len(); let valid = self.is_valid(idx); if valid { @@ -147,9 +162,12 @@ impl FixedSizeListArray { impl ListArray { #[inline] pub fn get(&self, idx: usize) -> Option { - if idx >= self.len() { - panic!("Out of bounds: {} vs len: {}", idx, self.len()) - } + assert!( + idx < self.len(), + "Out of bounds: {} vs len: {}", + idx, + self.len() + ); let valid = self.is_valid(idx); if valid { let (start, end) = self.offsets().start_end(idx); diff --git a/src/daft-core/src/array/ops/groups.rs b/src/daft-core/src/array/ops/groups.rs index 9676ef3a52..6f053040c3 100644 --- a/src/daft-core/src/array/ops/groups.rs +++ b/src/daft-core/src/array/ops/groups.rs @@ -1,6 +1,6 @@ use std::{ collections::hash_map::Entry::{Occupied, Vacant}, - hash::Hash, + hash::{BuildHasherDefault, Hash}, }; use arrow2::array::Array; @@ -37,12 +37,12 @@ use crate::{ fn make_groups(iter: impl Iterator) -> DaftResult where T: Hash, - T: std::cmp::Eq, + T: Eq, { const DEFAULT_SIZE: usize = 256; let mut tbl = FnvHashMap::)>::with_capacity_and_hasher( DEFAULT_SIZE, - Default::default(), + BuildHasherDefault::default(), ); for (idx, val) in iter.enumerate() { let idx = idx as u64; @@ -56,15 +56,15 @@ where } } } - let mut s_indices = Vec::with_capacity(tbl.len()); - let mut g_indices = Vec::with_capacity(tbl.len()); + let mut sample_indices = Vec::with_capacity(tbl.len()); + let mut group_indices = Vec::with_capacity(tbl.len()); - for (s_idx, g_idx) in tbl.into_values() { - s_indices.push(s_idx); - g_indices.push(g_idx); + for (sample_index, group_index) in tbl.into_values() { + sample_indices.push(sample_index); + group_indices.push(group_index); } - Ok((s_indices, g_indices)) + Ok((sample_indices, group_indices)) } impl IntoGroups for DataArray diff --git a/src/daft-core/src/array/ops/if_else.rs b/src/daft-core/src/array/ops/if_else.rs index 8981ac2e1f..3eef4a93db 100644 --- a/src/daft-core/src/array/ops/if_else.rs +++ b/src/daft-core/src/array/ops/if_else.rs @@ -66,8 +66,6 @@ fn generic_if_else( } } } - growable.build() - // CASE 3: predicate is not broadcastable, and does not contain nulls } else { // Helper to extend the growable, taking into account broadcast semantics @@ -108,8 +106,9 @@ fn generic_if_else( if total_len != predicate.len() { extend(false, total_len, predicate.len() - total_len); } - growable.build() } + + growable.build() } impl DataArray diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index 4dd8cee2a8..080fed0ad0 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -2,16 +2,24 @@ use std::{iter::repeat, sync::Arc}; use arrow2::offset::OffsetsBuffer; use common_error::DaftResult; +use indexmap::{ + map::{raw_entry_v1::RawEntryMut, RawEntryApiV1}, + IndexMap, +}; use super::as_arrow::AsArrow; use crate::{ array::{ growable::{make_growable, Growable}, - FixedSizeListArray, ListArray, + ops::arrow2::comparison::build_is_equal, + FixedSizeListArray, ListArray, StructArray, }, count_mode::CountMode, datatypes::{BooleanArray, DataType, Field, Int64Array, UInt64Array, Utf8Array}, + kernels::search_sorted::build_is_valid, + prelude::MapArray, series::{IntoSeries, Series}, + utils::identity_hash_set::IdentityBuildHasher, }; fn join_arrow_list_of_utf8s( @@ -25,7 +33,7 @@ fn join_arrow_list_of_utf8s( .downcast_ref::>() .unwrap() .iter() - .fold(String::from(""), |acc, str_item| { + .fold(String::new(), |acc, str_item| { acc + str_item.unwrap_or("") + delimiter_str }) // Remove trailing `delimiter_str` @@ -43,7 +51,7 @@ fn join_arrow_list_of_utf8s( // Given an i64 array that may have either 1 or `self.len()` elements, create an iterator with // `self.len()` elements. If there was originally 1 element, we repeat this element `self.len()` // times, otherwise we simply take the original array. -fn create_iter<'a>(arr: &'a Int64Array, len: usize) -> Box + '_> { +fn create_iter<'a>(arr: &'a Int64Array, len: usize) -> Box + 'a> { match arr.len() { 1 => Box::new(repeat(arr.get(0).unwrap()).take(len)), arr_len => { @@ -244,6 +252,134 @@ fn list_sort_helper_fixed_size( } impl ListArray { + pub fn value_counts(&self) -> DaftResult { + struct IndexRef { + index: usize, + hash: u64, + } + + impl std::hash::Hash for IndexRef { + fn hash(&self, state: &mut H) { + self.hash.hash(state); + } + } + + let original_name = self.name(); + + let hashes = self.flat_child.hash(None)?; + + let flat_child = self.flat_child.to_arrow(); + let flat_child = &*flat_child; + + let is_equal = build_is_equal( + flat_child, flat_child, + false, // this value does not matter; invalid (= nulls) are never included + true, // NaNs are equal so we do not get a bunch of {Nan: 1, Nan: 1, ...} + )?; + + let is_valid = build_is_valid(flat_child); + + let key_type = self.flat_child.data_type().clone(); + let count_type = DataType::UInt64; + + let mut include_mask = Vec::with_capacity(self.flat_child.len()); + let mut count_array = Vec::new(); + + let mut offsets = Vec::with_capacity(self.len()); + + offsets.push(0_i64); + + let mut map: IndexMap = IndexMap::default(); + for range in self.offsets().ranges() { + map.clear(); + + for index in range { + let index = index as usize; + if !is_valid(index) { + include_mask.push(false); + // skip nulls + continue; + } + + let hash = hashes.get(index).unwrap(); + + let entry = map + .raw_entry_mut_v1() + .from_hash(hash, |other| is_equal(other.index, index)); + + match entry { + RawEntryMut::Occupied(mut entry) => { + include_mask.push(false); + *entry.get_mut() += 1; + } + RawEntryMut::Vacant(vacant) => { + include_mask.push(true); + vacant.insert(IndexRef { index, hash }, 1); + } + } + } + + // IndexMap maintains insertion order, so we iterate over its values + // in the same order that elements were added. This ensures that + // the count_array values correspond to the same order in which + // the include_mask was set earlier in the loop. Each 'true' in + // include_mask represents a unique key, and its corresponding + // count is now added to count_array in the same sequence. + for v in map.values() { + count_array.push(*v); + } + + offsets.push(count_array.len() as i64); + } + + let values = UInt64Array::from(("count", count_array)).into_series(); + let include_mask = BooleanArray::from(("boolean", include_mask.as_slice())); + + let keys = self.flat_child.filter(&include_mask)?; + + let keys = Series::try_from_field_and_arrow_array( + Field::new("key", key_type.clone()), + keys.to_arrow(), + )?; + + let values = Series::try_from_field_and_arrow_array( + Field::new("value", count_type.clone()), + values.to_arrow(), + )?; + + let struct_type = DataType::Struct(vec![ + Field::new("key", key_type.clone()), + Field::new("value", count_type.clone()), + ]); + + let struct_array = StructArray::new( + Arc::new(Field::new("entries", struct_type.clone())), + vec![keys, values], + None, + ); + + let list_type = DataType::List(Box::new(struct_type)); + + let offsets = OffsetsBuffer::try_from(offsets)?; + + let list_array = Self::new( + Arc::new(Field::new("entries", list_type)), + struct_array.into_series(), + offsets, + None, + ); + + let map_type = DataType::Map { + key: Box::new(key_type), + value: Box::new(count_type), + }; + + Ok(MapArray::new( + Field::new(original_name, map_type), + list_array, + )) + } + pub fn count(&self, mode: CountMode) -> DaftResult { let counts = match (mode, self.flat_child.validity()) { (CountMode::All, _) | (CountMode::Valid, None) => { @@ -472,6 +608,11 @@ impl ListArray { } impl FixedSizeListArray { + pub fn value_counts(&self) -> DaftResult { + let list = self.to_list(); + list.value_counts() + } + pub fn count(&self, mode: CountMode) -> DaftResult { let size = self.fixed_element_len(); let counts = match (mode, self.flat_child.validity()) { diff --git a/src/daft-core/src/array/ops/list_agg.rs b/src/daft-core/src/array/ops/list_agg.rs index 0792a17675..89bf7090a1 100644 --- a/src/daft-core/src/array/ops/list_agg.rs +++ b/src/daft-core/src/array/ops/list_agg.rs @@ -10,52 +10,72 @@ use crate::{ series::IntoSeries, }; +macro_rules! impl_daft_list_agg { + () => { + type Output = DaftResult; + + fn list(&self) -> Self::Output { + let child_series = self.clone().into_series(); + let offsets = + arrow2::offset::OffsetsBuffer::try_from(vec![0, child_series.len() as i64])?; + let list_field = self.field.to_list_field()?; + Ok(ListArray::new(list_field, child_series, offsets, None)) + } + + fn grouped_list(&self, groups: &GroupIndices) -> Self::Output { + let mut offsets = Vec::with_capacity(groups.len() + 1); + + offsets.push(0); + for g in groups { + offsets.push(offsets.last().unwrap() + g.len() as i64); + } + + let total_capacity = *offsets.last().unwrap(); + + let mut growable: Box = Box::new(Self::make_growable( + self.name(), + self.data_type(), + vec![self], + self.null_count() > 0, + total_capacity as usize, + )); + + for g in groups { + for idx in g { + growable.extend(0, *idx as usize, 1); + } + } + let list_field = self.field.to_list_field()?; + + Ok(ListArray::new( + list_field, + growable.build()?, + arrow2::offset::OffsetsBuffer::try_from(offsets)?, + None, + )) + } + }; +} + impl DaftListAggable for DataArray where T: DaftArrowBackedType, Self: IntoSeries, Self: GrowableArray, { - type Output = DaftResult; - fn list(&self) -> Self::Output { - let child_series = self.clone().into_series(); - let offsets = arrow2::offset::OffsetsBuffer::try_from(vec![0, child_series.len() as i64])?; - let list_field = self.field.to_list_field()?; - Ok(ListArray::new(list_field, child_series, offsets, None)) - } - - fn grouped_list(&self, groups: &GroupIndices) -> Self::Output { - let mut offsets = Vec::with_capacity(groups.len() + 1); - - offsets.push(0); - for g in groups { - offsets.push(offsets.last().unwrap() + g.len() as i64); - } + impl_daft_list_agg!(); +} - let total_capacity = *offsets.last().unwrap(); +impl DaftListAggable for ListArray { + impl_daft_list_agg!(); +} - let mut growable: Box = Box::new(Self::make_growable( - self.name(), - self.data_type(), - vec![self], - self.data.null_count() > 0, - total_capacity as usize, - )); +impl DaftListAggable for FixedSizeListArray { + impl_daft_list_agg!(); +} - for g in groups { - for idx in g { - growable.extend(0, *idx as usize, 1); - } - } - let list_field = self.field.to_list_field()?; - - Ok(ListArray::new( - list_field, - growable.build()?, - arrow2::offset::OffsetsBuffer::try_from(offsets)?, - None, - )) - } +impl DaftListAggable for StructArray { + impl_daft_list_agg!(); } #[cfg(feature = "python")] @@ -95,45 +115,3 @@ impl DaftListAggable for crate::datatypes::PythonArray { Self::new(self.field().clone().into(), Box::new(arrow_array)) } } - -impl DaftListAggable for ListArray { - type Output = DaftResult; - - fn list(&self) -> Self::Output { - // TODO(FixedSizeList) - todo!("Requires new ListArrays for implementation") - } - - fn grouped_list(&self, _groups: &GroupIndices) -> Self::Output { - // TODO(FixedSizeList) - todo!("Requires new ListArrays for implementation") - } -} - -impl DaftListAggable for FixedSizeListArray { - type Output = DaftResult; - - fn list(&self) -> Self::Output { - // TODO(FixedSizeList) - todo!("Requires new ListArrays for implementation") - } - - fn grouped_list(&self, _groups: &GroupIndices) -> Self::Output { - // TODO(FixedSizeList) - todo!("Requires new ListArrays for implementation") - } -} - -impl DaftListAggable for StructArray { - type Output = DaftResult; - - fn list(&self) -> Self::Output { - // TODO(FixedSizeList) - todo!("Requires new ListArrays for implementation") - } - - fn grouped_list(&self, _groups: &GroupIndices) -> Self::Output { - // TODO(FixedSizeList) - todo!("Requires new ListArrays for implementation") - } -} diff --git a/src/daft-core/src/array/ops/map.rs b/src/daft-core/src/array/ops/map.rs index 3b2f6ffd8c..c9daafe2c4 100644 --- a/src/daft-core/src/array/ops/map.rs +++ b/src/daft-core/src/array/ops/map.rs @@ -1,4 +1,5 @@ use common_error::{DaftError, DaftResult}; +use itertools::Itertools; use crate::{ array::{ops::DaftCompare, prelude::*}, @@ -6,13 +7,21 @@ use crate::{ series::Series, }; -fn single_map_get(structs: &Series, key_to_get: &Series) -> DaftResult { +fn single_map_get( + structs: &Series, + key_to_get: &Series, + coerce_value: &DataType, +) -> DaftResult { let (keys, values) = { let struct_array = structs.struct_()?; (struct_array.get("key")?, struct_array.get("value")?) }; + let mask = keys.equal(key_to_get)?; let filtered = values.filter(&mask)?; + + let filtered = filtered.cast(coerce_value)?; + if filtered.is_empty() { Ok(Series::full_null("value", values.data_type(), 1)) } else if filtered.len() == 1 { @@ -24,19 +33,10 @@ fn single_map_get(structs: &Series, key_to_get: &Series) -> DaftResult { impl MapArray { pub fn map_get(&self, key_to_get: &Series) -> DaftResult { - let value_type = if let DataType::Map(inner_dtype) = self.data_type() { - match *inner_dtype.clone() { - DataType::Struct(fields) if fields.len() == 2 => { - fields[1].dtype.clone() - } - _ => { - return Err(DaftError::TypeError(format!( - "Expected inner type to be a struct type with two fields: key and value, got {:?}", - inner_dtype - ))) - } - } - } else { + let DataType::Map { + value: value_type, .. + } = self.data_type() + else { return Err(DaftError::TypeError(format!( "Expected input to be a map type, got {:?}", self.data_type() @@ -44,30 +44,49 @@ impl MapArray { }; match key_to_get.len() { - 1 => { - let mut result = Vec::with_capacity(self.len()); - for series in self.physical.into_iter() { - match series { - Some(s) if !s.is_empty() => result.push(single_map_get(&s, key_to_get)?), - _ => result.push(Series::full_null("value", &value_type, 1)), - } - } - Series::concat(&result.iter().collect::>()) - } - len if len == self.len() => { - let mut result = Vec::with_capacity(len); - for (i, series) in self.physical.into_iter().enumerate() { - match (series, key_to_get.slice(i, i + 1)?) { - (Some(s), k) if !s.is_empty() => result.push(single_map_get(&s, &k)?), - _ => result.push(Series::full_null("value", &value_type, 1)), - } - } - Series::concat(&result.iter().collect::>()) - } + 1 => self.get_single_key(key_to_get, value_type), + len if len == self.len() => self.get_multiple_keys(key_to_get, value_type), _ => Err(DaftError::ValueError(format!( "Expected key to have length 1 or length equal to the map length, got {}", key_to_get.len() ))), } } + + fn get_single_key(&self, key_to_get: &Series, coerce_value: &DataType) -> DaftResult { + let result: Vec<_> = self + .physical + .iter() + .map(|series| match series { + Some(s) if !s.is_empty() => single_map_get(&s, key_to_get, coerce_value), + _ => Ok(Series::full_null("value", coerce_value, 1)), + }) + .try_collect()?; + + let result: Vec<_> = result.iter().collect(); + + Series::concat(&result) + } + + fn get_multiple_keys( + &self, + key_to_get: &Series, + coerce_value: &DataType, + ) -> DaftResult { + let result: Vec<_> = self + .physical + .iter() + .enumerate() + .map(|(i, series)| match series { + Some(s) if !s.is_empty() => { + single_map_get(&s, &key_to_get.slice(i, i + 1)?, coerce_value) + } + _ => Ok(Series::full_null("value", coerce_value, 1)), + }) + .try_collect()?; + + let result: Vec<_> = result.iter().collect(); + + Series::concat(&result) + } } diff --git a/src/daft-core/src/array/ops/mean.rs b/src/daft-core/src/array/ops/mean.rs index b4b4016bbc..d5764c4954 100644 --- a/src/daft-core/src/array/ops/mean.rs +++ b/src/daft-core/src/array/ops/mean.rs @@ -1,44 +1,27 @@ use std::sync::Arc; +use arrow2::array::PrimitiveArray; use common_error::DaftResult; -use super::{as_arrow::AsArrow, DaftCountAggable, DaftMeanAggable, DaftSumAggable}; -use crate::{array::ops::GroupIndices, count_mode::CountMode, datatypes::*}; -impl DaftMeanAggable for &DataArray { - type Output = DaftResult>; +use crate::{ + array::ops::{DaftMeanAggable, GroupIndices}, + datatypes::*, + utils::stats, +}; - fn mean(&self) -> Self::Output { - let sum_value = DaftSumAggable::sum(self)?.as_arrow().value(0); - let count_value = DaftCountAggable::count(self, CountMode::Valid)? - .as_arrow() - .value(0); - - let result = match count_value { - 0 => None, - count_value => Some(sum_value / count_value as f64), - }; - let arrow_array = Box::new(arrow2::array::PrimitiveArray::from([result])); +impl DaftMeanAggable for DataArray { + type Output = DaftResult; - DataArray::new( - Arc::new(Field::new(self.field.name.clone(), DataType::Float64)), - arrow_array, - ) + fn mean(&self) -> Self::Output { + let stats = stats::calculate_stats(self)?; + let data = PrimitiveArray::from([stats.mean]).boxed(); + let field = Arc::new(Field::new(self.field.name.clone(), DataType::Float64)); + Self::new(field, data) } fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output { - use arrow2::array::PrimitiveArray; - let sum_values = self.grouped_sum(groups)?; - let count_values = self.grouped_count(groups, CountMode::Valid)?; - assert_eq!(sum_values.len(), count_values.len()); - let mean_per_group = sum_values - .as_arrow() - .values_iter() - .zip(count_values.as_arrow().values_iter()) - .map(|(s, c)| match (s, c) { - (_, 0) => None, - (s, c) => Some(s / (*c as f64)), - }); - let mean_array = Box::new(PrimitiveArray::from_trusted_len_iter(mean_per_group)); - Ok(DataArray::from((self.field.name.as_ref(), mean_array))) + let grouped_means = stats::grouped_stats(self, groups)?.map(|(stats, _)| stats.mean); + let data = Box::new(PrimitiveArray::from_iter(grouped_means)); + Ok(Self::from((self.field.name.as_ref(), data))) } } diff --git a/src/daft-core/src/array/ops/minhash.rs b/src/daft-core/src/array/ops/minhash.rs index edf7bb6a61..0596d6951b 100644 --- a/src/daft-core/src/array/ops/minhash.rs +++ b/src/daft-core/src/array/ops/minhash.rs @@ -34,7 +34,7 @@ impl DaftMinHash for Utf8Array { let self_arrow = self.as_arrow(); let mut output: MutablePrimitiveArray = MutablePrimitiveArray::with_capacity(num_hashes * self.len()); - for maybe_s in self_arrow.iter() { + for maybe_s in self_arrow { if let Some(s) = maybe_s { let minhash_res = daft_minhash::minhash( s, diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index d3a940f376..3bcf0f0cb9 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -49,6 +49,7 @@ mod sketch_percentile; mod sort; pub(crate) mod sparse_tensor; mod sqrt; +mod stddev; mod struct_; mod sum; mod take; @@ -189,6 +190,12 @@ pub trait DaftMeanAggable { fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output; } +pub trait DaftStddevAggable { + type Output; + fn stddev(&self) -> Self::Output; + fn grouped_stddev(&self, groups: &GroupIndices) -> Self::Output; +} + pub trait DaftCompareAggable { type Output; fn min(&self) -> Self::Output; diff --git a/src/daft-core/src/array/ops/repr.rs b/src/daft-core/src/array/ops/repr.rs index 8d60f697c7..ad8fe9b7c7 100644 --- a/src/daft-core/src/array/ops/repr.rs +++ b/src/daft-core/src/array/ops/repr.rs @@ -13,7 +13,9 @@ use crate::{ NullArray, UInt64Array, Utf8Array, }, series::Series, - utils::display::{display_date32, display_decimal128, display_time64, display_timestamp}, + utils::display::{ + display_date32, display_decimal128, display_duration, display_time64, display_timestamp, + }, with_match_daft_types, }; @@ -34,7 +36,6 @@ macro_rules! impl_array_str_value { impl_array_str_value!(BooleanArray, "{}"); impl_array_str_value!(ExtensionArray, "{:?}"); -impl_array_str_value!(DurationArray, "{}"); fn pretty_print_bytes(bytes: &[u8], max_len: usize) -> DaftResult { /// influenced by pythons bytes repr @@ -105,9 +106,12 @@ impl Utf8Array { } impl NullArray { pub fn str_value(&self, idx: usize) -> DaftResult { - if idx >= self.len() { - panic!("Out of bounds: {} vs len: {}", idx, self.len()) - } + assert!( + idx < self.len(), + "Out of bounds: {} vs len: {}", + idx, + self.len() + ); Ok("None".to_string()) } } @@ -192,6 +196,21 @@ impl TimestampArray { } } +impl DurationArray { + pub fn str_value(&self, idx: usize) -> DaftResult { + let res = self.get(idx).map_or_else( + || "None".to_string(), + |val| -> String { + let DataType::Duration(time_unit) = &self.field.dtype else { + panic!("Wrong dtype for DurationArray: {}", self.field.dtype) + }; + display_duration(val, time_unit) + }, + ); + Ok(res) + } +} + impl Decimal128Array { pub fn str_value(&self, idx: usize) -> DaftResult { let res = self.get(idx).map_or_else( diff --git a/src/daft-core/src/array/ops/sort.rs b/src/daft-core/src/array/ops/sort.rs index ba2d791101..19bf41574e 100644 --- a/src/daft-core/src/array/ops/sort.rs +++ b/src/daft-core/src/array/ops/sort.rs @@ -45,7 +45,7 @@ pub fn build_multi_array_bicompare( } let combined_comparator = Box::new(move |a_idx: usize, b_idx: usize| -> std::cmp::Ordering { - for comparator in cmp_list.iter() { + for comparator in &cmp_list { match comparator(a_idx, b_idx) { std::cmp::Ordering::Equal => continue, other => return other, diff --git a/src/daft-core/src/array/ops/sparse_tensor.rs b/src/daft-core/src/array/ops/sparse_tensor.rs index 696a5996b8..010a6740a3 100644 --- a/src/daft-core/src/array/ops/sparse_tensor.rs +++ b/src/daft-core/src/array/ops/sparse_tensor.rs @@ -63,6 +63,7 @@ mod tests { Some(validity.clone()), ) .into_series(); + let indices_array = ListArray::new( Field::new("indices", DataType::List(Box::new(DataType::UInt64))), UInt64Array::from(( @@ -90,11 +91,12 @@ mod tests { Some(validity.clone()), ) .into_series(); + let dtype = DataType::SparseTensor(Box::new(DataType::Int64)); let struct_array = StructArray::new( Field::new("tensor", dtype.to_physical()), vec![values_array, indices_array, shapes_array], - Some(validity.clone()), + Some(validity), ); let sparse_tensor_array = SparseTensorArray::new(Field::new(struct_array.name(), dtype.clone()), struct_array); @@ -103,9 +105,12 @@ mod tests { let fixed_shape_sparse_tensor_array = sparse_tensor_array.cast(&fixed_shape_sparse_tensor_dtype)?; let roundtrip_tensor = fixed_shape_sparse_tensor_array.cast(&dtype)?; - assert!(roundtrip_tensor - .to_arrow() - .eq(&sparse_tensor_array.to_arrow())); + + let round_trip_tensor_arrow = roundtrip_tensor.to_arrow(); + let sparse_tensor_array_arrow = sparse_tensor_array.to_arrow(); + + assert_eq!(round_trip_tensor_arrow, sparse_tensor_array_arrow); + Ok(()) } } diff --git a/src/daft-core/src/array/ops/stddev.rs b/src/daft-core/src/array/ops/stddev.rs new file mode 100644 index 0000000000..c412922937 --- /dev/null +++ b/src/daft-core/src/array/ops/stddev.rs @@ -0,0 +1,34 @@ +use arrow2::array::PrimitiveArray; +use common_error::DaftResult; + +use crate::{ + array::{ + ops::{DaftStddevAggable, GroupIndices}, + DataArray, + }, + datatypes::Float64Type, + utils::stats, +}; + +impl DaftStddevAggable for DataArray { + type Output = DaftResult; + + fn stddev(&self) -> Self::Output { + let stats = stats::calculate_stats(self)?; + let values = self.into_iter().flatten().copied(); + let stddev = stats::calculate_stddev(stats, values); + let field = self.field.clone(); + let data = PrimitiveArray::::from([stddev]).boxed(); + Self::new(field, data) + } + + fn grouped_stddev(&self, groups: &GroupIndices) -> Self::Output { + let grouped_stddevs_iter = stats::grouped_stats(self, groups)?.map(|(stats, group)| { + let values = group.iter().filter_map(|&index| self.get(index as _)); + stats::calculate_stddev(stats, values) + }); + let field = self.field.clone(); + let data = PrimitiveArray::::from_iter(grouped_stddevs_iter).boxed(); + Self::new(field, data) + } +} diff --git a/src/daft-core/src/array/ops/struct_.rs b/src/daft-core/src/array/ops/struct_.rs index 64fbf74cc8..e077c577d8 100644 --- a/src/daft-core/src/array/ops/struct_.rs +++ b/src/daft-core/src/array/ops/struct_.rs @@ -52,7 +52,7 @@ mod tests { "foo", DataType::Struct(vec![Field::new("bar", DataType::Int64)]), ), - vec![child.clone().into_series()], + vec![child.into_series()], None, ); @@ -68,7 +68,7 @@ mod tests { assert_eq!(old_child.get(2), None); assert_eq!(old_child.get(3), None); - parent = parent.with_validity(Some(parent_validity.clone()))?; + parent = parent.with_validity(Some(parent_validity))?; let new_child = parent.get("bar")?.i64()?.clone(); let new_child_validity = new_child diff --git a/src/daft-core/src/array/ops/tensor.rs b/src/daft-core/src/array/ops/tensor.rs index c1cd0f13ec..17c16f1793 100644 --- a/src/daft-core/src/array/ops/tensor.rs +++ b/src/daft-core/src/array/ops/tensor.rs @@ -68,7 +68,7 @@ mod tests { let struct_array = StructArray::new( Field::new("tensor", dtype.to_physical()), vec![list_array, shapes_array], - Some(validity.clone()), + Some(validity), ); let tensor_array = TensorArray::new(Field::new(struct_array.name(), dtype.clone()), struct_array); @@ -85,7 +85,7 @@ mod tests { let validity = arrow2::bitmap::Bitmap::from(raw_validity.as_slice()); let field = Field::new("foo", DataType::FixedSizeList(Box::new(DataType::Int64), 3)); let flat_child = Int64Array::from(("foo", (0..9).collect::>())); - let arr = FixedSizeListArray::new(field, flat_child.into_series(), Some(validity.clone())); + let arr = FixedSizeListArray::new(field, flat_child.into_series(), Some(validity)); let dtype = DataType::FixedShapeTensor(Box::new(DataType::Int64), vec![3]); let tensor_array = FixedShapeTensorArray::new(Field::new("data", dtype.clone()), arr); let sparse_tensor_dtype = diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index ebac895e20..f67ee6977b 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -275,9 +275,9 @@ fn substring(s: &str, start: usize, len: Option) -> Option<&str> { Some(len) => { if len == 0 { return None; - } else { - len } + + len } None => { return Some(&s[start_pos..]); @@ -342,7 +342,7 @@ pub enum PadPlacement { Right, } -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, Default)] pub struct Utf8NormalizeOptions { pub remove_punct: bool, pub lowercase: bool, @@ -438,7 +438,7 @@ impl Utf8Array { &mut splits, &mut offsets, &mut validity, - )? + )?; } (true, _) => { let regex_iter = pattern @@ -451,7 +451,7 @@ impl Utf8Array { &mut splits, &mut offsets, &mut validity, - )? + )?; } (false, _) => { let pattern_iter = create_broadcasted_str_iter(pattern, expected_size); @@ -461,7 +461,7 @@ impl Utf8Array { &mut splits, &mut offsets, &mut validity, - )? + )?; } } // Shrink splits capacity to current length, since we will have overallocated if any of the patterns actually occurred in the strings. @@ -1389,7 +1389,7 @@ impl Utf8Array { // ensure this match is a whole word (or set of words) // don't want to filter out things like "brass" let prev_char = s.get(m.start() - 1..m.start()); - let next_char = s.get(m.end()..m.end() + 1); + let next_char = s.get(m.end()..=m.end()); !(prev_char.is_some_and(|s| s.chars().next().unwrap().is_alphabetic()) || next_char .is_some_and(|s| s.chars().next().unwrap().is_alphabetic())) diff --git a/src/daft-core/src/array/pseudo_arrow/compute.rs b/src/daft-core/src/array/pseudo_arrow/compute.rs index 65b11a69c1..d49c2d6000 100644 --- a/src/daft-core/src/array/pseudo_arrow/compute.rs +++ b/src/daft-core/src/array/pseudo_arrow/compute.rs @@ -9,7 +9,7 @@ impl PseudoArrowArray { // Concatenate the values and the validity separately. let mut concatenated_values: Vec = Vec::new(); - for array in arrays.iter() { + for array in &arrays { concatenated_values.extend_from_slice(array.values()); } diff --git a/src/daft-core/src/array/serdes.rs b/src/daft-core/src/array/serdes.rs index cc908c0dd6..0976f53a0a 100644 --- a/src/daft-core/src/array/serdes.rs +++ b/src/daft-core/src/array/serdes.rs @@ -130,7 +130,11 @@ impl serde::Serialize for ExtensionArray { let mut s = serializer.serialize_map(Some(2))?; s.serialize_entry("field", self.field())?; let values = if let DataType::Extension(_, inner, _) = self.data_type() { - Series::try_from(("physical", self.data.to_type(inner.to_arrow().unwrap()))).unwrap() + Series::try_from(( + "physical", + self.data.convert_logical_type(inner.to_arrow().unwrap()), + )) + .unwrap() } else { panic!("Expected Extension Type!") }; diff --git a/src/daft-core/src/array/struct_array.rs b/src/daft-core/src/array/struct_array.rs index 996680ede5..d64350736b 100644 --- a/src/daft-core/src/array/struct_array.rs +++ b/src/daft-core/src/array/struct_array.rs @@ -11,6 +11,8 @@ use crate::{ #[derive(Clone, Debug)] pub struct StructArray { pub field: Arc, + + /// Column representations pub children: Vec, validity: Option, len: usize, @@ -31,16 +33,15 @@ impl StructArray { let field: Arc = field.into(); match &field.as_ref().dtype { DataType::Struct(fields) => { - if fields.len() != children.len() { - panic!("StructArray::new received {} children arrays but expected {} for specified dtype: {}", children.len(), fields.len(), &field.as_ref().dtype) - } + assert!(fields.len() == children.len(), "StructArray::new received {} children arrays but expected {} for specified dtype: {}", children.len(), fields.len(), &field.as_ref().dtype); for (dtype_field, series) in fields.iter().zip(children.iter()) { - if &dtype_field.dtype != series.data_type() { - panic!("StructArray::new received an array with dtype: {} but expected child field: {}", series.data_type(), dtype_field) - } - if dtype_field.name != series.name() { - panic!("StructArray::new received a series with name: {} but expected name: {}", series.name(), &dtype_field.name) - } + assert!(!(&dtype_field.dtype != series.data_type()), "StructArray::new received an array with dtype: {} but expected child field: {}", series.data_type(), dtype_field); + assert!( + dtype_field.name == series.name(), + "StructArray::new received a series with name: {} but expected name: {}", + series.name(), + &dtype_field.name + ); } let len = if !children.is_empty() { @@ -49,10 +50,8 @@ impl StructArray { 0 }; - for s in children.iter() { - if s.len() != len { - panic!("StructArray::new expects all children to have the same length, but received: {} vs {}", s.len(), len) - } + for s in &children { + assert!(s.len() == len, "StructArray::new expects all children to have the same length, but received: {} vs {}", s.len(), len); } if let Some(some_validity) = &validity && some_validity.len() != len @@ -82,6 +81,13 @@ impl StructArray { self.validity.as_ref() } + pub fn null_count(&self) -> usize { + match self.validity() { + None => 0, + Some(validity) => validity.unset_bits(), + } + } + pub fn concat(arrays: &[&Self]) -> DaftResult { if arrays.is_empty() { return Err(DaftError::ValueError( diff --git a/src/daft-core/src/datatypes/agg_ops.rs b/src/daft-core/src/datatypes/agg_ops.rs index a6420b039b..c1f04fecbe 100644 --- a/src/daft-core/src/datatypes/agg_ops.rs +++ b/src/daft-core/src/datatypes/agg_ops.rs @@ -23,7 +23,7 @@ pub fn try_sum_supertype(dtype: &DataType) -> DaftResult { } /// Get the data type that the mean of a column of the given data type should be casted to. -pub fn try_mean_supertype(dtype: &DataType) -> DaftResult { +pub fn try_mean_stddev_aggregation_supertype(dtype: &DataType) -> DaftResult { if dtype.is_numeric() { Ok(DataType::Float64) } else { diff --git a/src/daft-core/src/datatypes/infer_datatype.rs b/src/daft-core/src/datatypes/infer_datatype.rs index 9c05eb0b02..08f83d5198 100644 --- a/src/daft-core/src/datatypes/infer_datatype.rs +++ b/src/daft-core/src/datatypes/infer_datatype.rs @@ -120,7 +120,7 @@ impl<'a> Add for InferDataType<'a> { type Output = DaftResult; fn add(self, other: Self) -> Self::Output { - try_numeric_supertype(self.0, other.0).or(try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| {InferDataType::from(l) + InferDataType::from(r)})).or( + try_numeric_supertype(self.0, other.0).or_else(|_| try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| {InferDataType::from(l) + InferDataType::from(r)})).or( match (self.0, other.0) { #[cfg(feature = "python")] (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), @@ -138,27 +138,27 @@ impl<'a> Add for InferDataType<'a> { (du_self @ &DataType::Duration(..), du_other @ &DataType::Duration(..)) => Err(DaftError::TypeError( format!("Cannot add due to differing precision: {}, {}. Please explicitly cast to the precision you wish to add in.", du_self, du_other) )), - (DataType::Null, other) | (other, DataType::Null) => { + (dtype @ DataType::Null, other) | (other, dtype @ DataType::Null) => { match other { // Condition is for backwards compatibility. TODO: remove DataType::Binary | DataType::FixedSizeBinary(..) | DataType::Date => Err(DaftError::TypeError( - format!("Cannot add types: {}, {}", self, other) + format!("Cannot add types: {}, {}", dtype, other) )), other if other.is_physical() => Ok(other.clone()), _ => Err(DaftError::TypeError( - format!("Cannot add types: {}, {}", self, other) + format!("Cannot add types: {}, {}", dtype, other) )), } } - (DataType::Utf8, other) | (other, DataType::Utf8) => { + (dtype @ DataType::Utf8, other) | (other, dtype @ DataType::Utf8) => { match other { // DataType::Date condition is for backwards compatibility. TODO: remove DataType::Binary | DataType::FixedSizeBinary(..) | DataType::Date => Err(DaftError::TypeError( - format!("Cannot add types: {}, {}", self, other) + format!("Cannot add types: {}, {}", dtype, other) )), other if other.is_physical() => Ok(DataType::Utf8), _ => Err(DaftError::TypeError( - format!("Cannot add types: {}, {}", self, other) + format!("Cannot add types: {}, {}", dtype, other) )), } } @@ -176,7 +176,7 @@ impl<'a> Sub for InferDataType<'a> { type Output = DaftResult; fn sub(self, other: Self) -> Self::Output { - try_numeric_supertype(self.0, other.0).or(try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| {InferDataType::from(l) - InferDataType::from(r)})).or( + try_numeric_supertype(self.0, other.0).or_else(|_| try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| {InferDataType::from(l) - InferDataType::from(r)})).or( match (self.0, other.0) { #[cfg(feature = "python")] (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), @@ -219,9 +219,11 @@ impl<'a> Div for InferDataType<'a> { self, other ))), } - .or(try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| { - InferDataType::from(l) / InferDataType::from(r) - })) + .or_else(|_| { + try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| { + InferDataType::from(l) / InferDataType::from(r) + }) + }) } } @@ -230,9 +232,11 @@ impl<'a> Mul for InferDataType<'a> { fn mul(self, other: Self) -> Self::Output { try_numeric_supertype(self.0, other.0) - .or(try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| { - InferDataType::from(l) * InferDataType::from(r) - })) + .or_else(|_| { + try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| { + InferDataType::from(l) * InferDataType::from(r) + }) + }) .or(match (self.0, other.0) { #[cfg(feature = "python")] (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), @@ -249,9 +253,11 @@ impl<'a> Rem for InferDataType<'a> { fn rem(self, other: Self) -> Self::Output { try_numeric_supertype(self.0, other.0) - .or(try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| { - InferDataType::from(l) % InferDataType::from(r) - })) + .or_else(|_| { + try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| { + InferDataType::from(l) % InferDataType::from(r) + }) + }) .or(match (self.0, other.0) { #[cfg(feature = "python")] (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), @@ -394,7 +400,7 @@ pub fn try_numeric_supertype(l: &DataType, r: &DataType) -> DaftResult } inner(l, r) - .or(inner(r, l)) + .or_else(|| inner(r, l)) .ok_or(DaftError::TypeError(format!( "Invalid arguments to numeric supertype: {}, {}", l, r diff --git a/src/daft-core/src/datatypes/logical.rs b/src/daft-core/src/datatypes/logical.rs index 86d84535e1..9704b3b76f 100644 --- a/src/daft-core/src/datatypes/logical.rs +++ b/src/daft-core/src/datatypes/logical.rs @@ -44,6 +44,7 @@ impl LogicalArrayImpl { &field.dtype.to_physical(), physical.data_type() ); + Self { physical, field, diff --git a/src/daft-core/src/datatypes/matching.rs b/src/daft-core/src/datatypes/matching.rs index b8b8e1660f..c275bb4a2d 100644 --- a/src/daft-core/src/datatypes/matching.rs +++ b/src/daft-core/src/datatypes/matching.rs @@ -8,43 +8,43 @@ macro_rules! with_match_daft_types {( use $crate::datatypes::*; match $key_type { - Null => __with_ty__! { NullType }, + // Float16 => unimplemented!("Array for Float16 DataType not implemented"), + Binary => __with_ty__! { BinaryType }, Boolean => __with_ty__! { BooleanType }, - Int8 => __with_ty__! { Int8Type }, + Date => __with_ty__! { DateType }, + Decimal128(..) => __with_ty__! { Decimal128Type }, + Duration(_) => __with_ty__! { DurationType }, + Embedding(..) => __with_ty__! { EmbeddingType }, + Extension(_, _, _) => __with_ty__! { ExtensionType }, + FixedShapeImage(..) => __with_ty__! { FixedShapeImageType }, + FixedShapeSparseTensor(..) => __with_ty__! { FixedShapeSparseTensorType }, + FixedShapeTensor(..) => __with_ty__! { FixedShapeTensorType }, + FixedSizeBinary(_) => __with_ty__! { FixedSizeBinaryType }, + FixedSizeList(_, _) => __with_ty__! { FixedSizeListType }, + Float32 => __with_ty__! { Float32Type }, + Float64 => __with_ty__! { Float64Type }, + Image(..) => __with_ty__! { ImageType }, + Int128 => __with_ty__! { Int128Type }, Int16 => __with_ty__! { Int16Type }, Int32 => __with_ty__! { Int32Type }, Int64 => __with_ty__! { Int64Type }, - Int128 => __with_ty__! { Int128Type }, - UInt8 => __with_ty__! { UInt8Type }, + Int8 => __with_ty__! { Int8Type }, + List(_) => __with_ty__! { ListType }, + Map{..} => __with_ty__! { MapType }, + Null => __with_ty__! { NullType }, + SparseTensor(..) => __with_ty__! { SparseTensorType }, + Struct(_) => __with_ty__! { StructType }, + Tensor(..) => __with_ty__! { TensorType }, + Time(_) => __with_ty__! { TimeType }, + Timestamp(_, _) => __with_ty__! { TimestampType }, UInt16 => __with_ty__! { UInt16Type }, UInt32 => __with_ty__! { UInt32Type }, UInt64 => __with_ty__! { UInt64Type }, - Float32 => __with_ty__! { Float32Type }, - Float64 => __with_ty__! { Float64Type }, - Timestamp(_, _) => __with_ty__! { TimestampType }, - Date => __with_ty__! { DateType }, - Time(_) => __with_ty__! { TimeType }, - Duration(_) => __with_ty__! { DurationType }, - Binary => __with_ty__! { BinaryType }, - FixedSizeBinary(_) => __with_ty__! { FixedSizeBinaryType }, + UInt8 => __with_ty__! { UInt8Type }, + Unknown => unimplemented!("Array for Unknown DataType not implemented"), Utf8 => __with_ty__! { Utf8Type }, - FixedSizeList(_, _) => __with_ty__! { FixedSizeListType }, - List(_) => __with_ty__! { ListType }, - Struct(_) => __with_ty__! { StructType }, - Map(_) => __with_ty__! { MapType }, - Extension(_, _, _) => __with_ty__! { ExtensionType }, #[cfg(feature = "python")] Python => __with_ty__! { PythonType }, - Embedding(..) => __with_ty__! { EmbeddingType }, - Image(..) => __with_ty__! { ImageType }, - FixedShapeImage(..) => __with_ty__! { FixedShapeImageType }, - Tensor(..) => __with_ty__! { TensorType }, - FixedShapeTensor(..) => __with_ty__! { FixedShapeTensorType }, - SparseTensor(..) => __with_ty__! { SparseTensorType }, - FixedShapeSparseTensor(..) => __with_ty__! { FixedShapeSparseTensorType }, - Decimal128(..) => __with_ty__! { Decimal128Type }, - // Float16 => unimplemented!("Array for Float16 DataType not implemented"), - Unknown => unimplemented!("Array for Unknown DataType not implemented"), // NOTE: We should not implement a default for match here, because this is meant to be // an exhaustive match across **all** Daft types. diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index 174098ada9..01a6b6ca6e 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -6,7 +6,7 @@ pub use infer_datatype::InferDataType; pub mod prelude; use std::ops::{Add, Div, Mul, Rem, Sub}; -pub use agg_ops::{try_mean_supertype, try_sum_supertype}; +pub use agg_ops::{try_mean_stddev_aggregation_supertype, try_sum_supertype}; use arrow2::{ compute::comparison::Simd8, types::{simd::Simd, NativeType}, diff --git a/src/daft-core/src/kernels/search_sorted.rs b/src/daft-core/src/kernels/search_sorted.rs index f8b0a0f946..cff5785b08 100644 --- a/src/daft-core/src/kernels/search_sorted.rs +++ b/src/daft-core/src/kernels/search_sorted.rs @@ -27,7 +27,7 @@ where { let mut last_key = keys.iter().next().unwrap_or(None); let less = |l: &T, r: &T| l < r || (r != r && l == l); - for key_val in keys.iter() { + for key_val in keys { let is_last_key_lt = match (last_key, key_val) { (None, None) => false, (None, Some(_)) => input_reversed, @@ -90,7 +90,7 @@ fn search_sorted_utf_array( let mut results: Vec = Vec::with_capacity(array_size); let mut last_key = keys.iter().next().unwrap_or(None); - for key_val in keys.iter() { + for key_val in keys { let is_last_key_lt = match (last_key, key_val) { (None, None) => false, (None, Some(_)) => input_reversed, @@ -228,7 +228,7 @@ fn search_sorted_binary_array( let mut results: Vec = Vec::with_capacity(array_size); let mut last_key = keys.iter().next().unwrap_or(None); - for key_val in keys.iter() { + for key_val in keys { let is_last_key_lt = match (last_key, key_val) { (None, None) => false, (None, Some(_)) => input_reversed, @@ -291,7 +291,7 @@ fn search_sorted_fixed_size_binary_array( let mut results: Vec = Vec::with_capacity(array_size); let mut last_key = keys.iter().next().unwrap_or(None); - for key_val in keys.iter() { + for key_val in keys { let is_last_key_lt = match (last_key, key_val) { (None, None) => false, (None, Some(_)) => input_reversed, @@ -536,7 +536,7 @@ pub fn search_sorted_multi_array( } let combined_comparator = |a_idx: usize, b_idx: usize| -> Ordering { - for comparator in cmp_list.iter() { + for comparator in &cmp_list { match comparator(a_idx, b_idx) { Ordering::Equal => continue, other => return other, diff --git a/src/daft-core/src/lib.rs b/src/daft-core/src/lib.rs index 322a0db3ec..5892f75ffb 100644 --- a/src/daft-core/src/lib.rs +++ b/src/daft-core/src/lib.rs @@ -2,6 +2,7 @@ #![feature(int_roundings)] #![feature(iterator_try_reduce)] #![feature(if_let_guard)] +#![feature(hash_raw_entry)] pub mod array; pub mod count_mode; diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index f57bf3f829..87304b12d1 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -740,7 +740,7 @@ fn infer_daft_dtype_for_sequence( .getattr(pyo3::intern!(py, "from_numpy_dtype"))? }; let mut dtype: Option = None; - for obj in vec_pyobj.iter() { + for obj in vec_pyobj { let obj = obj.bind(py); if let Ok(pil_image_type) = &py_pil_image_type && obj.is_instance(pil_image_type)? diff --git a/src/daft-core/src/series/from.rs b/src/daft-core/src/series/from.rs index 99776edf64..8c56aaa625 100644 --- a/src/daft-core/src/series/from.rs +++ b/src/daft-core/src/series/from.rs @@ -1,6 +1,8 @@ use std::sync::Arc; +use arrow2::datatypes::ArrowDataType; use common_error::{DaftError, DaftResult}; +use daft_schema::{dtype::DaftDataType, field::DaftField}; use super::Series; use crate::{ @@ -12,9 +14,10 @@ use crate::{ impl Series { pub fn try_from_field_and_arrow_array( - field: Arc, + field: impl Into>, array: Box, ) -> DaftResult { + let field = field.into(); // TODO(Nested): Refactor this out with nested logical types in StructArray and ListArray // Corner-case nested logical types that have not yet been migrated to new Array formats // to hold only casted physical arrow arrays. @@ -46,11 +49,90 @@ impl Series { impl TryFrom<(&str, Box)> for Series { type Error = DaftError; - fn try_from(item: (&str, Box)) -> DaftResult { - let (name, array) = item; - let source_arrow_type = array.data_type(); - let dtype: DataType = source_arrow_type.into(); - let field = Arc::new(Field::new(name, dtype.clone())); + fn try_from((name, array): (&str, Box)) -> DaftResult { + let source_arrow_type: &ArrowDataType = array.data_type(); + let dtype = DaftDataType::from(source_arrow_type); + + let field = Arc::new(Field::new(name, dtype)); Self::try_from_field_and_arrow_array(field, array) } } + +#[cfg(test)] +mod tests { + use std::sync::LazyLock; + + use arrow2::{ + array::Array, + datatypes::{ArrowDataType, ArrowField}, + }; + use common_error::DaftResult; + use daft_schema::dtype::DataType; + + static ARROW_DATA_TYPE: LazyLock = LazyLock::new(|| { + ArrowDataType::Map( + Box::new(ArrowField::new( + "entries", + ArrowDataType::Struct(vec![ + ArrowField::new("key", ArrowDataType::LargeUtf8, false), + ArrowField::new("value", ArrowDataType::Date32, true), + ]), + false, + )), + false, + ) + }); + + #[test] + fn test_map_type_conversion() { + let arrow_data_type = ARROW_DATA_TYPE.clone(); + let dtype = DataType::from(&arrow_data_type); + assert_eq!( + dtype, + DataType::Map { + key: Box::new(DataType::Utf8), + value: Box::new(DataType::Date), + }, + ); + } + + #[test] + fn test_map_array_conversion() -> DaftResult<()> { + use arrow2::array::MapArray; + + use super::*; + + let arrow_array = MapArray::new( + ARROW_DATA_TYPE.clone(), + vec![0, 1].try_into().unwrap(), + Box::new(arrow2::array::StructArray::new( + ArrowDataType::Struct(vec![ + ArrowField::new("key", ArrowDataType::LargeUtf8, false), + ArrowField::new("value", ArrowDataType::Date32, true), + ]), + vec![ + Box::new(arrow2::array::Utf8Array::::from_slice(["key1"])), + arrow2::array::Int32Array::from_slice([1]) + .convert_logical_type(ArrowDataType::Date32), + ], + None, + )), + None, + ); + + let series = Series::try_from(( + "test_map", + Box::new(arrow_array) as Box, + ))?; + + assert_eq!( + series.data_type(), + &DataType::Map { + key: Box::new(DataType::Utf8), + value: Box::new(DataType::Date), + } + ); + + Ok(()) + } +} diff --git a/src/daft-core/src/series/mod.rs b/src/daft-core/src/series/mod.rs index 128b1bd344..59a8f66d05 100644 --- a/src/daft-core/src/series/mod.rs +++ b/src/daft-core/src/series/mod.rs @@ -38,6 +38,12 @@ impl PartialEq for Series { } impl Series { + /// Exports this Series into an Arrow arrow that is corrected for the Arrow type system. + /// For example, Daft's TimestampArray is a logical type that is backed by an Int64Array Physical array. + /// If we were to call `.as_arrow()` or `.physical`on the TimestampArray, we would get an Int64Array that represented the time units. + /// However if we want to export our Timestamp array to another arrow system like arrow2 kernels or python, duckdb or more. + /// We should convert it back to the canonical arrow dtype of Timestamp rather than Int64. + /// To get the internal physical type without conversion, see `as_arrow()`. pub fn to_arrow(&self) -> Box { self.inner.to_arrow() } @@ -81,6 +87,7 @@ impl Series { self.inner.name() } + #[must_use] pub fn rename>(&self, name: S) -> Self { self.inner.rename(name.as_ref()) } @@ -117,6 +124,13 @@ impl Series { self.inner.validity() } + pub fn is_valid(&self, idx: usize) -> bool { + let Some(validity) = self.validity() else { + return true; + }; + validity.get_bit(idx) + } + /// Attempts to downcast the Series to a primitive slice /// This will return an error if the Series is not of the physical type `T` /// # Example diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index 541fe5c556..b3bfee765c 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -4,7 +4,10 @@ use logical::Decimal128Array; use crate::{ array::{ - ops::{DaftHllMergeAggable, GroupIndices}, + ops::{ + DaftApproxSketchAggable, DaftHllMergeAggable, DaftMeanAggable, DaftStddevAggable, + DaftSumAggable, GroupIndices, + }, ListArray, }, count_mode::CountMode, @@ -26,12 +29,10 @@ impl Series { } pub fn sum(&self, groups: Option<&GroupIndices>) -> DaftResult { - use crate::{array::ops::DaftSumAggable, datatypes::DataType::*}; - match self.data_type() { // intX -> int64 (in line with numpy) - Int8 | Int16 | Int32 | Int64 => { - let casted = self.cast(&Int64)?; + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + let casted = self.cast(&DataType::Int64)?; match groups { Some(groups) => { Ok(DaftSumAggable::grouped_sum(&casted.i64()?, groups)?.into_series()) @@ -40,8 +41,8 @@ impl Series { } } // uintX -> uint64 (in line with numpy) - UInt8 | UInt16 | UInt32 | UInt64 => { - let casted = self.cast(&UInt64)?; + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + let casted = self.cast(&DataType::UInt64)?; match groups { Some(groups) => { Ok(DaftSumAggable::grouped_sum(&casted.u64()?, groups)?.into_series()) @@ -50,7 +51,7 @@ impl Series { } } // floatX -> floatX (in line with numpy) - Float32 => match groups { + DataType::Float32 => match groups { Some(groups) => Ok(DaftSumAggable::grouped_sum( &self.downcast::()?, groups, @@ -58,7 +59,7 @@ impl Series { .into_series()), None => Ok(DaftSumAggable::sum(&self.downcast::()?)?.into_series()), }, - Float64 => match groups { + DataType::Float64 => match groups { Some(groups) => Ok(DaftSumAggable::grouped_sum( &self.downcast::()?, groups, @@ -66,7 +67,7 @@ impl Series { .into_series()), None => Ok(DaftSumAggable::sum(&self.downcast::()?)?.into_series()), }, - Decimal128(_, _) => match groups { + DataType::Decimal128(_, _) => match groups { Some(groups) => Ok(Decimal128Array::new( Field { dtype: try_sum_supertype(self.data_type())?, @@ -95,12 +96,10 @@ impl Series { } pub fn approx_sketch(&self, groups: Option<&GroupIndices>) -> DaftResult { - use crate::{array::ops::DaftApproxSketchAggable, datatypes::DataType::*}; - // Upcast all numeric types to float64 and compute approx_sketch. match self.data_type() { dt if dt.is_numeric() => { - let casted = self.cast(&Float64)?; + let casted = self.cast(&DataType::Float64)?; match groups { Some(groups) => Ok(DaftApproxSketchAggable::grouped_approx_sketch( &casted.f64()?, @@ -149,24 +148,25 @@ impl Series { } pub fn mean(&self, groups: Option<&GroupIndices>) -> DaftResult { - use crate::{array::ops::DaftMeanAggable, datatypes::DataType::*}; - // Upcast all numeric types to float64 and use f64 mean kernel. - match self.data_type() { - dt if dt.is_numeric() => { - let casted = self.cast(&Float64)?; - match groups { - Some(groups) => { - Ok(DaftMeanAggable::grouped_mean(&casted.f64()?, groups)?.into_series()) - } - None => Ok(DaftMeanAggable::mean(&casted.f64()?)?.into_series()), - } - } - other => Err(DaftError::TypeError(format!( - "Numeric mean is not implemented for type {}", - other - ))), - } + self.data_type().assert_is_numeric()?; + let casted = self.cast(&DataType::Float64)?; + let casted = casted.f64()?; + let series = groups + .map_or_else(|| casted.mean(), |groups| casted.grouped_mean(groups))? + .into_series(); + Ok(series) + } + + pub fn stddev(&self, groups: Option<&GroupIndices>) -> DaftResult { + // Upcast all numeric types to float64 and use f64 stddev kernel. + self.data_type().assert_is_numeric()?; + let casted = self.cast(&DataType::Float64)?; + let casted = casted.f64()?; + let series = groups + .map_or_else(|| casted.stddev(), |groups| casted.grouped_stddev(groups))? + .into_series(); + Ok(series) } pub fn min(&self, groups: Option<&GroupIndices>) -> DaftResult { @@ -191,7 +191,7 @@ impl Series { ))) } else { Box::new(PrimitiveArray::from_trusted_len_iter( - groups.iter().map(|g| g.first().cloned()), + groups.iter().map(|g| g.first().copied()), )) } } diff --git a/src/daft-core/src/series/ops/between.rs b/src/daft-core/src/series/ops/between.rs index 6c53cbb86c..e1a3a680db 100644 --- a/src/daft-core/src/series/ops/between.rs +++ b/src/daft-core/src/series/ops/between.rs @@ -22,7 +22,7 @@ impl Series { } else { (self.clone(), lower.clone(), upper.clone()) }; - if let DataType::Boolean = output_type { + if output_type == DataType::Boolean { match comp_type { #[cfg(feature = "python")] DataType::Python => Ok(py_between_op_utilfn(self, upper, lower)? diff --git a/src/daft-core/src/series/ops/concat.rs b/src/daft-core/src/series/ops/concat.rs index 9103255faf..e6b202b62f 100644 --- a/src/daft-core/src/series/ops/concat.rs +++ b/src/daft-core/src/series/ops/concat.rs @@ -7,30 +7,29 @@ use crate::{ impl Series { pub fn concat(series: &[&Self]) -> DaftResult { - if series.is_empty() { - return Err(DaftError::ValueError( - "Need at least 1 series to perform concat".to_string(), - )); - } + let all_types: Vec<_> = series.iter().map(|s| s.data_type().clone()).collect(); - if series.len() == 1 { - return Ok((*series.first().unwrap()).clone()); - } + match series { + [] => Err(DaftError::ValueError( + "Need at least 1 series to perform concat".to_string(), + )), + [single_series] => Ok((*single_series).clone()), + [first, rest @ ..] => { + let first_dtype = first.data_type(); + for s in rest { + if first_dtype != s.data_type() { + return Err(DaftError::TypeError(format!( + "Series concat requires all data types to match. Found mismatched types. All types: {:?}", + all_types + ))); + } + } - let first_dtype = series.first().unwrap().data_type(); - for s in series.iter().skip(1) { - if first_dtype != s.data_type() { - return Err(DaftError::TypeError(format!( - "Series concat requires all data types to match, {} vs {}", - first_dtype, - s.data_type() - ))); + with_match_daft_types!(first_dtype, |$T| { + let downcasted = series.into_iter().map(|s| s.downcast::<<$T as DaftDataType>::ArrayType>()).collect::>>()?; + Ok(<$T as DaftDataType>::ArrayType::concat(downcasted.as_slice())?.into_series()) + }) } } - - with_match_daft_types!(first_dtype, |$T| { - let downcasted = series.into_iter().map(|s| s.downcast::<<$T as DaftDataType>::ArrayType>()).collect::>>()?; - Ok(<$T as DaftDataType>::ArrayType::concat(downcasted.as_slice())?.into_series()) - }) } } diff --git a/src/daft-core/src/series/ops/hash.rs b/src/daft-core/src/series/ops/hash.rs index 5355353c62..1782ffc522 100644 --- a/src/daft-core/src/series/ops/hash.rs +++ b/src/daft-core/src/series/ops/hash.rs @@ -18,7 +18,7 @@ impl Series { pub fn hash_with_validity(&self, seed: Option<&UInt64Array>) -> DaftResult { let hash = self.hash(seed)?; - let validity = if let DataType::Null = self.data_type() { + let validity = if matches!(self.data_type(), DataType::Null) { Some(Bitmap::new_zeroed(self.len())) } else { self.validity().cloned() diff --git a/src/daft-core/src/series/ops/is_in.rs b/src/daft-core/src/series/ops/is_in.rs index d6655d4bb9..1e49b13825 100644 --- a/src/daft-core/src/series/ops/is_in.rs +++ b/src/daft-core/src/series/ops/is_in.rs @@ -28,7 +28,7 @@ impl Series { (self.clone(), items.clone()) }; - if let DataType::Boolean = output_type { + if output_type == DataType::Boolean { match comp_type { #[cfg(feature = "python")] DataType::Python => Ok(py_membership_op_utilfn(self, items)? diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index d9a17dd087..81a4788067 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -7,6 +7,22 @@ use crate::{ }; impl Series { + pub fn list_value_counts(&self) -> DaftResult { + let series = match self.data_type() { + DataType::List(_) => self.list()?.value_counts(), + DataType::FixedSizeList(..) => self.fixed_size_list()?.value_counts(), + dt => { + return Err(DaftError::TypeError(format!( + "List contains not implemented for {}", + dt + ))) + } + }? + .into_series(); + + Ok(series) + } + pub fn explode(&self) -> DaftResult { match self.data_type() { DataType::List(_) => self.list()?.explode(), diff --git a/src/daft-core/src/series/ops/map.rs b/src/daft-core/src/series/ops/map.rs index b624cd8aac..d5f1452bee 100644 --- a/src/daft-core/src/series/ops/map.rs +++ b/src/daft-core/src/series/ops/map.rs @@ -4,12 +4,13 @@ use crate::{datatypes::DataType, series::Series}; impl Series { pub fn map_get(&self, key: &Self) -> DaftResult { - match self.data_type() { - DataType::Map(_) => self.map()?.map_get(key), - dt => Err(DaftError::TypeError(format!( + let DataType::Map { .. } = self.data_type() else { + return Err(DaftError::TypeError(format!( "map.get not implemented for {}", - dt - ))), - } + self.data_type() + ))); + }; + + self.map()?.map_get(key) } } diff --git a/src/daft-core/src/series/serdes.rs b/src/daft-core/src/series/serdes.rs index bf7e42a1e0..76414e30e6 100644 --- a/src/daft-core/src/series/serdes.rs +++ b/src/daft-core/src/series/serdes.rs @@ -158,12 +158,13 @@ impl<'d> serde::Deserialize<'d> for Series { DataType::Extension(..) => { let physical = map.next_value::()?; let physical = physical.to_arrow(); - let ext_array = physical.to_type(field.dtype.to_arrow().unwrap()); + let ext_array = + physical.convert_logical_type(field.dtype.to_arrow().unwrap()); Ok(ExtensionArray::new(Arc::new(field), ext_array) .unwrap() .into_series()) } - DataType::Map(..) => { + DataType::Map { .. } => { let physical = map.next_value::()?; Ok(MapArray::new( Arc::new(field), diff --git a/src/daft-core/src/series/utils/mod.rs b/src/daft-core/src/series/utils/mod.rs index a262af9755..e093b50648 100644 --- a/src/daft-core/src/series/utils/mod.rs +++ b/src/daft-core/src/series/utils/mod.rs @@ -1,6 +1,6 @@ #[cfg(feature = "python")] pub(super) mod python_fn; -pub(crate) mod cast { +pub mod cast { macro_rules! cast_downcast_op { ($lhs:expr, $rhs:expr, $ty_expr:expr, $ty_type:ty, $op:ident) => {{ let lhs = $lhs.cast($ty_expr)?; diff --git a/src/daft-core/src/series/utils/python_fn.rs b/src/daft-core/src/series/utils/python_fn.rs index 2fb9112775..f0d4745999 100644 --- a/src/daft-core/src/series/utils/python_fn.rs +++ b/src/daft-core/src/series/utils/python_fn.rs @@ -2,7 +2,7 @@ use common_error::DaftResult; use crate::series::Series; -pub(crate) fn run_python_binary_operator_fn( +pub fn run_python_binary_operator_fn( lhs: &Series, rhs: &Series, operator_fn: &str, @@ -10,7 +10,7 @@ pub(crate) fn run_python_binary_operator_fn( python_binary_op_with_utilfn(lhs, rhs, operator_fn, "map_operator_arrow_semantics") } -pub(crate) fn run_python_binary_bool_operator( +pub fn run_python_binary_bool_operator( lhs: &Series, rhs: &Series, operator_fn: &str, @@ -39,7 +39,7 @@ fn python_binary_op_with_utilfn( }; let left_pylist = PySeries::from(lhs.clone()).to_pylist()?; - let right_pylist = PySeries::from(rhs.clone()).to_pylist()?; + let right_pylist = PySeries::from(rhs).to_pylist()?; let result_series: Series = Python::with_gil(|py| -> PyResult { let py_operator = @@ -60,7 +60,7 @@ fn python_binary_op_with_utilfn( Ok(result_series) } -pub(crate) fn py_membership_op_utilfn(lhs: &Series, rhs: &Series) -> DaftResult { +pub fn py_membership_op_utilfn(lhs: &Series, rhs: &Series) -> DaftResult { use pyo3::prelude::*; use crate::{datatypes::DataType, python::PySeries}; @@ -69,7 +69,7 @@ pub(crate) fn py_membership_op_utilfn(lhs: &Series, rhs: &Series) -> DaftResult< let rhs_casted = rhs.cast(&DataType::Python)?; let left_pylist = PySeries::from(lhs_casted.clone()).to_pylist()?; - let right_pylist = PySeries::from(rhs_casted.clone()).to_pylist()?; + let right_pylist = PySeries::from(rhs_casted).to_pylist()?; let result_series: Series = Python::with_gil(|py| -> PyResult { let result_pylist = PyModule::import_bound(py, pyo3::intern!(py, "daft.utils"))? @@ -92,11 +92,7 @@ pub(crate) fn py_membership_op_utilfn(lhs: &Series, rhs: &Series) -> DaftResult< Ok(result_series) } -pub(crate) fn py_between_op_utilfn( - value: &Series, - lower: &Series, - upper: &Series, -) -> DaftResult { +pub fn py_between_op_utilfn(value: &Series, lower: &Series, upper: &Series) -> DaftResult { use pyo3::prelude::*; use crate::{datatypes::DataType, python::PySeries}; @@ -132,8 +128,8 @@ pub(crate) fn py_between_op_utilfn( }; let value_pylist = PySeries::from(value_casted.clone()).to_pylist()?; - let lower_pylist = PySeries::from(lower_casted.clone()).to_pylist()?; - let upper_pylist = PySeries::from(upper_casted.clone()).to_pylist()?; + let lower_pylist = PySeries::from(lower_casted).to_pylist()?; + let upper_pylist = PySeries::from(upper_casted).to_pylist()?; let result_series: Series = Python::with_gil(|py| -> PyResult { let result_pylist = PyModule::import_bound(py, pyo3::intern!(py, "daft.utils"))? diff --git a/src/daft-core/src/utils/arrow.rs b/src/daft-core/src/utils/arrow.rs index 8e99be3897..229cc2dad1 100644 --- a/src/daft-core/src/utils/arrow.rs +++ b/src/daft-core/src/utils/arrow.rs @@ -117,7 +117,7 @@ pub fn cast_array_for_daft_if_needed( .unwrap(); let casted = cast_array_for_daft_if_needed(map_array.field().clone()); Box::new(arrow2::array::MapArray::new( - arrow2::datatypes::DataType::Map(to_field.clone(), sorted), + arrow2::datatypes::DataType::Map(to_field, sorted), map_array.offsets().clone(), casted, arrow_array.validity().cloned(), diff --git a/src/daft-core/src/utils/display.rs b/src/daft-core/src/utils/display.rs index 37593cda8a..e82a368b31 100644 --- a/src/daft-core/src/utils/display.rs +++ b/src/daft-core/src/utils/display.rs @@ -75,6 +75,54 @@ pub fn display_timestamp(val: i64, unit: &TimeUnit, timezone: &Option) - ) } +const UNITS: [&str; 4] = ["d", "h", "m", "s"]; +const SIZES: [[i64; 4]; 4] = [ + [ + 86_400_000_000_000, + 3_600_000_000_000, + 60_000_000_000, + 1_000_000_000, + ], // Nanoseconds + [86_400_000_000, 3_600_000_000, 60_000_000, 1_000_000], // Microseconds + [86_400_000, 3_600_000, 60_000, 1_000], // Milliseconds + [86_400, 3_600, 60, 1], // Seconds +]; + +pub fn display_duration(val: i64, unit: &TimeUnit) -> String { + let mut output = String::new(); + let (sizes, suffix, remainder_divisor) = match unit { + TimeUnit::Nanoseconds => (&SIZES[0], "ns", 1_000_000_000), + TimeUnit::Microseconds => (&SIZES[1], "µs", 1_000_000), + TimeUnit::Milliseconds => (&SIZES[2], "ms", 1_000), + TimeUnit::Seconds => (&SIZES[3], "s", 1), + }; + + if val == 0 { + return format!("0{}", suffix); + } + + for (i, &size) in sizes.iter().enumerate() { + let whole_num = if i == 0 { + val / size + } else { + (val % sizes[i - 1]) / size + }; + if whole_num != 0 { + output.push_str(&format!("{}{}", whole_num, UNITS[i])); + if val % size != 0 { + output.push(' '); + } + } + } + + let remainder = val % remainder_divisor; + if remainder != 0 && suffix != "s" { + output.push_str(&format!("{}{}", remainder, suffix)); + } + + output +} + pub fn display_decimal128(val: i128, _precision: u8, scale: i8) -> String { if scale < 0 { unimplemented!(); @@ -93,6 +141,7 @@ pub fn display_decimal128(val: i128, _precision: u8, scale: i8) -> String { } } +#[must_use] pub fn display_series_literal(series: &Series) -> String { if !series.is_empty() { format!( diff --git a/src/daft-core/src/utils/dyn_compare.rs b/src/daft-core/src/utils/dyn_compare.rs index e83f5bba4d..f5c11a6eaf 100644 --- a/src/daft-core/src/utils/dyn_compare.rs +++ b/src/daft-core/src/utils/dyn_compare.rs @@ -18,18 +18,17 @@ pub fn build_dyn_compare( nulls_equal: bool, nans_equal: bool, ) -> DaftResult { - if left != right { - Err(DaftError::TypeError(format!( - "Types do not match when creating comparator {} vs {}", - left, right - ))) - } else { + if left == right { Ok(build_dyn_array_compare( &left.to_physical().to_arrow()?, &right.to_physical().to_arrow()?, nulls_equal, nans_equal, )?) + } else { + Err(DaftError::TypeError(format!( + "Types do not match when creating comparator {left} vs {right}", + ))) } } diff --git a/src/daft-core/src/utils/mod.rs b/src/daft-core/src/utils/mod.rs index 2e039e6953..baf1dc66fd 100644 --- a/src/daft-core/src/utils/mod.rs +++ b/src/daft-core/src/utils/mod.rs @@ -2,4 +2,5 @@ pub mod arrow; pub mod display; pub mod dyn_compare; pub mod identity_hash_set; +pub mod stats; pub mod supertype; diff --git a/src/daft-core/src/utils/stats.rs b/src/daft-core/src/utils/stats.rs new file mode 100644 index 0000000000..de43b186ea --- /dev/null +++ b/src/daft-core/src/utils/stats.rs @@ -0,0 +1,82 @@ +use common_error::DaftResult; + +use crate::{ + array::{ + ops::{DaftCountAggable, DaftSumAggable, GroupIndices, VecIndices}, + prelude::{Float64Array, UInt64Array}, + }, + count_mode::CountMode, +}; + +#[derive(Clone, Copy, Default, Debug)] +pub struct Stats { + pub sum: f64, + pub count: f64, + pub mean: Option, +} + +pub fn calculate_stats(array: &Float64Array) -> DaftResult { + let sum = array.sum()?.get(0); + let count = array.count(CountMode::Valid)?.get(0); + let stats = sum + .zip(count) + .map_or_else(Default::default, |(sum, count)| Stats { + sum, + count: count as _, + mean: calculate_mean(sum, count), + }); + Ok(stats) +} + +pub fn grouped_stats<'a>( + array: &Float64Array, + groups: &'a GroupIndices, +) -> DaftResult> { + let grouped_sum = array.grouped_sum(groups)?; + let grouped_count = array.grouped_count(groups, CountMode::Valid)?; + debug_assert_eq!(grouped_sum.len(), grouped_count.len()); + debug_assert_eq!(grouped_sum.len(), groups.len()); + Ok(GroupedStats { + grouped_sum, + grouped_count, + groups: groups.iter().enumerate(), + }) +} + +struct GroupedStats<'a, I: Iterator> { + grouped_sum: Float64Array, + grouped_count: UInt64Array, + groups: I, +} + +impl<'a, I: Iterator> Iterator for GroupedStats<'a, I> { + type Item = (Stats, &'a VecIndices); + + fn next(&mut self) -> Option { + let (index, group) = self.groups.next()?; + let sum = self.grouped_sum.get(index); + let count = self.grouped_count.get(index); + let stats = sum + .zip(count) + .map_or_else(Default::default, |(sum, count)| Stats { + sum, + count: count as _, + mean: calculate_mean(sum, count), + }); + Some((stats, group)) + } +} + +pub fn calculate_mean(sum: f64, count: u64) -> Option { + match count { + 0 => None, + _ => Some(sum / count as f64), + } +} + +pub fn calculate_stddev(stats: Stats, values: impl Iterator) -> Option { + stats.mean.map(|mean| { + let sum_of_squares = values.map(|value| (value - mean).powi(2)).sum::(); + (sum_of_squares / stats.count).sqrt() + }) +} diff --git a/src/daft-core/src/utils/supertype.rs b/src/daft-core/src/utils/supertype.rs index 0ee0d50966..26330466f5 100644 --- a/src/daft-core/src/utils/supertype.rs +++ b/src/daft-core/src/utils/supertype.rs @@ -21,6 +21,7 @@ pub fn try_get_supertype(l: &DataType, r: &DataType) -> DaftResult { } } +#[must_use] pub fn get_supertype(l: &DataType, r: &DataType) -> Option { fn inner(l: &DataType, r: &DataType) -> Option { if l == r { diff --git a/src/daft-csv/src/metadata.rs b/src/daft-csv/src/metadata.rs index c8add38d96..14d5d472ab 100644 --- a/src/daft-csv/src/metadata.rs +++ b/src/daft-csv/src/metadata.rs @@ -29,6 +29,7 @@ pub struct CsvReadStats { } impl CsvReadStats { + #[must_use] pub fn new( total_bytes_read: usize, total_records_read: usize, @@ -83,7 +84,7 @@ pub async fn read_csv_schema_bulk( let result = runtime_handle .block_on_current_thread(async { let task_stream = futures::stream::iter(uris.iter().map(|uri| { - let owned_string = uri.to_string(); + let owned_string = (*uri).to_string(); let owned_client = io_client.clone(); let owned_io_stats = io_stats.clone(); let owned_parse_options = parse_options.clone(); @@ -134,7 +135,7 @@ pub(crate) async fn read_csv_schema_single( compression_codec, parse_options, // Truncate max_bytes to size if both are set. - max_bytes.map(|m| size.map(|s| m.min(s)).unwrap_or(m)), + max_bytes.map(|m| size.map_or(m, |s| m.min(s))), ) .await } @@ -220,7 +221,7 @@ where .headers() .await? .iter() - .map(|s| s.to_string()) + .map(std::string::ToString::to_string) .collect(), false, ) @@ -324,15 +325,14 @@ mod tests { let file = format!( "{}/test/iris_tiny.csv{}", env!("CARGO_MANIFEST_DIR"), - compression.map_or("".to_string(), |ext| format!(".{}", ext)) + compression.map_or(String::new(), |ext| format!(".{ext}")) ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let (schema, read_stats) = - read_csv_schema(file.as_ref(), None, None, io_client.clone(), None)?; + let (schema, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None)?; assert_eq!( schema, Schema::new(vec![ @@ -364,7 +364,7 @@ mod tests { file.as_ref(), Some(CsvParseOptions::default().with_delimiter(b'|')), None, - io_client.clone(), + io_client, None, )?; assert_eq!( @@ -391,7 +391,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let (_, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client.clone(), None)?; + let (_, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None)?; assert_eq!(read_stats.total_bytes_read, 328); assert_eq!(read_stats.total_records_read, 20); @@ -413,7 +413,7 @@ mod tests { file.as_ref(), Some(CsvParseOptions::default().with_has_header(false)), None, - io_client.clone(), + io_client, None, )?; assert_eq!( @@ -443,8 +443,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let (schema, read_stats) = - read_csv_schema(file.as_ref(), None, None, io_client.clone(), None)?; + let (schema, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None)?; assert_eq!( schema, Schema::new(vec![ @@ -469,8 +468,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let (schema, read_stats) = - read_csv_schema(file.as_ref(), None, None, io_client.clone(), None)?; + let (schema, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None)?; assert_eq!( schema, Schema::new(vec![ @@ -498,8 +496,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let (schema, read_stats) = - read_csv_schema(file.as_ref(), None, None, io_client.clone(), None)?; + let (schema, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None)?; assert_eq!( schema, Schema::new(vec![ @@ -526,7 +523,7 @@ mod tests { let io_client = Arc::new(IOClient::new(io_config.into())?); let (schema, read_stats) = - read_csv_schema(file.as_ref(), None, Some(100), io_client.clone(), None)?; + read_csv_schema(file.as_ref(), None, Some(100), io_client, None)?; assert_eq!( schema, Schema::new(vec![ @@ -563,7 +560,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let err = read_csv_schema(file.as_ref(), None, None, io_client.clone(), None); + let err = read_csv_schema(file.as_ref(), None, None, io_client, None); assert!(err.is_err()); let err = err.unwrap_err(); assert!(matches!(err, DaftError::ArrowError(_)), "{}", err); @@ -592,7 +589,7 @@ mod tests { file.as_ref(), Some(CsvParseOptions::default().with_has_header(false)), None, - io_client.clone(), + io_client, None, ); assert!(err.is_err()); @@ -634,14 +631,14 @@ mod tests { ) -> DaftResult<()> { let file = format!( "s3://daft-public-data/test_fixtures/csv-dev/mvp.csv{}", - compression.map_or("".to_string(), |ext| format!(".{}", ext)) + compression.map_or(String::new(), |ext| format!(".{ext}")) ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let (schema, _) = read_csv_schema(file.as_ref(), None, None, io_client.clone(), None)?; + let (schema, _) = read_csv_schema(file.as_ref(), None, None, io_client, None)?; assert_eq!( schema, Schema::new(vec![ diff --git a/src/daft-csv/src/options.rs b/src/daft-csv/src/options.rs index 28ab48c573..de6e13f6c5 100644 --- a/src/daft-csv/src/options.rs +++ b/src/daft-csv/src/options.rs @@ -21,6 +21,7 @@ pub struct CsvConvertOptions { } impl CsvConvertOptions { + #[must_use] pub fn new_internal( limit: Option, include_columns: Option>, @@ -37,6 +38,7 @@ impl CsvConvertOptions { } } + #[must_use] pub fn with_limit(self, limit: Option) -> Self { Self { limit, @@ -47,6 +49,7 @@ impl CsvConvertOptions { } } + #[must_use] pub fn with_include_columns(self, include_columns: Option>) -> Self { Self { limit: self.limit, @@ -57,6 +60,7 @@ impl CsvConvertOptions { } } + #[must_use] pub fn with_column_names(self, column_names: Option>) -> Self { Self { limit: self.limit, @@ -67,6 +71,7 @@ impl CsvConvertOptions { } } + #[must_use] pub fn with_schema(self, schema: Option) -> Self { Self { limit: self.limit, @@ -98,6 +103,7 @@ impl CsvConvertOptions { /// * `predicate` - Expression to filter rows applied before the limit #[new] #[pyo3(signature = (limit=None, include_columns=None, column_names=None, schema=None, predicate=None))] + #[must_use] pub fn new( limit: Option, include_columns: Option>, @@ -109,7 +115,7 @@ impl CsvConvertOptions { limit, include_columns, column_names, - schema.map(|s| s.into()), + schema.map(std::convert::Into::into), predicate.map(|p| p.expr), ) } @@ -143,7 +149,7 @@ impl CsvConvertOptions { } pub fn __str__(&self) -> PyResult { - Ok(format!("{:?}", self)) + Ok(format!("{self:?}")) } } impl_bincode_py_state_serialization!(CsvConvertOptions); @@ -162,6 +168,7 @@ pub struct CsvParseOptions { } impl CsvParseOptions { + #[must_use] pub fn new_internal( has_header: bool, delimiter: u8, @@ -176,9 +183,9 @@ impl CsvParseOptions { delimiter, double_quote, quote, - allow_variable_columns, escape_char, comment, + allow_variable_columns, } } @@ -202,14 +209,17 @@ impl CsvParseOptions { )) } + #[must_use] pub fn with_has_header(self, has_header: bool) -> Self { Self { has_header, ..self } } + #[must_use] pub fn with_delimiter(self, delimiter: u8) -> Self { Self { delimiter, ..self } } + #[must_use] pub fn with_double_quote(self, double_quote: bool) -> Self { Self { double_quote, @@ -217,10 +227,12 @@ impl CsvParseOptions { } } + #[must_use] pub fn with_quote(self, quote: u8) -> Self { Self { quote, ..self } } + #[must_use] pub fn with_escape_char(self, escape_char: Option) -> Self { Self { escape_char, @@ -228,10 +240,12 @@ impl CsvParseOptions { } } + #[must_use] pub fn with_comment(self, comment: Option) -> Self { Self { comment, ..self } } + #[must_use] pub fn with_variable_columns(self, allow_variable_columns: bool) -> Self { Self { allow_variable_columns, @@ -291,7 +305,7 @@ impl CsvParseOptions { } pub fn __str__(&self) -> PyResult { - Ok(format!("{:?}", self)) + Ok(format!("{self:?}")) } } @@ -316,6 +330,7 @@ pub struct CsvReadOptions { } impl CsvReadOptions { + #[must_use] pub fn new_internal(buffer_size: Option, chunk_size: Option) -> Self { Self { buffer_size, @@ -323,6 +338,7 @@ impl CsvReadOptions { } } + #[must_use] pub fn with_buffer_size(self, buffer_size: Option) -> Self { Self { buffer_size, @@ -330,6 +346,7 @@ impl CsvReadOptions { } } + #[must_use] pub fn with_chunk_size(self, chunk_size: Option) -> Self { Self { buffer_size: self.buffer_size, @@ -355,6 +372,7 @@ impl CsvReadOptions { /// * `chunk_size` - Size of the chunks (in bytes) deserialized in parallel by the streaming reader. #[new] #[pyo3(signature = (buffer_size=None, chunk_size=None))] + #[must_use] pub fn new(buffer_size: Option, chunk_size: Option) -> Self { Self::new_internal(buffer_size, chunk_size) } @@ -368,7 +386,7 @@ impl CsvReadOptions { } pub fn __str__(&self) -> PyResult { - Ok(format!("{:?}", self)) + Ok(format!("{self:?}")) } } impl_bincode_py_state_serialization!(CsvReadOptions); diff --git a/src/daft-csv/src/read.rs b/src/daft-csv/src/read.rs index c0332feca8..3cdc751284 100644 --- a/src/daft-csv/src/read.rs +++ b/src/daft-csv/src/read.rs @@ -85,7 +85,7 @@ pub fn read_csv_bulk( // Launch a read task per URI, throttling the number of concurrent file reads to num_parallel tasks. let task_stream = futures::stream::iter(uris.iter().map(|uri| { let (uri, convert_options, parse_options, read_options, io_client, io_stats) = ( - uri.to_string(), + (*uri).to_string(), convert_options.clone(), parse_options.clone(), read_options.clone(), @@ -195,7 +195,7 @@ fn tables_concat(mut tables: Vec) -> DaftResult
{ Table::new_with_size( first_table.schema.clone(), new_series, - tables.iter().map(|t| t.len()).sum(), + tables.iter().map(daft_table::Table::len).sum(), ) } @@ -226,7 +226,7 @@ async fn read_csv_single_into_table( let required_columns_for_predicate = get_required_columns(predicate); for rc in required_columns_for_predicate { if include_columns.iter().all(|c| c.as_str() != rc.as_str()) { - include_columns.push(rc) + include_columns.push(rc); } } } @@ -352,7 +352,7 @@ async fn stream_csv_single( let required_columns_for_predicate = get_required_columns(predicate); for rc in required_columns_for_predicate { if include_columns.iter().all(|c| c.as_str() != rc.as_str()) { - include_columns.push(rc) + include_columns.push(rc); } } } @@ -424,10 +424,10 @@ async fn read_csv_single_into_stream( io_client: Arc, io_stats: Option, ) -> DaftResult<(impl TableStream + Send, Vec)> { - let (mut schema, estimated_mean_row_size, estimated_std_row_size) = match convert_options.schema - { - Some(schema) => (schema.to_arrow()?, None, None), - None => { + let (mut schema, estimated_mean_row_size, estimated_std_row_size) = + if let Some(schema) = convert_options.schema { + (schema.to_arrow()?, None, None) + } else { let (schema, read_stats) = read_csv_schema_single( uri, parse_options.clone(), @@ -442,8 +442,7 @@ async fn read_csv_single_into_stream( Some(read_stats.mean_record_size_bytes), Some(read_stats.stddev_record_size_bytes), ) - } - }; + }; // Rename fields, if necessary. if let Some(column_names) = convert_options.column_names { schema = schema @@ -627,7 +626,7 @@ fn parse_into_column_array_chunk_stream( ) }) .collect::>>()?; - let num_rows = chunk.first().map(|s| s.len()).unwrap_or(0); + let num_rows = chunk.first().map_or(0, daft_core::series::Series::len); Ok(Table::new_unchecked(read_schema, chunk, num_rows)) })(); let _ = send.send(result); @@ -767,7 +766,7 @@ mod tests { let file = format!( "{}/test/iris_tiny.csv{}", env!("CARGO_MANIFEST_DIR"), - compression.map_or("".to_string(), |ext| format!(".{}", ext)) + compression.map_or(String::new(), |ext| format!(".{ext}")) ); let mut io_config = IOConfig::default(); @@ -828,10 +827,9 @@ mod tests { ]; let table = read_csv( file.as_ref(), - Some( - CsvConvertOptions::default() - .with_column_names(Some(column_names.iter().map(|s| s.to_string()).collect())), - ), + Some(CsvConvertOptions::default().with_column_names(Some( + column_names.iter().map(|s| (*s).to_string()).collect(), + ))), Some(CsvParseOptions::default().with_has_header(false)), None, io_client, @@ -1234,7 +1232,9 @@ mod tests { file.as_ref(), Some( CsvConvertOptions::default() - .with_column_names(Some(column_names.iter().map(|s| s.to_string()).collect())) + .with_column_names(Some( + column_names.iter().map(|s| (*s).to_string()).collect(), + )) .with_include_columns(Some(vec![ "petal.length".to_string(), "petal.width".to_string(), @@ -1860,7 +1860,7 @@ mod tests { ) -> DaftResult<()> { let file = format!( "s3://daft-public-data/test_fixtures/csv-dev/mvp.csv{}", - compression.map_or("".to_string(), |ext| format!(".{}", ext)) + compression.map_or(String::new(), |ext| format!(".{ext}")) ); let mut io_config = IOConfig::default(); @@ -1894,10 +1894,9 @@ mod tests { let column_names = ["a", "b"]; let table = read_csv( file, - Some( - CsvConvertOptions::default() - .with_column_names(Some(column_names.iter().map(|s| s.to_string()).collect())), - ), + Some(CsvConvertOptions::default().with_column_names(Some( + column_names.iter().map(|s| (*s).to_string()).collect(), + ))), Some(CsvParseOptions::default().with_has_header(false)), None, io_client, @@ -1932,7 +1931,9 @@ mod tests { file, Some( CsvConvertOptions::default() - .with_column_names(Some(column_names.iter().map(|s| s.to_string()).collect())) + .with_column_names(Some( + column_names.iter().map(|s| (*s).to_string()).collect(), + )) .with_include_columns(Some(vec!["b".to_string()])), ), Some(CsvParseOptions::default().with_has_header(false)), diff --git a/src/daft-decoding/src/deserialize.rs b/src/daft-decoding/src/deserialize.rs index dd3c56292e..3552cc57a8 100644 --- a/src/daft-decoding/src/deserialize.rs +++ b/src/daft-decoding/src/deserialize.rs @@ -1,6 +1,9 @@ use arrow2::{ - array::*, - datatypes::*, + array::{ + Array, BinaryArray, BooleanArray, MutableBinaryArray, MutableUtf8Array, NullArray, + PrimitiveArray, Utf8Array, + }, + datatypes::{DataType, TimeUnit}, error::{Error, Result}, offset::Offset, temporal_conversions, @@ -72,7 +75,7 @@ where #[inline] fn significant_bytes(bytes: &[u8]) -> usize { - bytes.iter().map(|byte| (*byte != b'0') as usize).sum() + bytes.iter().map(|byte| usize::from(*byte != b'0')).sum() } /// Deserializes bytes to a single i128 representing a decimal @@ -230,14 +233,17 @@ pub fn deserialize_datetime( /// Deserializes `column` of `rows` into an [`Array`] of [`DataType`] `datatype`. #[inline] - pub fn deserialize_column( rows: &[B], column: usize, datatype: DataType, _line_number: usize, ) -> Result> { - use DataType::*; + use DataType::{ + Binary, Boolean, Date32, Date64, Decimal, Float32, Float64, Int16, Int32, Int64, Int8, + LargeBinary, LargeUtf8, Null, Time32, Time64, Timestamp, UInt16, UInt32, UInt64, UInt8, + Utf8, + }; Ok(match datatype { Boolean => deserialize_boolean(rows, column, |bytes| { if bytes.eq_ignore_ascii_case(b"false") { @@ -306,10 +312,10 @@ pub fn deserialize_column( to_utf8(bytes) .and_then(|x| x.parse::().ok()) .map(|x| { - (x.hour() as u64 * 3_600 * factor - + x.minute() as u64 * 60 * factor - + x.second() as u64 * factor - + x.nanosecond() as u64 / (1_000_000_000 / factor)) + (u64::from(x.hour()) * 3_600 * factor + + u64::from(x.minute()) * 60 * factor + + u64::from(x.second()) * factor + + u64::from(x.nanosecond()) / (1_000_000_000 / factor)) as i64 }) }), @@ -357,6 +363,7 @@ pub fn deserialize_column( } // Return the factor by how small is a time unit compared to seconds +#[must_use] pub fn get_factor_from_timeunit(time_unit: TimeUnit) -> u32 { match time_unit { TimeUnit::Second => 1, diff --git a/src/daft-decoding/src/inference.rs b/src/daft-decoding/src/inference.rs index 1d8be9c32a..a995acb0d6 100644 --- a/src/daft-decoding/src/inference.rs +++ b/src/daft-decoding/src/inference.rs @@ -15,6 +15,7 @@ use crate::deserialize::{ALL_NAIVE_DATE_FMTS, ALL_NAIVE_TIMESTAMP_FMTS, ALL_TIME /// * parsable to time-aware datetime is mapped to [`DataType::Timestamp`] of milliseconds and parsed offset. /// * other utf8 is mapped to [`DataType::Utf8`] /// * invalid utf8 is mapped to [`DataType::Binary`] +#[must_use] pub fn infer(bytes: &[u8]) -> arrow2::datatypes::DataType { if is_null(bytes) { DataType::Null @@ -32,6 +33,7 @@ pub fn infer(bytes: &[u8]) -> arrow2::datatypes::DataType { } } +#[must_use] pub fn infer_string(string: &str) -> DataType { if is_date(string) { DataType::Date32 diff --git a/src/daft-dsl/src/arithmetic.rs b/src/daft-dsl/src/arithmetic/mod.rs similarity index 57% rename from src/daft-dsl/src/arithmetic.rs rename to src/daft-dsl/src/arithmetic/mod.rs index 95faa64074..d4222fe64c 100644 --- a/src/daft-dsl/src/arithmetic.rs +++ b/src/daft-dsl/src/arithmetic/mod.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +mod tests; + use crate::{Expr, ExprRef, Operator}; macro_rules! impl_expr_op { @@ -21,23 +24,3 @@ impl_expr_op!(sub, Minus); impl_expr_op!(mul, Multiply); impl_expr_op!(div, TrueDivide); impl_expr_op!(rem, Modulus); - -#[cfg(test)] -mod tests { - use common_error::{DaftError, DaftResult}; - - use crate::{col, Expr}; - - #[test] - fn check_add_expr_type() -> DaftResult<()> { - let a = col("a"); - let b = col("b"); - let c = a.add(b); - match c.as_ref() { - Expr::BinaryOp { .. } => Ok(()), - other => Err(DaftError::ValueError(format!( - "expected expression to be a binary op expression, got {other:?}" - ))), - } - } -} diff --git a/src/daft-dsl/src/arithmetic/tests.rs b/src/daft-dsl/src/arithmetic/tests.rs new file mode 100644 index 0000000000..19a7c23310 --- /dev/null +++ b/src/daft-dsl/src/arithmetic/tests.rs @@ -0,0 +1,16 @@ +use common_error::{DaftError, DaftResult}; + +use crate::{col, Expr}; + +#[test] +fn check_add_expr_type() -> DaftResult<()> { + let a = col("a"); + let b = col("b"); + let c = a.add(b); + match c.as_ref() { + Expr::BinaryOp { .. } => Ok(()), + other => Err(DaftError::ValueError(format!( + "expected expression to be a binary op expression, got {other:?}" + ))), + } +} diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr/mod.rs similarity index 77% rename from src/daft-dsl/src/expr.rs rename to src/daft-dsl/src/expr/mod.rs index 48249355fc..873f9013bd 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +mod tests; + use std::{ io::{self, Write}, sync::Arc, @@ -7,7 +10,7 @@ use common_error::{DaftError, DaftResult}; use common_hashable_float_wrapper::FloatWrapper; use common_treenode::TreeNode; use daft_core::{ - datatypes::{try_mean_supertype, try_sum_supertype, InferDataType}, + datatypes::{try_mean_stddev_aggregation_supertype, try_sum_supertype, InferDataType}, prelude::*, utils::supertype::try_get_supertype, }; @@ -121,6 +124,9 @@ pub enum AggExpr { #[display("mean({_0})")] Mean(ExprRef), + #[display("stddev({_0})")] + Stddev(ExprRef), + #[display("min({_0})")] Min(ExprRef), @@ -159,36 +165,35 @@ pub fn binary_op(op: Operator, left: ExprRef, right: ExprRef) -> ExprRef { impl AggExpr { pub fn name(&self) -> &str { - use AggExpr::*; match self { - Count(expr, ..) - | Sum(expr) - | ApproxPercentile(ApproxPercentileParams { child: expr, .. }) - | ApproxCountDistinct(expr) - | ApproxSketch(expr, _) - | MergeSketch(expr, _) - | Mean(expr) - | Min(expr) - | Max(expr) - | AnyValue(expr, _) - | List(expr) - | Concat(expr) => expr.name(), - MapGroups { func: _, inputs } => inputs.first().unwrap().name(), + Self::Count(expr, ..) + | Self::Sum(expr) + | Self::ApproxPercentile(ApproxPercentileParams { child: expr, .. }) + | Self::ApproxCountDistinct(expr) + | Self::ApproxSketch(expr, _) + | Self::MergeSketch(expr, _) + | Self::Mean(expr) + | Self::Stddev(expr) + | Self::Min(expr) + | Self::Max(expr) + | Self::AnyValue(expr, _) + | Self::List(expr) + | Self::Concat(expr) => expr.name(), + Self::MapGroups { func: _, inputs } => inputs.first().unwrap().name(), } } pub fn semantic_id(&self, schema: &Schema) -> FieldID { - use AggExpr::*; match self { - Count(expr, mode) => { + Self::Count(expr, mode) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_count({mode})")) } - Sum(expr) => { + Self::Sum(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_sum()")) } - ApproxPercentile(ApproxPercentileParams { + Self::ApproxPercentile(ApproxPercentileParams { child: expr, percentiles, force_list_output, @@ -199,122 +204,126 @@ impl AggExpr { percentiles, )) } - ApproxCountDistinct(expr) => { + Self::ApproxCountDistinct(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_approx_count_distinct()")) } - ApproxSketch(expr, sketch_type) => { + Self::ApproxSketch(expr, sketch_type) => { let child_id = expr.semantic_id(schema); FieldID::new(format!( "{child_id}.local_approx_sketch(sketch_type={sketch_type:?})" )) } - MergeSketch(expr, sketch_type) => { + Self::MergeSketch(expr, sketch_type) => { let child_id = expr.semantic_id(schema); FieldID::new(format!( "{child_id}.local_merge_sketch(sketch_type={sketch_type:?})" )) } - Mean(expr) => { + Self::Mean(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_mean()")) } - Min(expr) => { + Self::Stddev(expr) => { + let child_id = expr.semantic_id(schema); + FieldID::new(format!("{child_id}.local_stddev()")) + } + Self::Min(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_min()")) } - Max(expr) => { + Self::Max(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_max()")) } - AnyValue(expr, ignore_nulls) => { + Self::AnyValue(expr, ignore_nulls) => { let child_id = expr.semantic_id(schema); FieldID::new(format!( "{child_id}.local_any_value(ignore_nulls={ignore_nulls})" )) } - List(expr) => { + Self::List(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_list()")) } - Concat(expr) => { + Self::Concat(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_concat()")) } - MapGroups { func, inputs } => function_semantic_id(func, inputs, schema), + Self::MapGroups { func, inputs } => function_semantic_id(func, inputs, schema), } } pub fn children(&self) -> Vec { - use AggExpr::*; match self { - Count(expr, ..) - | Sum(expr) - | ApproxPercentile(ApproxPercentileParams { child: expr, .. }) - | ApproxCountDistinct(expr) - | ApproxSketch(expr, _) - | MergeSketch(expr, _) - | Mean(expr) - | Min(expr) - | Max(expr) - | AnyValue(expr, _) - | List(expr) - | Concat(expr) => vec![expr.clone()], - MapGroups { func: _, inputs } => inputs.clone(), + Self::Count(expr, ..) + | Self::Sum(expr) + | Self::ApproxPercentile(ApproxPercentileParams { child: expr, .. }) + | Self::ApproxCountDistinct(expr) + | Self::ApproxSketch(expr, _) + | Self::MergeSketch(expr, _) + | Self::Mean(expr) + | Self::Stddev(expr) + | Self::Min(expr) + | Self::Max(expr) + | Self::AnyValue(expr, _) + | Self::List(expr) + | Self::Concat(expr) => vec![expr.clone()], + Self::MapGroups { func: _, inputs } => inputs.clone(), } } - pub fn with_new_children(&self, children: Vec) -> Self { - use AggExpr::*; - - if let MapGroups { func: _, inputs } = &self { + pub fn with_new_children(&self, mut children: Vec) -> Self { + if let Self::MapGroups { func: _, inputs } = &self { assert_eq!(children.len(), inputs.len()); } else { assert_eq!(children.len(), 1); } + let mut first_child = || children.pop().unwrap(); match self { - Count(_, count_mode) => Count(children[0].clone(), *count_mode), - Sum(_) => Sum(children[0].clone()), - Mean(_) => Mean(children[0].clone()), - Min(_) => Min(children[0].clone()), - Max(_) => Max(children[0].clone()), - AnyValue(_, ignore_nulls) => AnyValue(children[0].clone(), *ignore_nulls), - List(_) => List(children[0].clone()), - Concat(_) => Concat(children[0].clone()), - MapGroups { func, inputs: _ } => MapGroups { + Self::Count(_, count_mode) => Self::Count(first_child(), *count_mode), + Self::Sum(_) => Self::Sum(first_child()), + Self::Mean(_) => Self::Mean(first_child()), + Self::Stddev(_) => Self::Stddev(first_child()), + Self::Min(_) => Self::Min(first_child()), + Self::Max(_) => Self::Max(first_child()), + Self::AnyValue(_, ignore_nulls) => Self::AnyValue(first_child(), *ignore_nulls), + Self::List(_) => Self::List(first_child()), + Self::Concat(_) => Self::Concat(first_child()), + Self::MapGroups { func, inputs: _ } => Self::MapGroups { func: func.clone(), inputs: children, }, - ApproxPercentile(ApproxPercentileParams { + Self::ApproxPercentile(ApproxPercentileParams { percentiles, force_list_output, .. - }) => ApproxPercentile(ApproxPercentileParams { - child: children[0].clone(), + }) => Self::ApproxPercentile(ApproxPercentileParams { + child: first_child(), percentiles: percentiles.clone(), force_list_output: *force_list_output, }), - ApproxCountDistinct(_) => ApproxCountDistinct(children[0].clone()), - &ApproxSketch(_, sketch_type) => ApproxSketch(children[0].clone(), sketch_type), - &MergeSketch(_, sketch_type) => MergeSketch(children[0].clone(), sketch_type), + Self::ApproxCountDistinct(_) => Self::ApproxCountDistinct(first_child()), + &Self::ApproxSketch(_, sketch_type) => Self::ApproxSketch(first_child(), sketch_type), + &Self::MergeSketch(_, sketch_type) => Self::MergeSketch(first_child(), sketch_type), } } pub fn to_field(&self, schema: &Schema) -> DaftResult { - use AggExpr::*; match self { - Count(expr, ..) => { + Self::Count(expr, ..) => { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), DataType::UInt64)) } - Sum(expr) => { + Self::Sum(expr) => { let field = expr.to_field(schema)?; Ok(Field::new( field.name.as_str(), try_sum_supertype(&field.dtype)?, )) } - ApproxPercentile(ApproxPercentileParams { + + Self::ApproxPercentile(ApproxPercentileParams { child: expr, percentiles, force_list_output, @@ -337,11 +346,11 @@ impl AggExpr { }, )) } - ApproxCountDistinct(expr) => { + Self::ApproxCountDistinct(expr) => { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), DataType::UInt64)) } - ApproxSketch(expr, sketch_type) => { + Self::ApproxSketch(expr, sketch_type) => { let field = expr.to_field(schema)?; let dtype = match sketch_type { SketchType::DDSketch => { @@ -357,7 +366,7 @@ impl AggExpr { }; Ok(Field::new(field.name, dtype)) } - MergeSketch(expr, sketch_type) => { + Self::MergeSketch(expr, sketch_type) => { let field = expr.to_field(schema)?; let dtype = match sketch_type { SketchType::DDSketch => { @@ -374,19 +383,19 @@ impl AggExpr { }; Ok(Field::new(field.name, dtype)) } - Mean(expr) => { + Self::Mean(expr) | Self::Stddev(expr) => { let field = expr.to_field(schema)?; Ok(Field::new( field.name.as_str(), - try_mean_supertype(&field.dtype)?, + try_mean_stddev_aggregation_supertype(&field.dtype)?, )) } - Min(expr) | Max(expr) | AnyValue(expr, _) => { + Self::Min(expr) | Self::Max(expr) | Self::AnyValue(expr, _) => { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), field.dtype)) } - List(expr) => expr.to_field(schema)?.to_list_field(), - Concat(expr) => { + Self::List(expr) => expr.to_field(schema)?.to_list_field(), + Self::Concat(expr) => { let field = expr.to_field(schema)?; match field.dtype { DataType::List(..) => Ok(field), @@ -399,23 +408,7 @@ impl AggExpr { ))), } } - MapGroups { func, inputs } => func.to_field(inputs.as_slice(), schema, func), - } - } - - pub fn from_name_and_child_expr(name: &str, child: ExprRef) -> DaftResult { - use AggExpr::*; - match name { - "count" => Ok(Count(child, CountMode::Valid)), - "sum" => Ok(Sum(child)), - "mean" => Ok(Mean(child)), - "min" => Ok(Min(child)), - "max" => Ok(Max(child)), - "list" => Ok(List(child)), - _ => Err(DaftError::ValueError(format!( - "{} not a valid aggregation name", - name - ))), + Self::MapGroups { func, inputs } => func.to_field(inputs.as_slice(), schema, func), } } } @@ -498,6 +491,10 @@ impl Expr { Self::Agg(AggExpr::Mean(self)).into() } + pub fn stddev(self: ExprRef) -> ExprRef { + Self::Agg(AggExpr::Stddev(self)).into() + } + pub fn min(self: ExprRef) -> ExprRef { Self::Agg(AggExpr::Min(self)).into() } @@ -576,57 +573,55 @@ impl Expr { } pub fn semantic_id(&self, schema: &Schema) -> FieldID { - use Expr::*; match self { // Base case - anonymous column reference. // Look up the column name in the provided schema and get its field ID. - Column(name) => FieldID::new(&**name), + Self::Column(name) => FieldID::new(&**name), // Base case - literal. - Literal(value) => FieldID::new(format!("Literal({value:?})")), + Self::Literal(value) => FieldID::new(format!("Literal({value:?})")), // Recursive cases. - Cast(expr, dtype) => { + Self::Cast(expr, dtype) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.cast({dtype})")) } - Not(expr) => { + Self::Not(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.not()")) } - IsNull(expr) => { + Self::IsNull(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.is_null()")) } - NotNull(expr) => { + Self::NotNull(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.not_null()")) } - FillNull(expr, fill_value) => { + Self::FillNull(expr, fill_value) => { let child_id = expr.semantic_id(schema); let fill_value_id = fill_value.semantic_id(schema); FieldID::new(format!("{child_id}.fill_null({fill_value_id})")) } - IsIn(expr, items) => { + Self::IsIn(expr, items) => { let child_id = expr.semantic_id(schema); let items_id = items.semantic_id(schema); FieldID::new(format!("{child_id}.is_in({items_id})")) } - Between(expr, lower, upper) => { + Self::Between(expr, lower, upper) => { let child_id = expr.semantic_id(schema); let lower_id = lower.semantic_id(schema); let upper_id = upper.semantic_id(schema); FieldID::new(format!("{child_id}.between({lower_id},{upper_id})")) } - Function { func, inputs } => function_semantic_id(func, inputs, schema), - BinaryOp { op, left, right } => { + Self::Function { func, inputs } => function_semantic_id(func, inputs, schema), + Self::BinaryOp { op, left, right } => { let left_id = left.semantic_id(schema); let right_id = right.semantic_id(schema); // TODO: check for symmetry here. FieldID::new(format!("({left_id} {op} {right_id})")) } - - IfElse { + Self::IfElse { if_true, if_false, predicate, @@ -636,96 +631,100 @@ impl Expr { let predicate = predicate.semantic_id(schema); FieldID::new(format!("({if_true} if {predicate} else {if_false})")) } - // Alias: ID does not change. - Alias(expr, ..) => expr.semantic_id(schema), - + Self::Alias(expr, ..) => expr.semantic_id(schema), // Agg: Separate path. - Agg(agg_expr) => agg_expr.semantic_id(schema), - ScalarFunction(sf) => scalar_function_semantic_id(sf, schema), + Self::Agg(agg_expr) => agg_expr.semantic_id(schema), + Self::ScalarFunction(sf) => scalar_function_semantic_id(sf, schema), } } pub fn children(&self) -> Vec { - use Expr::*; match self { // No children. - Column(..) => vec![], - Literal(..) => vec![], + Self::Column(..) => vec![], + Self::Literal(..) => vec![], // One child. - Not(expr) | IsNull(expr) | NotNull(expr) | Cast(expr, ..) | Alias(expr, ..) => { + Self::Not(expr) + | Self::IsNull(expr) + | Self::NotNull(expr) + | Self::Cast(expr, ..) + | Self::Alias(expr, ..) => { vec![expr.clone()] } - Agg(agg_expr) => agg_expr.children(), + Self::Agg(agg_expr) => agg_expr.children(), // Multiple children. - Function { inputs, .. } => inputs.clone(), - BinaryOp { left, right, .. } => { + Self::Function { inputs, .. } => inputs.clone(), + Self::BinaryOp { left, right, .. } => { vec![left.clone(), right.clone()] } - IsIn(expr, items) => vec![expr.clone(), items.clone()], - Between(expr, lower, upper) => vec![expr.clone(), lower.clone(), upper.clone()], - IfElse { + Self::IsIn(expr, items) => vec![expr.clone(), items.clone()], + Self::Between(expr, lower, upper) => vec![expr.clone(), lower.clone(), upper.clone()], + Self::IfElse { if_true, if_false, predicate, } => { vec![if_true.clone(), if_false.clone(), predicate.clone()] } - FillNull(expr, fill_value) => vec![expr.clone(), fill_value.clone()], - ScalarFunction(sf) => sf.inputs.clone(), + Self::FillNull(expr, fill_value) => vec![expr.clone(), fill_value.clone()], + Self::ScalarFunction(sf) => sf.inputs.clone(), } } pub fn with_new_children(&self, children: Vec) -> Self { - use Expr::*; match self { // no children - Column(..) | Literal(..) => { + Self::Column(..) | Self::Literal(..) => { assert!(children.is_empty(), "Should have no children"); self.clone() } // 1 child - Not(..) => Not(children.first().expect("Should have 1 child").clone()), - Alias(.., name) => Alias( + Self::Not(..) => Self::Not(children.first().expect("Should have 1 child").clone()), + Self::Alias(.., name) => Self::Alias( children.first().expect("Should have 1 child").clone(), name.clone(), ), - IsNull(..) => IsNull(children.first().expect("Should have 1 child").clone()), - NotNull(..) => NotNull(children.first().expect("Should have 1 child").clone()), - Cast(.., dtype) => Cast( + Self::IsNull(..) => { + Self::IsNull(children.first().expect("Should have 1 child").clone()) + } + Self::NotNull(..) => { + Self::NotNull(children.first().expect("Should have 1 child").clone()) + } + Self::Cast(.., dtype) => Self::Cast( children.first().expect("Should have 1 child").clone(), dtype.clone(), ), // 2 children - BinaryOp { op, .. } => BinaryOp { + Self::BinaryOp { op, .. } => Self::BinaryOp { op: *op, left: children.first().expect("Should have 1 child").clone(), right: children.get(1).expect("Should have 2 child").clone(), }, - IsIn(..) => IsIn( + Self::IsIn(..) => Self::IsIn( children.first().expect("Should have 1 child").clone(), children.get(1).expect("Should have 2 child").clone(), ), - Between(..) => Between( + Self::Between(..) => Self::Between( children.first().expect("Should have 1 child").clone(), children.get(1).expect("Should have 2 child").clone(), children.get(2).expect("Should have 3 child").clone(), ), - FillNull(..) => FillNull( + Self::FillNull(..) => Self::FillNull( children.first().expect("Should have 1 child").clone(), children.get(1).expect("Should have 2 child").clone(), ), // ternary - IfElse { .. } => IfElse { + Self::IfElse { .. } => Self::IfElse { if_true: children.first().expect("Should have 1 child").clone(), if_false: children.get(1).expect("Should have 2 child").clone(), predicate: children.get(2).expect("Should have 3 child").clone(), }, // N-ary - Agg(agg_expr) => Agg(agg_expr.with_new_children(children)), - Function { + Self::Agg(agg_expr) => Self::Agg(agg_expr.with_new_children(children)), + Self::Function { func, inputs: old_children, } => { @@ -733,18 +732,18 @@ impl Expr { children.len() == old_children.len(), "Should have same number of children" ); - Function { + Self::Function { func: func.clone(), inputs: children, } } - ScalarFunction(sf) => { + Self::ScalarFunction(sf) => { assert!( children.len() == sf.inputs.len(), "Should have same number of children" ); - ScalarFunction(crate::functions::ScalarFunction { + Self::ScalarFunction(crate::functions::ScalarFunction { udf: sf.udf.clone(), inputs: children, }) @@ -753,13 +752,12 @@ impl Expr { } pub fn to_field(&self, schema: &Schema) -> DaftResult { - use Expr::*; match self { - Alias(expr, name) => Ok(Field::new(name.as_ref(), expr.get_type(schema)?)), - Agg(agg_expr) => agg_expr.to_field(schema), - Cast(expr, dtype) => Ok(Field::new(expr.name(), dtype.clone())), - Column(name) => Ok(schema.get_field(name).cloned()?), - Not(expr) => { + Self::Alias(expr, name) => Ok(Field::new(name.as_ref(), expr.get_type(schema)?)), + Self::Agg(agg_expr) => agg_expr.to_field(schema), + Self::Cast(expr, dtype) => Ok(Field::new(expr.name(), dtype.clone())), + Self::Column(name) => Ok(schema.get_field(name).cloned()?), + Self::Not(expr) => { let child_field = expr.to_field(schema)?; match child_field.dtype { DataType::Boolean => Ok(Field::new(expr.name(), DataType::Boolean)), @@ -768,9 +766,9 @@ impl Expr { ))), } } - IsNull(expr) => Ok(Field::new(expr.name(), DataType::Boolean)), - NotNull(expr) => Ok(Field::new(expr.name(), DataType::Boolean)), - FillNull(expr, fill_value) => { + Self::IsNull(expr) => Ok(Field::new(expr.name(), DataType::Boolean)), + Self::NotNull(expr) => Ok(Field::new(expr.name(), DataType::Boolean)), + Self::FillNull(expr, fill_value) => { let expr_field = expr.to_field(schema)?; let fill_value_field = fill_value.to_field(schema)?; match try_get_supertype(&expr_field.dtype, &fill_value_field.dtype) { @@ -780,7 +778,7 @@ impl Expr { ))) } } - IsIn(left, right) => { + Self::IsIn(left, right) => { let left_field = left.to_field(schema)?; let right_field = right.to_field(schema)?; let (result_type, _intermediate, _comp_type) = @@ -788,7 +786,7 @@ impl Expr { .membership_op(&InferDataType::from(&right_field.dtype))?; Ok(Field::new(left_field.name.as_str(), result_type)) } - Between(value, lower, upper) => { + Self::Between(value, lower, upper) => { let value_field = value.to_field(schema)?; let lower_field = lower.to_field(schema)?; let upper_field = upper.to_field(schema)?; @@ -803,11 +801,10 @@ impl Expr { .membership_op(&InferDataType::from(&upper_result_type))?; Ok(Field::new(value_field.name.as_str(), result_type)) } - Literal(value) => Ok(Field::new("literal", value.get_type())), - Function { func, inputs } => func.to_field(inputs.as_slice(), schema, func), - ScalarFunction(sf) => sf.to_field(schema), - - BinaryOp { op, left, right } => { + Self::Literal(value) => Ok(Field::new("literal", value.get_type())), + Self::Function { func, inputs } => func.to_field(inputs.as_slice(), schema, func), + Self::ScalarFunction(sf) => sf.to_field(schema), + Self::BinaryOp { op, left, right } => { let left_field = left.to_field(schema)?; let right_field = right.to_field(schema)?; @@ -873,7 +870,7 @@ impl Expr { } } } - IfElse { + Self::IfElse { if_true, if_false, predicate, @@ -903,33 +900,32 @@ impl Expr { } pub fn name(&self) -> &str { - use Expr::*; match self { - Alias(.., name) => name.as_ref(), - Agg(agg_expr) => agg_expr.name(), - Cast(expr, ..) => expr.name(), - Column(name) => name.as_ref(), - Not(expr) => expr.name(), - IsNull(expr) => expr.name(), - NotNull(expr) => expr.name(), - FillNull(expr, ..) => expr.name(), - IsIn(expr, ..) => expr.name(), - Between(expr, ..) => expr.name(), - Literal(..) => "literal", - Function { func, inputs } => match func { + Self::Alias(.., name) => name.as_ref(), + Self::Agg(agg_expr) => agg_expr.name(), + Self::Cast(expr, ..) => expr.name(), + Self::Column(name) => name.as_ref(), + Self::Not(expr) => expr.name(), + Self::IsNull(expr) => expr.name(), + Self::NotNull(expr) => expr.name(), + Self::FillNull(expr, ..) => expr.name(), + Self::IsIn(expr, ..) => expr.name(), + Self::Between(expr, ..) => expr.name(), + Self::Literal(..) => "literal", + Self::Function { func, inputs } => match func { FunctionExpr::Struct(StructExpr::Get(name)) => name, _ => inputs.first().unwrap().name(), }, - ScalarFunction(func) => match func.name() { + Self::ScalarFunction(func) => match func.name() { "to_struct" => "struct", // FIXME: make .name() use output name from schema _ => func.inputs.first().unwrap().name(), }, - BinaryOp { + Self::BinaryOp { op: _, left, right: _, } => left.name(), - IfElse { if_true, .. } => if_true.name(), + Self::IfElse { if_true, .. } => if_true.name(), } } @@ -1024,7 +1020,7 @@ impl Expr { let mut buffer = Vec::new(); to_sql_inner(self, &mut buffer) .ok() - .and_then(|_| String::from_utf8(buffer).ok()) + .and_then(|()| String::from_utf8(buffer).ok()) } /// If the expression is a literal, return it. Otherwise, return None. @@ -1119,90 +1115,3 @@ pub fn has_stateful_udf(expr: &ExprRef) -> bool { ) }) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn check_comparison_type() -> DaftResult<()> { - let x = lit(10.); - let y = lit(12); - let schema = Schema::empty(); - - let z = Expr::BinaryOp { - left: x, - right: y, - op: Operator::Lt, - }; - assert_eq!(z.get_type(&schema)?, DataType::Boolean); - Ok(()) - } - - #[test] - fn check_alias_type() -> DaftResult<()> { - let a = col("a"); - let b = a.alias("b"); - match b.as_ref() { - Expr::Alias(..) => Ok(()), - other => Err(common_error::DaftError::ValueError(format!( - "expected expression to be a alias, got {other:?}" - ))), - } - } - - #[test] - fn check_arithmetic_type() -> DaftResult<()> { - let x = lit(10.); - let y = lit(12); - let schema = Schema::empty(); - - let z = Expr::BinaryOp { - left: x, - right: y, - op: Operator::Plus, - }; - assert_eq!(z.get_type(&schema)?, DataType::Float64); - - let x = lit(10.); - let y = lit(12); - - let z = Expr::BinaryOp { - left: y, - right: x, - op: Operator::Plus, - }; - assert_eq!(z.get_type(&schema)?, DataType::Float64); - - Ok(()) - } - - #[test] - fn check_arithmetic_type_with_columns() -> DaftResult<()> { - let x = col("x"); - let y = col("y"); - let schema = Schema::new(vec![ - Field::new("x", DataType::Float64), - Field::new("y", DataType::Int64), - ])?; - - let z = Expr::BinaryOp { - left: x, - right: y, - op: Operator::Plus, - }; - assert_eq!(z.get_type(&schema)?, DataType::Float64); - - let x = col("x"); - let y = col("y"); - - let z = Expr::BinaryOp { - left: y, - right: x, - op: Operator::Plus, - }; - assert_eq!(z.get_type(&schema)?, DataType::Float64); - - Ok(()) - } -} diff --git a/src/daft-dsl/src/expr/tests.rs b/src/daft-dsl/src/expr/tests.rs new file mode 100644 index 0000000000..aff680c5d3 --- /dev/null +++ b/src/daft-dsl/src/expr/tests.rs @@ -0,0 +1,83 @@ +use super::*; + +#[test] +fn check_comparison_type() -> DaftResult<()> { + let x = lit(10.); + let y = lit(12); + let schema = Schema::empty(); + + let z = Expr::BinaryOp { + left: x, + right: y, + op: Operator::Lt, + }; + assert_eq!(z.get_type(&schema)?, DataType::Boolean); + Ok(()) +} + +#[test] +fn check_alias_type() -> DaftResult<()> { + let a = col("a"); + let b = a.alias("b"); + match b.as_ref() { + Expr::Alias(..) => Ok(()), + other => Err(common_error::DaftError::ValueError(format!( + "expected expression to be a alias, got {other:?}" + ))), + } +} + +#[test] +fn check_arithmetic_type() -> DaftResult<()> { + let x = lit(10.); + let y = lit(12); + let schema = Schema::empty(); + + let z = Expr::BinaryOp { + left: x, + right: y, + op: Operator::Plus, + }; + assert_eq!(z.get_type(&schema)?, DataType::Float64); + + let x = lit(10.); + let y = lit(12); + + let z = Expr::BinaryOp { + left: y, + right: x, + op: Operator::Plus, + }; + assert_eq!(z.get_type(&schema)?, DataType::Float64); + + Ok(()) +} + +#[test] +fn check_arithmetic_type_with_columns() -> DaftResult<()> { + let x = col("x"); + let y = col("y"); + let schema = Schema::new(vec![ + Field::new("x", DataType::Float64), + Field::new("y", DataType::Int64), + ])?; + + let z = Expr::BinaryOp { + left: x, + right: y, + op: Operator::Plus, + }; + assert_eq!(z.get_type(&schema)?, DataType::Float64); + + let x = col("x"); + let y = col("y"); + + let z = Expr::BinaryOp { + left: y, + right: x, + op: Operator::Plus, + }; + assert_eq!(z.get_type(&schema)?, DataType::Float64); + + Ok(()) +} diff --git a/src/daft-dsl/src/functions/map/get.rs b/src/daft-dsl/src/functions/map/get.rs index ab6eb148f8..5465f08562 100644 --- a/src/daft-dsl/src/functions/map/get.rs +++ b/src/daft-dsl/src/functions/map/get.rs @@ -12,40 +12,36 @@ impl FunctionEvaluator for GetEvaluator { } fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input, key] => match (input.to_field(schema), key.to_field(schema)) { - (Ok(input_field), Ok(_)) => match input_field.dtype { - DataType::Map(inner) => match inner.as_ref() { - DataType::Struct(fields) if fields.len() == 2 => { - let value_dtype = &fields[1].dtype; - Ok(Field::new("value", value_dtype.clone())) - } - _ => Err(DaftError::TypeError(format!( - "Expected input map to have struct values with 2 fields, got {}", - inner - ))), - }, - _ => Err(DaftError::TypeError(format!( - "Expected input to be a map, got {}", - input_field.dtype - ))), - }, - (Err(e), _) | (_, Err(e)) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( + let [input, key] = inputs else { + return Err(DaftError::SchemaMismatch(format!( "Expected 2 input args, got {}", inputs.len() - ))), - } + ))); + }; + + let input_field = input.to_field(schema)?; + let _ = key.to_field(schema)?; + + let DataType::Map { value, .. } = input_field.dtype else { + return Err(DaftError::TypeError(format!( + "Expected input to be a map, got {}", + input_field.dtype + ))); + }; + + let field = Field::new("value", *value); + + Ok(field) } fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [input, key] => input.map_get(key), - _ => Err(DaftError::ValueError(format!( + let [input, key] = inputs else { + return Err(DaftError::ValueError(format!( "Expected 2 input args, got {}", inputs.len() - ))), - } + ))); + }; + + input.map_get(key) } } diff --git a/src/daft-dsl/src/functions/map/mod.rs b/src/daft-dsl/src/functions/map/mod.rs index 979a6ccd1e..083e99e7db 100644 --- a/src/daft-dsl/src/functions/map/mod.rs +++ b/src/daft-dsl/src/functions/map/mod.rs @@ -14,9 +14,8 @@ pub enum MapExpr { impl MapExpr { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use MapExpr::*; match self { - Get => &GetEvaluator {}, + Self::Get => &GetEvaluator {}, } } } diff --git a/src/daft-dsl/src/functions/mod.rs b/src/daft-dsl/src/functions/mod.rs index 0386d7c54c..6f0b162422 100644 --- a/src/daft-dsl/src/functions/mod.rs +++ b/src/daft-dsl/src/functions/mod.rs @@ -1,5 +1,6 @@ pub mod map; pub mod partitioning; +pub mod python; pub mod scalar; pub mod sketch; pub mod struct_; @@ -12,6 +13,7 @@ use std::{ use common_error::DaftResult; use daft_core::prelude::*; +use python::PythonUDF; pub use scalar::*; use serde::{Deserialize, Serialize}; @@ -21,9 +23,6 @@ use self::{ }; use crate::{Expr, ExprRef, Operator}; -pub mod python; -use python::PythonUDF; - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum FunctionExpr { Utf8(Utf8Expr), @@ -48,14 +47,13 @@ pub trait FunctionEvaluator { impl FunctionExpr { #[inline] fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use FunctionExpr::*; match self { - Utf8(expr) => expr.get_evaluator(), - Map(expr) => expr.get_evaluator(), - Sketch(expr) => expr.get_evaluator(), - Struct(expr) => expr.get_evaluator(), - Python(expr) => expr.get_evaluator(), - Partitioning(expr) => expr.get_evaluator(), + Self::Utf8(expr) => expr.get_evaluator(), + Self::Map(expr) => expr.get_evaluator(), + Self::Sketch(expr) => expr.get_evaluator(), + Self::Struct(expr) => expr.get_evaluator(), + Self::Python(expr) => expr.get_evaluator(), + Self::Partitioning(expr) => expr.get_evaluator(), } } } diff --git a/src/daft-dsl/src/functions/partitioning/mod.rs b/src/daft-dsl/src/functions/partitioning/mod.rs index 9f37414e18..ead6ed91f8 100644 --- a/src/daft-dsl/src/functions/partitioning/mod.rs +++ b/src/daft-dsl/src/functions/partitioning/mod.rs @@ -24,14 +24,13 @@ pub enum PartitioningExpr { impl PartitioningExpr { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use PartitioningExpr::*; match self { - Years => &YearsEvaluator {}, - Months => &MonthsEvaluator {}, - Days => &DaysEvaluator {}, - Hours => &HoursEvaluator {}, - IcebergBucket(..) => &IcebergBucketEvaluator {}, - IcebergTruncate(..) => &IcebergTruncateEvaluator {}, + Self::Years => &YearsEvaluator {}, + Self::Months => &MonthsEvaluator {}, + Self::Days => &DaysEvaluator {}, + Self::Hours => &HoursEvaluator {}, + Self::IcebergBucket(..) => &IcebergBucketEvaluator {}, + Self::IcebergTruncate(..) => &IcebergTruncateEvaluator {}, } } } diff --git a/src/daft-dsl/src/functions/python/mod.rs b/src/daft-dsl/src/functions/python/mod.rs index 378611851a..c4f69f8331 100644 --- a/src/daft-dsl/src/functions/python/mod.rs +++ b/src/daft-dsl/src/functions/python/mod.rs @@ -2,13 +2,19 @@ mod runtime_py_object; mod udf; mod udf_runtime_binding; -use std::{collections::HashMap, sync::Arc}; +#[cfg(feature = "python")] +use std::collections::HashMap; +use std::sync::Arc; -use common_error::{DaftError, DaftResult}; +#[cfg(feature = "python")] +use common_error::DaftError; +use common_error::DaftResult; use common_resource_request::ResourceRequest; use common_treenode::{TreeNode, TreeNodeRecursion}; use daft_core::datatypes::DataType; use itertools::Itertools; +#[cfg(feature = "python")] +use pyo3::{Py, PyAny}; pub use runtime_py_object::RuntimePyObject; use serde::{Deserialize, Serialize}; pub use udf_runtime_binding::UDFRuntimeBinding; @@ -128,7 +134,7 @@ pub fn get_resource_request(exprs: &[ExprRef]) -> Option { .. } => { if let Some(rr) = resource_request { - resource_requests.push(rr.clone()) + resource_requests.push(rr.clone()); } Ok(TreeNodeRecursion::Continue) } @@ -156,7 +162,7 @@ pub fn get_resource_request(exprs: &[ExprRef]) -> Option { /// NOTE: This function panics if no StatefulUDF is found pub fn get_concurrency(exprs: &[ExprRef]) -> usize { let mut projection_concurrency = None; - for expr in exprs.iter() { + for expr in exprs { let mut found_stateful_udf = false; expr.apply(|e| match e.as_ref() { Expr::Function { @@ -180,7 +186,7 @@ pub fn get_concurrency(exprs: &[ExprRef]) -> usize { #[cfg(feature = "python")] pub fn bind_stateful_udfs( expr: ExprRef, - initialized_funcs: &HashMap>, + initialized_funcs: &HashMap>, ) -> DaftResult { expr.transform(|e| match e.as_ref() { Expr::Function { @@ -213,7 +219,9 @@ pub fn bind_stateful_udfs( /// Helper function that extracts all PartialStatefulUDF python objects from a given expression tree #[cfg(feature = "python")] -pub fn extract_partial_stateful_udf_py(expr: ExprRef) -> HashMap> { +pub fn extract_partial_stateful_udf_py( + expr: ExprRef, +) -> HashMap, Option>)> { let mut py_partial_udfs = HashMap::new(); expr.apply(|child| { if let Expr::Function { @@ -221,12 +229,19 @@ pub fn extract_partial_stateful_udf_py(expr: ExprRef) -> HashMap(&self, state: &mut H) { self.0 .iter() - .for_each(|p| p.to_be_bytes().iter().for_each(|&b| state.write_u8(b))) + .for_each(|p| p.to_be_bytes().iter().for_each(|&b| state.write_u8(b))); } } @@ -30,9 +30,8 @@ pub enum SketchExpr { impl SketchExpr { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use SketchExpr::*; match self { - Percentile { .. } => &PercentileEvaluator {}, + Self::Percentile { .. } => &PercentileEvaluator {}, } } } @@ -43,7 +42,7 @@ pub fn sketch_percentile(input: ExprRef, percentiles: &[f64], force_list_output: percentiles: HashableVecPercentiles(percentiles.to_vec()), force_list_output, }), - inputs: vec![input.clone()], + inputs: vec![input], } .into() } diff --git a/src/daft-dsl/src/functions/struct_/mod.rs b/src/daft-dsl/src/functions/struct_/mod.rs index c842c45c64..7d8d192d25 100644 --- a/src/daft-dsl/src/functions/struct_/mod.rs +++ b/src/daft-dsl/src/functions/struct_/mod.rs @@ -14,9 +14,8 @@ pub enum StructExpr { impl StructExpr { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use StructExpr::*; match self { - Get(_) => &GetEvaluator {}, + Self::Get(_) => &GetEvaluator {}, } } } diff --git a/src/daft-dsl/src/functions/utf8/mod.rs b/src/daft-dsl/src/functions/utf8/mod.rs index cb3a07aca1..7a795250ff 100644 --- a/src/daft-dsl/src/functions/utf8/mod.rs +++ b/src/daft-dsl/src/functions/utf8/mod.rs @@ -95,36 +95,35 @@ pub enum Utf8Expr { impl Utf8Expr { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use Utf8Expr::*; match self { - EndsWith => &EndswithEvaluator {}, - StartsWith => &StartswithEvaluator {}, - Contains => &ContainsEvaluator {}, - Split(_) => &SplitEvaluator {}, - Match => &MatchEvaluator {}, - Extract(_) => &ExtractEvaluator {}, - ExtractAll(_) => &ExtractAllEvaluator {}, - Replace(_) => &ReplaceEvaluator {}, - Length => &LengthEvaluator {}, - LengthBytes => &LengthBytesEvaluator {}, - Lower => &LowerEvaluator {}, - Upper => &UpperEvaluator {}, - Lstrip => &LstripEvaluator {}, - Rstrip => &RstripEvaluator {}, - Reverse => &ReverseEvaluator {}, - Capitalize => &CapitalizeEvaluator {}, - Left => &LeftEvaluator {}, - Right => &RightEvaluator {}, - Find => &FindEvaluator {}, - Rpad => &RpadEvaluator {}, - Lpad => &LpadEvaluator {}, - Repeat => &RepeatEvaluator {}, - Like => &LikeEvaluator {}, - Ilike => &IlikeEvaluator {}, - Substr => &SubstrEvaluator {}, - ToDate(_) => &ToDateEvaluator {}, - ToDatetime(_, _) => &ToDatetimeEvaluator {}, - Normalize(_) => &NormalizeEvaluator {}, + Self::EndsWith => &EndswithEvaluator {}, + Self::StartsWith => &StartswithEvaluator {}, + Self::Contains => &ContainsEvaluator {}, + Self::Split(_) => &SplitEvaluator {}, + Self::Match => &MatchEvaluator {}, + Self::Extract(_) => &ExtractEvaluator {}, + Self::ExtractAll(_) => &ExtractAllEvaluator {}, + Self::Replace(_) => &ReplaceEvaluator {}, + Self::Length => &LengthEvaluator {}, + Self::LengthBytes => &LengthBytesEvaluator {}, + Self::Lower => &LowerEvaluator {}, + Self::Upper => &UpperEvaluator {}, + Self::Lstrip => &LstripEvaluator {}, + Self::Rstrip => &RstripEvaluator {}, + Self::Reverse => &ReverseEvaluator {}, + Self::Capitalize => &CapitalizeEvaluator {}, + Self::Left => &LeftEvaluator {}, + Self::Right => &RightEvaluator {}, + Self::Find => &FindEvaluator {}, + Self::Rpad => &RpadEvaluator {}, + Self::Lpad => &LpadEvaluator {}, + Self::Repeat => &RepeatEvaluator {}, + Self::Like => &LikeEvaluator {}, + Self::Ilike => &IlikeEvaluator {}, + Self::Substr => &SubstrEvaluator {}, + Self::ToDate(_) => &ToDateEvaluator {}, + Self::ToDatetime(_, _) => &ToDatetimeEvaluator {}, + Self::Normalize(_) => &NormalizeEvaluator {}, } } } diff --git a/src/daft-dsl/src/join.rs b/src/daft-dsl/src/join/mod.rs similarity index 79% rename from src/daft-dsl/src/join.rs rename to src/daft-dsl/src/join/mod.rs index 2f1cf96cb2..1de29b995e 100644 --- a/src/daft-dsl/src/join.rs +++ b/src/daft-dsl/src/join/mod.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +mod tests; + use std::sync::Arc; use common_error::{DaftError, DaftResult}; @@ -79,34 +82,3 @@ pub fn infer_join_schema( Ok(Schema::new(fields)?.into()) } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::col; - - #[test] - fn test_get_common_join_keys() { - let left_on: &[ExprRef] = &[ - col("a"), - col("b_left"), - col("c").alias("c_new"), - col("d").alias("d_new"), - col("e").add(col("f")), - ]; - - let right_on: &[ExprRef] = &[ - col("a"), - col("b_right"), - col("c"), - col("d").alias("d_new"), - col("e"), - ]; - - let common_join_keys = get_common_join_keys(left_on, right_on) - .map(|k| k.to_string()) - .collect::>(); - - assert_eq!(common_join_keys, vec!["a"]); - } -} diff --git a/src/daft-dsl/src/join/tests.rs b/src/daft-dsl/src/join/tests.rs new file mode 100644 index 0000000000..52d58a76c0 --- /dev/null +++ b/src/daft-dsl/src/join/tests.rs @@ -0,0 +1,27 @@ +use super::*; +use crate::col; + +#[test] +fn test_get_common_join_keys() { + let left_on: &[ExprRef] = &[ + col("a"), + col("b_left"), + col("c").alias("c_new"), + col("d").alias("d_new"), + col("e").add(col("f")), + ]; + + let right_on: &[ExprRef] = &[ + col("a"), + col("b_right"), + col("c"), + col("d").alias("d_new"), + col("e"), + ]; + + let common_join_keys = get_common_join_keys(left_on, right_on) + .map(|k| k.to_string()) + .collect::>(); + + assert_eq!(common_join_keys, vec!["a"]); +} diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index 754578eb6d..c3f5d68594 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -18,7 +18,7 @@ pub use expr::{ binary_op, col, has_agg, has_stateful_udf, is_partition_compatible, AggExpr, ApproxPercentileParams, Expr, ExprRef, Operator, SketchType, }; -pub use lit::{lit, literals_to_series, null_lit, Literal, LiteralValue}; +pub use lit::{lit, literal_value, literals_to_series, null_lit, Literal, LiteralValue}; #[cfg(feature = "python")] use pyo3::prelude::*; pub use resolve_expr::{ @@ -35,6 +35,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction_bound!(python::date_lit, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::time_lit, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::timestamp_lit, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(python::duration_lit, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::decimal_lit, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::series_lit, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::stateless_udf, parent)?)?; diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 5db0f05a3d..52ca09af18 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -10,10 +10,11 @@ use common_hashable_float_wrapper::FloatWrapper; use daft_core::{ prelude::*, utils::display::{ - display_date32, display_decimal128, display_series_literal, display_time64, - display_timestamp, + display_date32, display_decimal128, display_duration, display_series_literal, + display_time64, display_timestamp, }, }; +use indexmap::IndexMap; use serde::{Deserialize, Serialize}; #[cfg(feature = "python")] @@ -59,6 +60,8 @@ pub enum LiteralValue { Date(i32), /// An [`i64`] representing a time in microseconds or nanoseconds since midnight. Time(i64, TimeUnit), + /// An [`i64`] representing a measure of elapsed time. This elapsed time is a physical duration (i.e. 1s as defined in S.I.) + Duration(i64, TimeUnit), /// A 64-bit floating point number. Float64(f64), /// An [`i128`] representing a decimal number with the provided precision and scale. @@ -68,50 +71,60 @@ pub enum LiteralValue { /// Python object. #[cfg(feature = "python")] Python(PyObjectWrapper), + + Struct(IndexMap), } impl Eq for LiteralValue {} impl Hash for LiteralValue { fn hash(&self, state: &mut H) { - use LiteralValue::*; - match self { // Stable hash for Null variant. - Null => 1.hash(state), - Boolean(bool) => bool.hash(state), - Utf8(s) => s.hash(state), - Binary(arr) => arr.hash(state), - Int32(n) => n.hash(state), - UInt32(n) => n.hash(state), - Int64(n) => n.hash(state), - UInt64(n) => n.hash(state), - Date(n) => n.hash(state), - Time(n, tu) => { + Self::Null => 1.hash(state), + Self::Boolean(bool) => bool.hash(state), + Self::Utf8(s) => s.hash(state), + Self::Binary(arr) => arr.hash(state), + Self::Int32(n) => n.hash(state), + Self::UInt32(n) => n.hash(state), + Self::Int64(n) => n.hash(state), + Self::UInt64(n) => n.hash(state), + Self::Date(n) => n.hash(state), + Self::Time(n, tu) => { n.hash(state); tu.hash(state); } - Timestamp(n, tu, tz) => { + Self::Timestamp(n, tu, tz) => { n.hash(state); tu.hash(state); tz.hash(state); } + Self::Duration(n, tu) => { + n.hash(state); + tu.hash(state); + } // Wrap float64 in hashable newtype. - Float64(n) => FloatWrapper(*n).hash(state), - Decimal(n, precision, scale) => { + Self::Float64(n) => FloatWrapper(*n).hash(state), + Self::Decimal(n, precision, scale) => { n.hash(state); precision.hash(state); scale.hash(state); } - Series(series) => { + Self::Series(series) => { let hash_result = series.hash(None); match hash_result { Ok(hash) => hash.into_iter().for_each(|i| i.hash(state)), - Err(_) => panic!("Cannot hash series"), + Err(..) => panic!("Cannot hash series"), } } #[cfg(feature = "python")] - Python(py_obj) => py_obj.hash(state), + Self::Python(py_obj) => py_obj.hash(state), + Self::Struct(entries) => { + entries.iter().for_each(|(v, f)| { + v.hash(state); + f.hash(state); + }); + } } } } @@ -119,122 +132,148 @@ impl Hash for LiteralValue { impl Display for LiteralValue { // `f` is a buffer, and this method must write the formatted string into it fn fmt(&self, f: &mut Formatter) -> Result { - use LiteralValue::*; match self { - Null => write!(f, "Null"), - Boolean(val) => write!(f, "{val}"), - Utf8(val) => write!(f, "\"{val}\""), - Binary(val) => write!(f, "Binary[{}]", val.len()), - Int32(val) => write!(f, "{val}"), - UInt32(val) => write!(f, "{val}"), - Int64(val) => write!(f, "{val}"), - UInt64(val) => write!(f, "{val}"), - Date(val) => write!(f, "{}", display_date32(*val)), - Time(val, tu) => write!(f, "{}", display_time64(*val, tu)), - Timestamp(val, tu, tz) => write!(f, "{}", display_timestamp(*val, tu, tz)), - Float64(val) => write!(f, "{val:.1}"), - Decimal(val, precision, scale) => { + Self::Null => write!(f, "Null"), + Self::Boolean(val) => write!(f, "{val}"), + Self::Utf8(val) => write!(f, "\"{val}\""), + Self::Binary(val) => write!(f, "Binary[{}]", val.len()), + Self::Int32(val) => write!(f, "{val}"), + Self::UInt32(val) => write!(f, "{val}"), + Self::Int64(val) => write!(f, "{val}"), + Self::UInt64(val) => write!(f, "{val}"), + Self::Date(val) => write!(f, "{}", display_date32(*val)), + Self::Time(val, tu) => write!(f, "{}", display_time64(*val, tu)), + Self::Timestamp(val, tu, tz) => write!(f, "{}", display_timestamp(*val, tu, tz)), + Self::Duration(val, tu) => write!(f, "{}", display_duration(*val, tu)), + Self::Float64(val) => write!(f, "{val:.1}"), + Self::Decimal(val, precision, scale) => { write!(f, "{}", display_decimal128(*val, *precision, *scale)) } - Series(series) => write!(f, "{}", display_series_literal(series)), + Self::Series(series) => write!(f, "{}", display_series_literal(series)), #[cfg(feature = "python")] - Python(pyobj) => write!(f, "PyObject({})", { + Self::Python(pyobj) => write!(f, "PyObject({})", { use pyo3::prelude::*; Python::with_gil(|py| pyobj.0.call_method0(py, pyo3::intern!(py, "__str__"))) .unwrap() }), + Self::Struct(entries) => { + write!(f, "Struct(")?; + for (i, (field, v)) in entries.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}: {}", field.name, v)?; + } + write!(f, ")") + } } } } impl LiteralValue { pub fn get_type(&self) -> DataType { - use LiteralValue::*; match self { - Null => DataType::Null, - Boolean(_) => DataType::Boolean, - Utf8(_) => DataType::Utf8, - Binary(_) => DataType::Binary, - Int32(_) => DataType::Int32, - UInt32(_) => DataType::UInt32, - Int64(_) => DataType::Int64, - UInt64(_) => DataType::UInt64, - Date(_) => DataType::Date, - Time(_, tu) => DataType::Time(*tu), - Timestamp(_, tu, tz) => DataType::Timestamp(*tu, tz.clone()), - Float64(_) => DataType::Float64, - Decimal(_, precision, scale) => { + Self::Null => DataType::Null, + Self::Boolean(_) => DataType::Boolean, + Self::Utf8(_) => DataType::Utf8, + Self::Binary(_) => DataType::Binary, + Self::Int32(_) => DataType::Int32, + Self::UInt32(_) => DataType::UInt32, + Self::Int64(_) => DataType::Int64, + Self::UInt64(_) => DataType::UInt64, + Self::Date(_) => DataType::Date, + Self::Time(_, tu) => DataType::Time(*tu), + Self::Timestamp(_, tu, tz) => DataType::Timestamp(*tu, tz.clone()), + Self::Duration(_, tu) => DataType::Duration(*tu), + Self::Float64(_) => DataType::Float64, + Self::Decimal(_, precision, scale) => { DataType::Decimal128(*precision as usize, *scale as usize) } - Series(series) => series.data_type().clone(), + Self::Series(series) => series.data_type().clone(), #[cfg(feature = "python")] - Python(_) => DataType::Python, + Self::Python(_) => DataType::Python, + Self::Struct(entries) => DataType::Struct(entries.keys().cloned().collect()), } } pub fn to_series(&self) -> Series { - use LiteralValue::*; - let result = match self { - Null => NullArray::full_null("literal", &DataType::Null, 1).into_series(), - Boolean(val) => BooleanArray::from(("literal", [*val].as_slice())).into_series(), - Utf8(val) => Utf8Array::from(("literal", [val.as_str()].as_slice())).into_series(), - Binary(val) => BinaryArray::from(("literal", val.as_slice())).into_series(), - Int32(val) => Int32Array::from(("literal", [*val].as_slice())).into_series(), - UInt32(val) => UInt32Array::from(("literal", [*val].as_slice())).into_series(), - Int64(val) => Int64Array::from(("literal", [*val].as_slice())).into_series(), - UInt64(val) => UInt64Array::from(("literal", [*val].as_slice())).into_series(), - Date(val) => { + match self { + Self::Null => NullArray::full_null("literal", &DataType::Null, 1).into_series(), + Self::Boolean(val) => BooleanArray::from(("literal", [*val].as_slice())).into_series(), + Self::Utf8(val) => { + Utf8Array::from(("literal", [val.as_str()].as_slice())).into_series() + } + Self::Binary(val) => BinaryArray::from(("literal", val.as_slice())).into_series(), + Self::Int32(val) => Int32Array::from(("literal", [*val].as_slice())).into_series(), + Self::UInt32(val) => UInt32Array::from(("literal", [*val].as_slice())).into_series(), + Self::Int64(val) => Int64Array::from(("literal", [*val].as_slice())).into_series(), + Self::UInt64(val) => UInt64Array::from(("literal", [*val].as_slice())).into_series(), + Self::Date(val) => { let physical = Int32Array::from(("literal", [*val].as_slice())); DateArray::new(Field::new("literal", self.get_type()), physical).into_series() } - Time(val, ..) => { + Self::Time(val, ..) => { let physical = Int64Array::from(("literal", [*val].as_slice())); TimeArray::new(Field::new("literal", self.get_type()), physical).into_series() } - Timestamp(val, ..) => { + Self::Timestamp(val, ..) => { let physical = Int64Array::from(("literal", [*val].as_slice())); TimestampArray::new(Field::new("literal", self.get_type()), physical).into_series() } - Float64(val) => Float64Array::from(("literal", [*val].as_slice())).into_series(), - Decimal(val, ..) => { + Self::Duration(val, ..) => { + let physical = Int64Array::from(("literal", [*val].as_slice())); + DurationArray::new(Field::new("literal", self.get_type()), physical).into_series() + } + Self::Float64(val) => Float64Array::from(("literal", [*val].as_slice())).into_series(), + Self::Decimal(val, ..) => { let physical = Int128Array::from(("literal", [*val].as_slice())); Decimal128Array::new(Field::new("literal", self.get_type()), physical).into_series() } - Series(series) => series.clone().rename("literal"), + Self::Series(series) => series.clone().rename("literal"), #[cfg(feature = "python")] - Python(val) => PythonArray::from(("literal", vec![val.0.clone()])).into_series(), - }; - result + Self::Python(val) => PythonArray::from(("literal", vec![val.0.clone()])).into_series(), + Self::Struct(entries) => { + let struct_dtype = DataType::Struct(entries.keys().cloned().collect()); + let struct_field = Field::new("literal", struct_dtype); + + let values = entries.values().map(|v| v.to_series()).collect(); + StructArray::new(struct_field, values, None).into_series() + } + } } pub fn display_sql(&self, buffer: &mut W) -> io::Result<()> { - use LiteralValue::*; let display_sql_err = Err(io::Error::new( io::ErrorKind::Other, "Unsupported literal for SQL translation", )); match self { - Null => write!(buffer, "NULL"), - Boolean(v) => write!(buffer, "{}", v), - Int32(val) => write!(buffer, "{}", val), - UInt32(val) => write!(buffer, "{}", val), - Int64(val) => write!(buffer, "{}", val), - UInt64(val) => write!(buffer, "{}", val), - Float64(val) => write!(buffer, "{}", val), - Utf8(val) => write!(buffer, "'{}'", val), - Date(val) => write!(buffer, "DATE '{}'", display_date32(*val)), + Self::Null => write!(buffer, "NULL"), + Self::Boolean(v) => write!(buffer, "{}", v), + Self::Int32(val) => write!(buffer, "{}", val), + Self::UInt32(val) => write!(buffer, "{}", val), + Self::Int64(val) => write!(buffer, "{}", val), + Self::UInt64(val) => write!(buffer, "{}", val), + Self::Float64(val) => write!(buffer, "{}", val), + Self::Utf8(val) => write!(buffer, "'{}'", val), + Self::Date(val) => write!(buffer, "DATE '{}'", display_date32(*val)), // The `display_timestamp` function formats a timestamp in the ISO 8601 format: "YYYY-MM-DDTHH:MM:SS.fffff". // ANSI SQL standard uses a space instead of 'T'. Some databases do not support 'T', hence it's replaced with a space. // Reference: https://docs.actian.com/ingres/10s/index.html#page/SQLRef/Summary_of_ANSI_Date_2fTime_Data_Types.html - Timestamp(val, tu, tz) => write!( + Self::Timestamp(val, tu, tz) => write!( buffer, "TIMESTAMP '{}'", display_timestamp(*val, tu, tz).replace('T', " ") ), // TODO(Colin): Implement the rest of the types in future work for SQL pushdowns. - Decimal(..) | Series(..) | Time(..) | Binary(..) => display_sql_err, + Self::Decimal(..) + | Self::Series(..) + | Self::Time(..) + | Self::Binary(..) + | Self::Duration(..) => display_sql_err, #[cfg(feature = "python")] - Python(..) => display_sql_err, + Self::Python(..) => display_sql_err, + Self::Struct(..) => display_sql_err, } } @@ -304,49 +343,64 @@ impl LiteralValue { } } -pub trait Literal { +pub trait Literal: Sized { /// [Literal](Expr::Literal) expression. - fn lit(self) -> ExprRef; + fn lit(self) -> ExprRef { + Expr::Literal(self.literal_value()).into() + } + fn literal_value(self) -> LiteralValue; } impl Literal for String { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::Utf8(self)).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::Utf8(self) } } impl<'a> Literal for &'a str { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::Utf8(self.to_owned())).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::Utf8(self.to_owned()) } } macro_rules! make_literal { ($TYPE:ty, $SCALAR:ident) => { impl Literal for $TYPE { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::$SCALAR(self)).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::$SCALAR(self) } } }; } impl<'a> Literal for &'a [u8] { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::Binary(self.to_vec())).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::Binary(self.to_vec()) } } impl Literal for Series { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::Series(self)).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::Series(self) } } #[cfg(feature = "python")] impl Literal for pyo3::PyObject { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::Python(PyObjectWrapper(self))).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::Python(PyObjectWrapper(self)) + } +} + +impl Literal for Option +where + T: Literal, +{ + fn literal_value(self) -> LiteralValue { + match self { + Some(val) => val.literal_value(), + None => LiteralValue::Null, + } } } @@ -361,6 +415,10 @@ pub fn lit(t: L) -> ExprRef { t.lit() } +pub fn literal_value(t: L) -> LiteralValue { + t.literal_value() +} + pub fn null_lit() -> ExprRef { Arc::new(Expr::Literal(LiteralValue::Null)) } diff --git a/src/daft-dsl/src/pyobj_serde.rs b/src/daft-dsl/src/pyobj_serde.rs index e5ec52d70b..8abdfc76bd 100644 --- a/src/daft-dsl/src/pyobj_serde.rs +++ b/src/daft-dsl/src/pyobj_serde.rs @@ -35,7 +35,7 @@ impl Hash for PyObjectWrapper { Err(_) => { let hasher = HashWriter { state }; bincode::serialize_into(hasher, self) - .expect("Pickling error occurred when computing hash of Pyobject") + .expect("Pickling error occurred when computing hash of Pyobject"); } } } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index a4a54e74c9..e0c6dc1700 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -46,6 +46,12 @@ pub fn timestamp_lit(val: i64, tu: PyTimeUnit, tz: Option) -> PyResult

PyResult { + let expr = Expr::Literal(LiteralValue::Duration(val, tu.timeunit)); + Ok(expr.into()) +} + fn decimal_from_digits(digits: Vec, exp: i32) -> Option<(i128, usize)> { const MAX_ABS_DEC: i128 = 10_i128.pow(38) - 1; let mut v = 0_i128; @@ -222,7 +228,9 @@ pub fn stateful_udf( /// Extracts the `class PartialStatefulUDF` Python objects that are in the specified expression tree #[pyfunction] -pub fn extract_partial_stateful_udf_py(expr: PyExpr) -> HashMap> { +pub fn extract_partial_stateful_udf_py( + expr: PyExpr, +) -> HashMap, Option>)> { use crate::functions::python::extract_partial_stateful_udf_py; extract_partial_stateful_udf_py(expr.expr) } @@ -299,7 +307,7 @@ impl PyExpr { ApproxPercentileInput::Many(p) => (p, true), }; - for &p in percentiles.iter() { + for &p in &percentiles { if !(0. ..=1.).contains(&p) { return Err(PyValueError::new_err(format!( "Provided percentile must be between 0 and 1: {}", @@ -319,6 +327,10 @@ impl PyExpr { Ok(self.expr.clone().mean().into()) } + pub fn stddev(&self) -> PyResult { + Ok(self.expr.clone().stddev().into()) + } + pub fn min(&self) -> PyResult { Ok(self.expr.clone().min().into()) } diff --git a/src/daft-dsl/src/resolve_expr.rs b/src/daft-dsl/src/resolve_expr/mod.rs similarity index 67% rename from src/daft-dsl/src/resolve_expr.rs rename to src/daft-dsl/src/resolve_expr/mod.rs index df686f60a0..5888774fe4 100644 --- a/src/daft-dsl/src/resolve_expr.rs +++ b/src/daft-dsl/src/resolve_expr/mod.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +mod tests; + use std::{ cmp::Ordering, collections::{BinaryHeap, HashMap}, @@ -92,7 +95,7 @@ fn transform_struct_gets( }), _ => Ok(Transformed::no(e)), }) - .data() + .data() } // Finds the names of all the wildcard expressions in an expression tree. @@ -197,50 +200,52 @@ fn expand_wildcards( .collect() } _ => Err(DaftError::ValueError(format!( - "Error resolving expression {}: cannot have multiple wildcard columns in one expression tree (found {:?})", expr, wildcards - ))) + "Error resolving expression {expr}: cannot have multiple wildcard columns in one expression tree (found {wildcards:?})"))) } } fn extract_agg_expr(expr: &Expr) -> DaftResult { - use crate::Expr::*; - match expr { - Agg(agg_expr) => Ok(agg_expr.clone()), - Function { func, inputs } => Ok(AggExpr::MapGroups { + Expr::Agg(agg_expr) => Ok(agg_expr.clone()), + Expr::Function { func, inputs } => Ok(AggExpr::MapGroups { func: func.clone(), inputs: inputs.clone(), }), - Alias(e, name) => extract_agg_expr(e).map(|agg_expr| { - use crate::AggExpr::*; - + Expr::Alias(e, name) => extract_agg_expr(e).map(|agg_expr| { // reorder expressions so that alias goes before agg match agg_expr { - Count(e, count_mode) => Count(Alias(e, name.clone()).into(), count_mode), - Sum(e) => Sum(Alias(e, name.clone()).into()), - ApproxPercentile(ApproxPercentileParams { + AggExpr::Count(e, count_mode) => { + AggExpr::Count(Expr::Alias(e, name.clone()).into(), count_mode) + } + AggExpr::Sum(e) => AggExpr::Sum(Expr::Alias(e, name.clone()).into()), + AggExpr::ApproxPercentile(ApproxPercentileParams { child: e, percentiles, force_list_output, - }) => ApproxPercentile(ApproxPercentileParams { - child: Alias(e, name.clone()).into(), + }) => AggExpr::ApproxPercentile(ApproxPercentileParams { + child: Expr::Alias(e, name.clone()).into(), percentiles, force_list_output, }), - ApproxCountDistinct(e) => ApproxCountDistinct(Alias(e, name.clone()).into()), - ApproxSketch(e, sketch_type) => { - ApproxSketch(Alias(e, name.clone()).into(), sketch_type) + AggExpr::ApproxCountDistinct(e) => { + AggExpr::ApproxCountDistinct(Expr::Alias(e, name.clone()).into()) } - MergeSketch(e, sketch_type) => { - MergeSketch(Alias(e, name.clone()).into(), sketch_type) + AggExpr::ApproxSketch(e, sketch_type) => { + AggExpr::ApproxSketch(Expr::Alias(e, name.clone()).into(), sketch_type) } - Mean(e) => Mean(Alias(e, name.clone()).into()), - Min(e) => Min(Alias(e, name.clone()).into()), - Max(e) => Max(Alias(e, name.clone()).into()), - AnyValue(e, ignore_nulls) => AnyValue(Alias(e, name.clone()).into(), ignore_nulls), - List(e) => List(Alias(e, name.clone()).into()), - Concat(e) => Concat(Alias(e, name.clone()).into()), - MapGroups { func, inputs } => MapGroups { + AggExpr::MergeSketch(e, sketch_type) => { + AggExpr::MergeSketch(Expr::Alias(e, name.clone()).into(), sketch_type) + } + AggExpr::Mean(e) => AggExpr::Mean(Expr::Alias(e, name.clone()).into()), + AggExpr::Stddev(e) => AggExpr::Stddev(Expr::Alias(e, name.clone()).into()), + AggExpr::Min(e) => AggExpr::Min(Expr::Alias(e, name.clone()).into()), + AggExpr::Max(e) => AggExpr::Max(Expr::Alias(e, name.clone()).into()), + AggExpr::AnyValue(e, ignore_nulls) => { + AggExpr::AnyValue(Expr::Alias(e, name.clone()).into(), ignore_nulls) + } + AggExpr::List(e) => AggExpr::List(Expr::Alias(e, name.clone()).into()), + AggExpr::Concat(e) => AggExpr::Concat(Expr::Alias(e, name.clone()).into()), + AggExpr::MapGroups { func, inputs } => AggExpr::MapGroups { func, inputs: inputs .into_iter() @@ -410,148 +415,3 @@ pub fn check_column_name_validity(name: &str, schema: &Schema) -> DaftResult<()> Ok(()) } - -#[cfg(test)] -mod tests { - use super::*; - - fn substitute_expr_getter_sugar(expr: ExprRef, schema: &Schema) -> DaftResult { - let struct_expr_map = calculate_struct_expr_map(schema); - transform_struct_gets(expr, &struct_expr_map) - } - - #[test] - fn test_substitute_expr_getter_sugar() -> DaftResult<()> { - use crate::functions::struct_::get as struct_get; - - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64)])?); - - assert_eq!(substitute_expr_getter_sugar(col("a"), &schema)?, col("a")); - assert!(substitute_expr_getter_sugar(col("a.b"), &schema).is_err()); - assert!(matches!( - substitute_expr_getter_sugar(col("a.b"), &schema).unwrap_err(), - DaftError::ValueError(..) - )); - - let schema = Arc::new(Schema::new(vec![Field::new( - "a", - DataType::Struct(vec![Field::new("b", DataType::Int64)]), - )])?); - - assert_eq!(substitute_expr_getter_sugar(col("a"), &schema)?, col("a")); - assert_eq!( - substitute_expr_getter_sugar(col("a.b"), &schema)?, - struct_get(col("a"), "b") - ); - assert_eq!( - substitute_expr_getter_sugar(col("a.b").alias("c"), &schema)?, - struct_get(col("a"), "b").alias("c") - ); - - let schema = Arc::new(Schema::new(vec![Field::new( - "a", - DataType::Struct(vec![Field::new( - "b", - DataType::Struct(vec![Field::new("c", DataType::Int64)]), - )]), - )])?); - - assert_eq!( - substitute_expr_getter_sugar(col("a.b"), &schema)?, - struct_get(col("a"), "b") - ); - assert_eq!( - substitute_expr_getter_sugar(col("a.b.c"), &schema)?, - struct_get(struct_get(col("a"), "b"), "c") - ); - - let schema = Arc::new(Schema::new(vec![ - Field::new( - "a", - DataType::Struct(vec![Field::new( - "b", - DataType::Struct(vec![Field::new("c", DataType::Int64)]), - )]), - ), - Field::new("a.b", DataType::Int64), - ])?); - - assert_eq!( - substitute_expr_getter_sugar(col("a.b"), &schema)?, - col("a.b") - ); - assert_eq!( - substitute_expr_getter_sugar(col("a.b.c"), &schema)?, - struct_get(struct_get(col("a"), "b"), "c") - ); - - let schema = Arc::new(Schema::new(vec![ - Field::new( - "a", - DataType::Struct(vec![Field::new("b.c", DataType::Int64)]), - ), - Field::new( - "a.b", - DataType::Struct(vec![Field::new("c", DataType::Int64)]), - ), - ])?); - - assert_eq!( - substitute_expr_getter_sugar(col("a.b.c"), &schema)?, - struct_get(col("a.b"), "c") - ); - - Ok(()) - } - - #[test] - fn test_find_wildcards() -> DaftResult<()> { - let schema = Schema::new(vec![ - Field::new( - "a", - DataType::Struct(vec![Field::new("b.*", DataType::Int64)]), - ), - Field::new("c.*", DataType::Int64), - ])?; - let struct_expr_map = calculate_struct_expr_map(&schema); - - let wildcards = find_wildcards(col("test"), &struct_expr_map); - assert!(wildcards.is_empty()); - - let wildcards = find_wildcards(col("*"), &struct_expr_map); - assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "*"); - - let wildcards = find_wildcards(col("t*"), &struct_expr_map); - assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "t*"); - - let wildcards = find_wildcards(col("a.*"), &struct_expr_map); - assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "a.*"); - - let wildcards = find_wildcards(col("c.*"), &struct_expr_map); - assert!(wildcards.is_empty()); - - let wildcards = find_wildcards(col("a.b.*"), &struct_expr_map); - assert!(wildcards.is_empty()); - - let wildcards = find_wildcards(col("a.b*"), &struct_expr_map); - assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "a.b*"); - - // nested expression - let wildcards = find_wildcards(col("t*").add(col("a.*")), &struct_expr_map); - assert!(wildcards.len() == 2); - assert!(wildcards.iter().any(|s| s.as_ref() == "t*")); - assert!(wildcards.iter().any(|s| s.as_ref() == "a.*")); - - let wildcards = find_wildcards(col("t*").add(col("a")), &struct_expr_map); - assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "t*"); - - // schema containing * - let schema = Schema::new(vec![Field::new("*", DataType::Int64)])?; - let struct_expr_map = calculate_struct_expr_map(&schema); - - let wildcards = find_wildcards(col("*"), &struct_expr_map); - assert!(wildcards.is_empty()); - - Ok(()) - } -} diff --git a/src/daft-dsl/src/resolve_expr/tests.rs b/src/daft-dsl/src/resolve_expr/tests.rs new file mode 100644 index 0000000000..dcb3147207 --- /dev/null +++ b/src/daft-dsl/src/resolve_expr/tests.rs @@ -0,0 +1,141 @@ +use super::*; + +fn substitute_expr_getter_sugar(expr: ExprRef, schema: &Schema) -> DaftResult { + let struct_expr_map = calculate_struct_expr_map(schema); + transform_struct_gets(expr, &struct_expr_map) +} + +#[test] +fn test_substitute_expr_getter_sugar() -> DaftResult<()> { + use crate::functions::struct_::get as struct_get; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64)])?); + + assert_eq!(substitute_expr_getter_sugar(col("a"), &schema)?, col("a")); + assert!(substitute_expr_getter_sugar(col("a.b"), &schema).is_err()); + assert!(matches!( + substitute_expr_getter_sugar(col("a.b"), &schema).unwrap_err(), + DaftError::ValueError(..) + )); + + let schema = Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Struct(vec![Field::new("b", DataType::Int64)]), + )])?); + + assert_eq!(substitute_expr_getter_sugar(col("a"), &schema)?, col("a")); + assert_eq!( + substitute_expr_getter_sugar(col("a.b"), &schema)?, + struct_get(col("a"), "b") + ); + assert_eq!( + substitute_expr_getter_sugar(col("a.b").alias("c"), &schema)?, + struct_get(col("a"), "b").alias("c") + ); + + let schema = Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Struct(vec![Field::new( + "b", + DataType::Struct(vec![Field::new("c", DataType::Int64)]), + )]), + )])?); + + assert_eq!( + substitute_expr_getter_sugar(col("a.b"), &schema)?, + struct_get(col("a"), "b") + ); + assert_eq!( + substitute_expr_getter_sugar(col("a.b.c"), &schema)?, + struct_get(struct_get(col("a"), "b"), "c") + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "a", + DataType::Struct(vec![Field::new( + "b", + DataType::Struct(vec![Field::new("c", DataType::Int64)]), + )]), + ), + Field::new("a.b", DataType::Int64), + ])?); + + assert_eq!( + substitute_expr_getter_sugar(col("a.b"), &schema)?, + col("a.b") + ); + assert_eq!( + substitute_expr_getter_sugar(col("a.b.c"), &schema)?, + struct_get(struct_get(col("a"), "b"), "c") + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "a", + DataType::Struct(vec![Field::new("b.c", DataType::Int64)]), + ), + Field::new( + "a.b", + DataType::Struct(vec![Field::new("c", DataType::Int64)]), + ), + ])?); + + assert_eq!( + substitute_expr_getter_sugar(col("a.b.c"), &schema)?, + struct_get(col("a.b"), "c") + ); + + Ok(()) +} + +#[test] +fn test_find_wildcards() -> DaftResult<()> { + let schema = Schema::new(vec![ + Field::new( + "a", + DataType::Struct(vec![Field::new("b.*", DataType::Int64)]), + ), + Field::new("c.*", DataType::Int64), + ])?; + let struct_expr_map = calculate_struct_expr_map(&schema); + + let wildcards = find_wildcards(col("test"), &struct_expr_map); + assert!(wildcards.is_empty()); + + let wildcards = find_wildcards(col("*"), &struct_expr_map); + assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "*"); + + let wildcards = find_wildcards(col("t*"), &struct_expr_map); + assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "t*"); + + let wildcards = find_wildcards(col("a.*"), &struct_expr_map); + assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "a.*"); + + let wildcards = find_wildcards(col("c.*"), &struct_expr_map); + assert!(wildcards.is_empty()); + + let wildcards = find_wildcards(col("a.b.*"), &struct_expr_map); + assert!(wildcards.is_empty()); + + let wildcards = find_wildcards(col("a.b*"), &struct_expr_map); + assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "a.b*"); + + // nested expression + let wildcards = find_wildcards(col("t*").add(col("a.*")), &struct_expr_map); + assert!(wildcards.len() == 2); + assert!(wildcards.iter().any(|s| s.as_ref() == "t*")); + assert!(wildcards.iter().any(|s| s.as_ref() == "a.*")); + + let wildcards = find_wildcards(col("t*").add(col("a")), &struct_expr_map); + assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "t*"); + + // schema containing * + let schema = Schema::new(vec![Field::new("*", DataType::Int64)])?; + let struct_expr_map = calculate_struct_expr_map(&schema); + + let wildcards = find_wildcards(col("*"), &struct_expr_map); + assert!(wildcards.is_empty()); + + Ok(()) +} diff --git a/src/daft-functions-json/src/lib.rs b/src/daft-functions-json/src/lib.rs index f04a99fd82..6c57b15039 100644 --- a/src/daft-functions-json/src/lib.rs +++ b/src/daft-functions-json/src/lib.rs @@ -32,7 +32,7 @@ fn compile_filter(query: &str) -> DaftResult { if !errs.is_empty() { return Err(DaftError::ValueError(format!( "Error parsing json query ({query}): {}", - errs.iter().map(|e| e.to_string()).join(", ") + errs.iter().map(std::string::ToString::to_string).join(", ") ))); } @@ -92,8 +92,7 @@ pub fn json_query_series(s: &Series, query: &str) -> DaftResult { json_query_impl(arr, query).map(daft_core::series::IntoSeries::into_series) } dt => Err(DaftError::TypeError(format!( - "json query not implemented for {}", - dt + "json query not implemented for {dt}" ))), } } @@ -108,6 +107,7 @@ pub fn json_query_series(s: &Series, query: &str) -> DaftResult { /// # Returns /// /// A `DaftResult` containing the resulting UTF-8 array after applying the query. +#[must_use] pub fn json_query(input: ExprRef, query: &str) -> ExprRef { ScalarFunction::new( JsonQuery { @@ -153,7 +153,7 @@ mod tests { .into_iter(), ); - let query = r#".foo.bar"#; + let query = r".foo.bar"; let result = json_query_impl(&data, query)?; assert_eq!(result.len(), 3); assert_eq!(result.as_arrow().value(0), "1"); diff --git a/src/daft-functions/src/count_matches.rs b/src/daft-functions/src/count_matches.rs index 89df9274a9..9c56974358 100644 --- a/src/daft-functions/src/count_matches.rs +++ b/src/daft-functions/src/count_matches.rs @@ -7,9 +7,9 @@ use daft_dsl::{ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -struct CountMatchesFunction { - pub(super) whole_words: bool, - pub(super) case_sensitive: bool, +pub struct CountMatchesFunction { + pub whole_words: bool, + pub case_sensitive: bool, } #[typetag::serde] @@ -53,6 +53,7 @@ impl ScalarUDF for CountMatchesFunction { } } +#[must_use] pub fn utf8_count_matches( input: ExprRef, patterns: ExprRef, diff --git a/src/daft-functions/src/distance/cosine.rs b/src/daft-functions/src/distance/cosine.rs index 170587c1bb..11b3d1eef2 100644 --- a/src/daft-functions/src/distance/cosine.rs +++ b/src/daft-functions/src/distance/cosine.rs @@ -26,10 +26,10 @@ impl SpatialSimilarity for f32 { let xy = a .iter() .zip(b) - .map(|(a, b)| *a as f64 * *b as f64) + .map(|(a, b)| f64::from(*a) * f64::from(*b)) .sum::(); - let x_sq = a.iter().map(|x| (*x as f64).powi(2)).sum::().sqrt(); - let y_sq = b.iter().map(|x| (*x as f64).powi(2)).sum::().sqrt(); + let x_sq = a.iter().map(|x| f64::from(*x).powi(2)).sum::().sqrt(); + let y_sq = b.iter().map(|x| f64::from(*x).powi(2)).sum::().sqrt(); Some(1.0 - xy / (x_sq * y_sq)) } } @@ -39,10 +39,10 @@ impl SpatialSimilarity for i8 { let xy = a .iter() .zip(b) - .map(|(a, b)| *a as f64 * *b as f64) + .map(|(a, b)| f64::from(*a) * f64::from(*b)) .sum::(); - let x_sq = a.iter().map(|x| (*x as f64).powi(2)).sum::().sqrt(); - let y_sq = b.iter().map(|x| (*x as f64).powi(2)).sum::().sqrt(); + let x_sq = a.iter().map(|x| f64::from(*x).powi(2)).sum::().sqrt(); + let y_sq = b.iter().map(|x| f64::from(*x).powi(2)).sum::().sqrt(); Some(1.0 - xy / (x_sq * y_sq)) } } @@ -140,8 +140,7 @@ impl ScalarUDF for CosineDistanceFunction { { if source_size != query_size { return Err(DaftError::ValueError(format!( - "Expected source and query to have the same size, instead got {} and {}", - source_size, query_size + "Expected source and query to have the same size, instead got {source_size} and {query_size}" ))); } } else { @@ -165,6 +164,7 @@ impl ScalarUDF for CosineDistanceFunction { } } +#[must_use] pub fn cosine_distance(a: ExprRef, b: ExprRef) -> ExprRef { ScalarFunction::new(CosineDistanceFunction {}, vec![a, b]).into() } diff --git a/src/daft-functions/src/float/fill_nan.rs b/src/daft-functions/src/float/fill_nan.rs index e79dd0a936..b2519c567d 100644 --- a/src/daft-functions/src/float/fill_nan.rs +++ b/src/daft-functions/src/float/fill_nan.rs @@ -53,6 +53,7 @@ impl ScalarUDF for FillNan { } } +#[must_use] pub fn fill_nan(input: ExprRef, fill_value: ExprRef) -> ExprRef { ScalarFunction::new(FillNan {}, vec![input, fill_value]).into() } diff --git a/src/daft-functions/src/float/is_inf.rs b/src/daft-functions/src/float/is_inf.rs index a46e221255..ebb00140d9 100644 --- a/src/daft-functions/src/float/is_inf.rs +++ b/src/daft-functions/src/float/is_inf.rs @@ -53,6 +53,7 @@ impl ScalarUDF for IsInf { } } +#[must_use] pub fn is_inf(input: ExprRef) -> ExprRef { ScalarFunction::new(IsInf {}, vec![input]).into() } diff --git a/src/daft-functions/src/float/is_nan.rs b/src/daft-functions/src/float/is_nan.rs index 365c09b80c..f75c605694 100644 --- a/src/daft-functions/src/float/is_nan.rs +++ b/src/daft-functions/src/float/is_nan.rs @@ -53,6 +53,7 @@ impl ScalarUDF for IsNan { } } +#[must_use] pub fn is_nan(input: ExprRef) -> ExprRef { ScalarFunction::new(IsNan {}, vec![input]).into() } diff --git a/src/daft-functions/src/float/not_nan.rs b/src/daft-functions/src/float/not_nan.rs index 87bca04011..396a6b5217 100644 --- a/src/daft-functions/src/float/not_nan.rs +++ b/src/daft-functions/src/float/not_nan.rs @@ -53,6 +53,7 @@ impl ScalarUDF for NotNan { } } +#[must_use] pub fn not_nan(input: ExprRef) -> ExprRef { ScalarFunction::new(NotNan {}, vec![input]).into() } diff --git a/src/daft-functions/src/hash.rs b/src/daft-functions/src/hash.rs index d21f49768a..f7ab7a7a30 100644 --- a/src/daft-functions/src/hash.rs +++ b/src/daft-functions/src/hash.rs @@ -21,7 +21,7 @@ impl ScalarUDF for HashFunction { fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { - [input] => input.hash(None).map(|s| s.into_series()), + [input] => input.hash(None).map(|arr| arr.into_series()), [input, seed] => { match seed.len() { 1 => { @@ -33,13 +33,17 @@ impl ScalarUDF for HashFunction { "seed", std::iter::repeat(Some(seed)).take(input.len()), ); - input.hash(Some(&seed)).map(|s| s.into_series()) + input + .hash(Some(&seed)) + .map(daft_core::series::IntoSeries::into_series) } _ if seed.len() == input.len() => { let seed = seed.cast(&DataType::UInt64)?; let seed = seed.u64().unwrap(); - input.hash(Some(seed)).map(|s| s.into_series()) + input + .hash(Some(seed)) + .map(daft_core::series::IntoSeries::into_series) } _ => Err(DaftError::ValueError( "Seed must be a single value or the same length as the input".to_string(), @@ -64,6 +68,7 @@ impl ScalarUDF for HashFunction { } } +#[must_use] pub fn hash(input: ExprRef, seed: Option) -> ExprRef { let inputs = match seed { Some(seed) => vec![input, seed], @@ -81,6 +86,6 @@ pub mod python { #[pyfunction] pub fn hash(expr: PyExpr, seed: Option) -> PyResult { use super::hash; - Ok(hash(expr.into(), seed.map(|s| s.into())).into()) + Ok(hash(expr.into(), seed.map(std::convert::Into::into)).into()) } } diff --git a/src/daft-functions/src/image/crop.rs b/src/daft-functions/src/image/crop.rs index ee464c2ca9..e69485daa4 100644 --- a/src/daft-functions/src/image/crop.rs +++ b/src/daft-functions/src/image/crop.rs @@ -42,8 +42,7 @@ impl ScalarUDF for ImageCrop { dtype => { return Err(DaftError::TypeError( format!( - "bbox list field must be List with numeric child type or FixedSizeList with size 4, got {}", - dtype + "bbox list field must be List with numeric child type or FixedSizeList with size 4, got {dtype}" ) )); } @@ -56,8 +55,7 @@ impl ScalarUDF for ImageCrop { Ok(Field::new(input_field.name, DataType::Image(Some(*mode)))) } _ => Err(DaftError::TypeError(format!( - "Image crop can only crop ImageArrays and FixedShapeImage, got {}", - input_field + "Image crop can only crop ImageArrays and FixedShapeImage, got {input_field}" ))), } } @@ -79,6 +77,7 @@ impl ScalarUDF for ImageCrop { } } +#[must_use] pub fn crop(input: ExprRef, bbox: ExprRef) -> ExprRef { ScalarFunction::new(ImageCrop {}, vec![input, bbox]).into() } diff --git a/src/daft-functions/src/image/decode.rs b/src/daft-functions/src/image/decode.rs index 99aabbee8a..1a81681058 100644 --- a/src/daft-functions/src/image/decode.rs +++ b/src/daft-functions/src/image/decode.rs @@ -47,8 +47,7 @@ impl ScalarUDF for ImageDecode { let field = input.to_field(schema)?; if !matches!(field.dtype, DataType::Binary) { return Err(DaftError::TypeError(format!( - "ImageDecode can only decode BinaryArrays, got {}", - field + "ImageDecode can only decode BinaryArrays, got {field}" ))); } Ok(Field::new(field.name, DataType::Image(self.mode))) @@ -72,6 +71,7 @@ impl ScalarUDF for ImageDecode { } } +#[must_use] pub fn decode(input: ExprRef, args: Option) -> ExprRef { ScalarFunction::new(args.unwrap_or_default(), vec![input]).into() } diff --git a/src/daft-functions/src/image/encode.rs b/src/daft-functions/src/image/encode.rs index f1a5bfaea4..110a2cbb08 100644 --- a/src/daft-functions/src/image/encode.rs +++ b/src/daft-functions/src/image/encode.rs @@ -34,8 +34,7 @@ impl ScalarUDF for ImageEncode { Ok(Field::new(field.name, DataType::Binary)) } _ => Err(DaftError::TypeError(format!( - "ImageEncode can only encode ImageArrays and FixedShapeImageArrays, got {}", - field + "ImageEncode can only encode ImageArrays and FixedShapeImageArrays, got {field}" ))), } } @@ -57,6 +56,7 @@ impl ScalarUDF for ImageEncode { } } +#[must_use] pub fn encode(input: ExprRef, image_encode: ImageEncode) -> ExprRef { ScalarFunction::new(image_encode, vec![input]).into() } diff --git a/src/daft-functions/src/image/resize.rs b/src/daft-functions/src/image/resize.rs index cac9fd7cf1..ea0468f31f 100644 --- a/src/daft-functions/src/image/resize.rs +++ b/src/daft-functions/src/image/resize.rs @@ -36,8 +36,7 @@ impl ScalarUDF for ImageResize { }, DataType::FixedShapeImage(..) => Ok(field.clone()), _ => Err(DaftError::TypeError(format!( - "ImageResize can only resize ImageArrays and FixedShapeImageArrays, got {}", - field + "ImageResize can only resize ImageArrays and FixedShapeImageArrays, got {field}" ))), } } @@ -59,6 +58,7 @@ impl ScalarUDF for ImageResize { } } +#[must_use] pub fn resize(input: ExprRef, w: u32, h: u32) -> ExprRef { ScalarFunction::new( ImageResize { diff --git a/src/daft-functions/src/image/to_mode.rs b/src/daft-functions/src/image/to_mode.rs index 5d46a376dd..8609840f33 100644 --- a/src/daft-functions/src/image/to_mode.rs +++ b/src/daft-functions/src/image/to_mode.rs @@ -32,8 +32,7 @@ impl ScalarUDF for ImageToMode { } _ => { return Err(DaftError::TypeError(format!( - "ToMode can only operate on ImageArrays and FixedShapeImageArrays, got {}", - field + "ToMode can only operate on ImageArrays and FixedShapeImageArrays, got {field}" ))) } }; @@ -57,6 +56,7 @@ impl ScalarUDF for ImageToMode { } } +#[must_use] pub fn image_to_mode(expr: ExprRef, mode: ImageMode) -> ExprRef { ScalarFunction::new(ImageToMode { mode }, vec![expr]).into() } diff --git a/src/daft-functions/src/list/chunk.rs b/src/daft-functions/src/list/chunk.rs index 39743e80b9..1891a42945 100644 --- a/src/daft-functions/src/list/chunk.rs +++ b/src/daft-functions/src/list/chunk.rs @@ -51,6 +51,7 @@ impl ScalarUDF for ListChunk { } } +#[must_use] pub fn list_chunk(expr: ExprRef, size: usize) -> ExprRef { ScalarFunction::new(ListChunk { size }, vec![expr]).into() } diff --git a/src/daft-functions/src/list/count.rs b/src/daft-functions/src/list/count.rs index 08e344e04a..00a3264adb 100644 --- a/src/daft-functions/src/list/count.rs +++ b/src/daft-functions/src/list/count.rs @@ -57,6 +57,7 @@ impl ScalarUDF for ListCount { } } +#[must_use] pub fn list_count(expr: ExprRef, mode: CountMode) -> ExprRef { ScalarFunction::new(ListCount { mode }, vec![expr]).into() } diff --git a/src/daft-functions/src/list/explode.rs b/src/daft-functions/src/list/explode.rs index a2232b33f9..6cf187e291 100644 --- a/src/daft-functions/src/list/explode.rs +++ b/src/daft-functions/src/list/explode.rs @@ -46,6 +46,7 @@ impl ScalarUDF for Explode { } } +#[must_use] pub fn explode(expr: ExprRef) -> ExprRef { ScalarFunction::new(Explode {}, vec![expr]).into() } diff --git a/src/daft-functions/src/list/get.rs b/src/daft-functions/src/list/get.rs index 15f088ce0c..45dc2b8cd0 100644 --- a/src/daft-functions/src/list/get.rs +++ b/src/daft-functions/src/list/get.rs @@ -59,6 +59,7 @@ impl ScalarUDF for ListGet { } } +#[must_use] pub fn list_get(expr: ExprRef, idx: ExprRef, default_value: ExprRef) -> ExprRef { ScalarFunction::new(ListGet {}, vec![expr, idx, default_value]).into() } diff --git a/src/daft-functions/src/list/join.rs b/src/daft-functions/src/list/join.rs index 83d2f87efb..fdb2ea3bcd 100644 --- a/src/daft-functions/src/list/join.rs +++ b/src/daft-functions/src/list/join.rs @@ -70,6 +70,7 @@ impl ScalarUDF for ListJoin { } } +#[must_use] pub fn list_join(expr: ExprRef, delim: ExprRef) -> ExprRef { ScalarFunction::new(ListJoin {}, vec![expr, delim]).into() } diff --git a/src/daft-functions/src/list/max.rs b/src/daft-functions/src/list/max.rs index 22621eb7f9..c6d6ded13e 100644 --- a/src/daft-functions/src/list/max.rs +++ b/src/daft-functions/src/list/max.rs @@ -54,6 +54,7 @@ impl ScalarUDF for ListMax { } } +#[must_use] pub fn list_max(expr: ExprRef) -> ExprRef { ScalarFunction::new(ListMax {}, vec![expr]).into() } diff --git a/src/daft-functions/src/list/mean.rs b/src/daft-functions/src/list/mean.rs index 16a817a9c3..b01d3c1fa1 100644 --- a/src/daft-functions/src/list/mean.rs +++ b/src/daft-functions/src/list/mean.rs @@ -1,6 +1,6 @@ use common_error::{DaftError, DaftResult}; use daft_core::{ - datatypes::try_mean_supertype, + datatypes::try_mean_stddev_aggregation_supertype, prelude::{Field, Schema}, series::Series, }; @@ -29,7 +29,7 @@ impl ScalarUDF for ListMean { let inner_field = input.to_field(schema)?.to_exploded_field()?; Ok(Field::new( inner_field.name.as_str(), - try_mean_supertype(&inner_field.dtype)?, + try_mean_stddev_aggregation_supertype(&inner_field.dtype)?, )) } _ => Err(DaftError::SchemaMismatch(format!( @@ -50,6 +50,7 @@ impl ScalarUDF for ListMean { } } +#[must_use] pub fn list_mean(expr: ExprRef) -> ExprRef { ScalarFunction::new(ListMean {}, vec![expr]).into() } diff --git a/src/daft-functions/src/list/min.rs b/src/daft-functions/src/list/min.rs index 8386b38410..55af30e154 100644 --- a/src/daft-functions/src/list/min.rs +++ b/src/daft-functions/src/list/min.rs @@ -51,6 +51,7 @@ impl ScalarUDF for ListMin { } } +#[must_use] pub fn list_min(expr: ExprRef) -> ExprRef { ScalarFunction::new(ListMin {}, vec![expr]).into() } diff --git a/src/daft-functions/src/list/mod.rs b/src/daft-functions/src/list/mod.rs index 2ba3f197be..c0ad745b19 100644 --- a/src/daft-functions/src/list/mod.rs +++ b/src/daft-functions/src/list/mod.rs @@ -9,6 +9,7 @@ mod min; mod slice; mod sort; mod sum; +mod value_counts; pub use chunk::{list_chunk as chunk, ListChunk}; pub use count::{list_count as count, ListCount}; @@ -31,6 +32,10 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction_bound!(count::py_list_count, parent)?)?; parent.add_function(wrap_pyfunction_bound!(get::py_list_get, parent)?)?; parent.add_function(wrap_pyfunction_bound!(join::py_list_join, parent)?)?; + parent.add_function(wrap_pyfunction_bound!( + value_counts::py_list_value_counts, + parent + )?)?; parent.add_function(wrap_pyfunction_bound!(max::py_list_max, parent)?)?; parent.add_function(wrap_pyfunction_bound!(min::py_list_min, parent)?)?; diff --git a/src/daft-functions/src/list/slice.rs b/src/daft-functions/src/list/slice.rs index f62e47474d..ffde7f0b7a 100644 --- a/src/daft-functions/src/list/slice.rs +++ b/src/daft-functions/src/list/slice.rs @@ -62,6 +62,7 @@ impl ScalarUDF for ListSlice { } } +#[must_use] pub fn list_slice(expr: ExprRef, start: ExprRef, end: ExprRef) -> ExprRef { ScalarFunction::new(ListSlice {}, vec![expr, start, end]).into() } diff --git a/src/daft-functions/src/list/sort.rs b/src/daft-functions/src/list/sort.rs index 3d75e3fa48..2d1ef45afb 100644 --- a/src/daft-functions/src/list/sort.rs +++ b/src/daft-functions/src/list/sort.rs @@ -23,10 +23,10 @@ impl ScalarUDF for ListSort { match inputs { [data, desc] => match (data.to_field(schema), desc.to_field(schema)) { (Ok(field), Ok(desc_field)) => match (&field.dtype, &desc_field.dtype) { - (l @ DataType::List(_), DataType::Boolean) - | (l @ DataType::FixedSizeList(_, _), DataType::Boolean) => { - Ok(Field::new(field.name, l.clone())) - } + ( + l @ (DataType::List(_) | DataType::FixedSizeList(_, _)), + DataType::Boolean, + ) => Ok(Field::new(field.name, l.clone())), (a, b) => Err(DaftError::TypeError(format!( "Expects inputs to list_sort to be list and bool, but received {a} and {b}", ))), @@ -51,6 +51,7 @@ impl ScalarUDF for ListSort { } } +#[must_use] pub fn list_sort(input: ExprRef, desc: Option) -> ExprRef { let desc = desc.unwrap_or_else(|| lit(false)); ScalarFunction::new(ListSort {}, vec![input, desc]).into() diff --git a/src/daft-functions/src/list/sum.rs b/src/daft-functions/src/list/sum.rs index 82883faf26..79c04d9f6f 100644 --- a/src/daft-functions/src/list/sum.rs +++ b/src/daft-functions/src/list/sum.rs @@ -54,6 +54,7 @@ impl ScalarUDF for ListSum { } } +#[must_use] pub fn list_sum(expr: ExprRef) -> ExprRef { ScalarFunction::new(ListSum {}, vec![expr]).into() } diff --git a/src/daft-functions/src/list/value_counts.rs b/src/daft-functions/src/list/value_counts.rs new file mode 100644 index 0000000000..d558db8ac4 --- /dev/null +++ b/src/daft-functions/src/list/value_counts.rs @@ -0,0 +1,72 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::prelude::{DataType, Field, Schema, Series}; +#[cfg(feature = "python")] +use daft_dsl::python::PyExpr; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +#[cfg(feature = "python")] +use pyo3::{pyfunction, PyResult}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +struct ListValueCountsFunction; + +#[typetag::serde] +impl ScalarUDF for ListValueCountsFunction { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &'static str { + "list_value_counts" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + let [data] = inputs else { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + }; + + let data_field = data.to_field(schema)?; + + let DataType::List(inner_type) = &data_field.dtype else { + return Err(DaftError::TypeError(format!( + "Expected list, got {}", + data_field.dtype + ))); + }; + + let map_type = DataType::Map { + key: inner_type.clone(), + value: Box::new(DataType::UInt64), + }; + + Ok(Field::new(data_field.name, map_type)) + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + let [data] = inputs else { + return Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + }; + + data.list_value_counts() + } +} + +pub fn list_value_counts(expr: ExprRef) -> ExprRef { + ScalarFunction::new(ListValueCountsFunction, vec![expr]).into() +} + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "list_value_counts")] +pub fn py_list_value_counts(expr: PyExpr) -> PyResult { + Ok(list_value_counts(expr.into()).into()) +} diff --git a/src/daft-functions/src/minhash.rs b/src/daft-functions/src/minhash.rs index 48d13e0a65..1aaa82b3e5 100644 --- a/src/daft-functions/src/minhash.rs +++ b/src/daft-functions/src/minhash.rs @@ -7,10 +7,10 @@ use daft_dsl::{ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub(super) struct MinHashFunction { - num_hashes: usize, - ngram_size: usize, - seed: u32, +pub struct MinHashFunction { + pub num_hashes: usize, + pub ngram_size: usize, + pub seed: u32, } #[typetag::serde] @@ -55,6 +55,7 @@ impl ScalarUDF for MinHashFunction { } } +#[must_use] pub fn minhash(input: ExprRef, num_hashes: usize, ngram_size: usize, seed: u32) -> ExprRef { ScalarFunction::new( MinHashFunction { diff --git a/src/daft-functions/src/numeric/abs.rs b/src/daft-functions/src/numeric/abs.rs index f054950e0f..133dd82478 100644 --- a/src/daft-functions/src/numeric/abs.rs +++ b/src/daft-functions/src/numeric/abs.rs @@ -31,6 +31,7 @@ impl ScalarUDF for Abs { } } +#[must_use] pub fn abs(input: ExprRef) -> ExprRef { ScalarFunction::new(Abs {}, vec![input]).into() } diff --git a/src/daft-functions/src/numeric/cbrt.rs b/src/daft-functions/src/numeric/cbrt.rs index c9b4e9286f..3b49db2984 100644 --- a/src/daft-functions/src/numeric/cbrt.rs +++ b/src/daft-functions/src/numeric/cbrt.rs @@ -29,6 +29,7 @@ impl ScalarUDF for Cbrt { } } +#[must_use] pub fn cbrt(input: ExprRef) -> ExprRef { ScalarFunction::new(Cbrt {}, vec![input]).into() } diff --git a/src/daft-functions/src/numeric/ceil.rs b/src/daft-functions/src/numeric/ceil.rs index 26c37bec6b..8f733ca332 100644 --- a/src/daft-functions/src/numeric/ceil.rs +++ b/src/daft-functions/src/numeric/ceil.rs @@ -30,6 +30,7 @@ impl ScalarUDF for Ceil { } } +#[must_use] pub fn ceil(input: ExprRef) -> ExprRef { ScalarFunction::new(Ceil {}, vec![input]).into() } diff --git a/src/daft-functions/src/numeric/exp.rs b/src/daft-functions/src/numeric/exp.rs index abde081b46..d56c608c79 100644 --- a/src/daft-functions/src/numeric/exp.rs +++ b/src/daft-functions/src/numeric/exp.rs @@ -49,6 +49,7 @@ impl ScalarUDF for Exp { } } +#[must_use] pub fn exp(input: ExprRef) -> ExprRef { ScalarFunction::new(Exp {}, vec![input]).into() } diff --git a/src/daft-functions/src/numeric/floor.rs b/src/daft-functions/src/numeric/floor.rs index 36ec365e0f..9debcc4823 100644 --- a/src/daft-functions/src/numeric/floor.rs +++ b/src/daft-functions/src/numeric/floor.rs @@ -31,6 +31,7 @@ impl ScalarUDF for Floor { } } +#[must_use] pub fn floor(input: ExprRef) -> ExprRef { ScalarFunction::new(Floor {}, vec![input]).into() } diff --git a/src/daft-functions/src/numeric/log.rs b/src/daft-functions/src/numeric/log.rs index 7aecb2de56..4e90f20672 100644 --- a/src/daft-functions/src/numeric/log.rs +++ b/src/daft-functions/src/numeric/log.rs @@ -52,6 +52,7 @@ macro_rules! log { } } + #[must_use] pub fn $name(input: ExprRef) -> ExprRef { ScalarFunction::new($variant, vec![input]).into() } @@ -101,6 +102,7 @@ impl ScalarUDF for Log { } } +#[must_use] pub fn log(input: ExprRef, base: f64) -> ExprRef { ScalarFunction::new(Log(FloatWrapper(base)), vec![input]).into() } diff --git a/src/daft-functions/src/numeric/round.rs b/src/daft-functions/src/numeric/round.rs index 395b0ee696..bf7a51ed5d 100644 --- a/src/daft-functions/src/numeric/round.rs +++ b/src/daft-functions/src/numeric/round.rs @@ -33,6 +33,7 @@ impl ScalarUDF for Round { } } +#[must_use] pub fn round(input: ExprRef, decimal: i32) -> ExprRef { ScalarFunction::new(Round { decimal }, vec![input]).into() } diff --git a/src/daft-functions/src/numeric/sign.rs b/src/daft-functions/src/numeric/sign.rs index a58b7f294d..4355a1bfbe 100644 --- a/src/daft-functions/src/numeric/sign.rs +++ b/src/daft-functions/src/numeric/sign.rs @@ -31,6 +31,7 @@ impl ScalarUDF for Sign { } } +#[must_use] pub fn sign(input: ExprRef) -> ExprRef { ScalarFunction::new(Sign {}, vec![input]).into() } diff --git a/src/daft-functions/src/numeric/sqrt.rs b/src/daft-functions/src/numeric/sqrt.rs index 11766e4f17..2e5ba26e7a 100644 --- a/src/daft-functions/src/numeric/sqrt.rs +++ b/src/daft-functions/src/numeric/sqrt.rs @@ -33,6 +33,7 @@ impl ScalarUDF for Sqrt { } } +#[must_use] pub fn sqrt(input: ExprRef) -> ExprRef { ScalarFunction::new(Sqrt {}, vec![input]).into() } diff --git a/src/daft-functions/src/numeric/trigonometry.rs b/src/daft-functions/src/numeric/trigonometry.rs index 9a47875596..43997c72b3 100644 --- a/src/daft-functions/src/numeric/trigonometry.rs +++ b/src/daft-functions/src/numeric/trigonometry.rs @@ -56,6 +56,7 @@ macro_rules! trigonometry { } } + #[must_use] pub fn $name(input: ExprRef) -> ExprRef { ScalarFunction::new($variant, vec![input]).into() } @@ -102,8 +103,7 @@ impl ScalarUDF for Atan2 { (dt1, dt2) if dt1.is_numeric() && dt2.is_numeric() => DataType::Float64, (dt1, dt2) => { return Err(DaftError::TypeError(format!( - "Expected inputs to atan2 to be numeric, got {} and {}", - dt1, dt2 + "Expected inputs to atan2 to be numeric, got {dt1} and {dt2}" ))) } }; @@ -121,6 +121,7 @@ impl ScalarUDF for Atan2 { } } +#[must_use] pub fn atan2(x: ExprRef, y: ExprRef) -> ExprRef { ScalarFunction::new(Atan2 {}, vec![x, y]).into() } diff --git a/src/daft-functions/src/temporal/mod.rs b/src/daft-functions/src/temporal/mod.rs index 314546fe77..a9db4dd96d 100644 --- a/src/daft-functions/src/temporal/mod.rs +++ b/src/daft-functions/src/temporal/mod.rs @@ -77,7 +77,7 @@ macro_rules! impl_temporal { } } - pub fn $dt(input: ExprRef) -> ExprRef { + #[must_use] pub fn $dt(input: ExprRef) -> ExprRef { ScalarFunction::new($name {}, vec![input]).into() } @@ -150,6 +150,7 @@ impl ScalarUDF for Time { } } +#[must_use] pub fn dt_time(input: ExprRef) -> ExprRef { ScalarFunction::new(Time {}, vec![input]).into() } @@ -182,7 +183,7 @@ mod test { (Arc::new(Year), "year"), ( Arc::new(Truncate { - interval: "".into(), + interval: String::new(), }), "truncate", ), diff --git a/src/daft-functions/src/to_struct.rs b/src/daft-functions/src/to_struct.rs index 73f390eb26..112ce0bc3a 100644 --- a/src/daft-functions/src/to_struct.rs +++ b/src/daft-functions/src/to_struct.rs @@ -49,6 +49,7 @@ impl ScalarUDF for ToStructFunction { } } +#[must_use] pub fn to_struct(inputs: Vec) -> ExprRef { ScalarFunction::new(ToStructFunction {}, inputs).into() } @@ -60,7 +61,7 @@ pub mod python { #[pyfunction] pub fn to_struct(inputs: Vec) -> PyResult { - let inputs = inputs.into_iter().map(|x| x.into()).collect(); + let inputs = inputs.into_iter().map(std::convert::Into::into).collect(); let expr = super::to_struct(inputs); Ok(expr.into()) } diff --git a/src/daft-functions/src/tokenize/bpe.rs b/src/daft-functions/src/tokenize/bpe.rs index c35e41771e..98c826f498 100644 --- a/src/daft-functions/src/tokenize/bpe.rs +++ b/src/daft-functions/src/tokenize/bpe.rs @@ -60,7 +60,10 @@ pub enum Error { impl From for DaftError { fn from(err: Error) -> Self { - use Error::*; + use Error::{ + BPECreation, BadToken, Base64Decode, Decode, EmptyTokenFile, InvalidTokenLine, + InvalidUtf8Sequence, MissingPattern, RankNumberParse, UnsupportedSpecialTokens, + }; match err { Base64Decode { .. } => Self::ValueError(err.to_string()), RankNumberParse { .. } => Self::ValueError(err.to_string()), diff --git a/src/daft-functions/src/tokenize/decode.rs b/src/daft-functions/src/tokenize/decode.rs index 30a713f993..e7a5724fdf 100644 --- a/src/daft-functions/src/tokenize/decode.rs +++ b/src/daft-functions/src/tokenize/decode.rs @@ -59,18 +59,17 @@ fn tokenize_decode_series( )? .into_series()), dt => Err(DaftError::TypeError(format!( - "Tokenize decode not implemented for type {}", - dt + "Tokenize decode not implemented for type {dt}" ))), } } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub(super) struct TokenizeDecodeFunction { - pub(super) tokens_path: String, - pub(super) io_config: Option>, - pub(super) pattern: Option, - pub(super) special_tokens: Option, +pub struct TokenizeDecodeFunction { + pub tokens_path: String, + pub io_config: Option>, + pub pattern: Option, + pub special_tokens: Option, } #[typetag::serde] diff --git a/src/daft-functions/src/tokenize/encode.rs b/src/daft-functions/src/tokenize/encode.rs index e36f9be4d2..1cb1829c9f 100644 --- a/src/daft-functions/src/tokenize/encode.rs +++ b/src/daft-functions/src/tokenize/encode.rs @@ -26,7 +26,7 @@ fn tokenize_encode_array( let mut offsets: Vec = Vec::with_capacity(arr.len() + 1); offsets.push(0); let self_arrow = arr.as_arrow(); - for s_opt in self_arrow.iter() { + for s_opt in self_arrow { if let Some(s) = s_opt { let tokens = bpe.encode(s, use_special_tokens); let tokens_iter = tokens.iter().map(|t| Some(*t)); @@ -70,12 +70,12 @@ fn tokenize_encode_series( } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub(super) struct TokenizeEncodeFunction { - pub(super) tokens_path: String, - pub(super) io_config: Option>, - pub(super) pattern: Option, - pub(super) special_tokens: Option, - pub(super) use_special_tokens: bool, +pub struct TokenizeEncodeFunction { + pub tokens_path: String, + pub io_config: Option>, + pub pattern: Option, + pub special_tokens: Option, + pub use_special_tokens: bool, } #[typetag::serde] diff --git a/src/daft-functions/src/tokenize/mod.rs b/src/daft-functions/src/tokenize/mod.rs index 8eb1aee7e1..4462e8096b 100644 --- a/src/daft-functions/src/tokenize/mod.rs +++ b/src/daft-functions/src/tokenize/mod.rs @@ -1,7 +1,7 @@ use daft_dsl::{functions::ScalarFunction, ExprRef}; use daft_io::IOConfig; -use decode::TokenizeDecodeFunction; -use encode::TokenizeEncodeFunction; +pub use decode::TokenizeDecodeFunction; +pub use encode::TokenizeEncodeFunction; mod bpe; mod decode; @@ -19,7 +19,7 @@ pub fn tokenize_encode( ScalarFunction::new( TokenizeEncodeFunction { tokens_path: tokens_path.to_string(), - io_config: io_config.map(|x| x.into()), + io_config: io_config.map(std::convert::Into::into), pattern: pattern.map(str::to_string), special_tokens: special_tokens.map(str::to_string), use_special_tokens, @@ -39,7 +39,7 @@ pub fn tokenize_decode( ScalarFunction::new( TokenizeDecodeFunction { tokens_path: tokens_path.to_string(), - io_config: io_config.map(|x| x.into()), + io_config: io_config.map(std::convert::Into::into), pattern: pattern.map(str::to_string), special_tokens: special_tokens.map(str::to_string), }, diff --git a/src/daft-functions/src/tokenize/special_tokens.rs b/src/daft-functions/src/tokenize/special_tokens.rs index c00b33e1d3..1d5a50eac3 100644 --- a/src/daft-functions/src/tokenize/special_tokens.rs +++ b/src/daft-functions/src/tokenize/special_tokens.rs @@ -15,7 +15,7 @@ fn get_llama3_tokens() -> Vec { .map(str::to_string) .collect(); for i in 5..256 { - res.push(format!("<|reserved_special_token_{}|>", i)); + res.push(format!("<|reserved_special_token_{i}|>")); } res } diff --git a/src/daft-functions/src/uri/download.rs b/src/daft-functions/src/uri/download.rs index 9f107e95c1..15ebc2f9fc 100644 --- a/src/daft-functions/src/uri/download.rs +++ b/src/daft-functions/src/uri/download.rs @@ -52,8 +52,7 @@ impl ScalarUDF for DownloadFunction { Ok(result.into_series()) } _ => Err(DaftError::TypeError(format!( - "Download can only download uris from Utf8Array, got {}", - input + "Download can only download uris from Utf8Array, got {input}" ))), }, _ => Err(DaftError::ValueError(format!( @@ -71,8 +70,7 @@ impl ScalarUDF for DownloadFunction { match &field.dtype { DataType::Utf8 => Ok(Field::new(field.name, DataType::Binary)), _ => Err(DaftError::TypeError(format!( - "Download can only download uris from Utf8Array, got {}", - field + "Download can only download uris from Utf8Array, got {field}" ))), } } @@ -108,11 +106,16 @@ fn url_download( let io_client = get_io_client(multi_thread, config)?; let owned_array = array.clone(); + + #[expect( + clippy::needless_collect, + reason = "This actually might be needed, but need to double check TODO:(andrewgazelka)" + )] let fetches = async move { let urls = owned_array .as_arrow() .into_iter() - .map(|s| s.map(|s| s.to_string())) + .map(|s| s.map(std::string::ToString::to_string)) .collect::>(); let stream = futures::stream::iter(urls.into_iter().enumerate().map(move |(i, url)| { @@ -146,20 +149,17 @@ fn url_download( let cap_needed: usize = results .iter() - .filter_map(|f| f.1.as_ref().map(|f| f.len())) + .filter_map(|f| f.1.as_ref().map(bytes::Bytes::len)) .sum(); let mut data = Vec::with_capacity(cap_needed); - for (_, b) in results.into_iter() { - match b { - Some(b) => { - data.extend(b.as_ref()); - offsets.push(b.len() as i64 + offsets.last().unwrap()); - valid.push(true); - } - None => { - offsets.push(*offsets.last().unwrap()); - valid.push(false); - } + for (_, b) in results { + if let Some(b) = b { + data.extend(b.as_ref()); + offsets.push(b.len() as i64 + offsets.last().unwrap()); + valid.push(true); + } else { + offsets.push(*offsets.last().unwrap()); + valid.push(false); } } Ok(BinaryArray::try_from((name, data, offsets))? diff --git a/src/daft-functions/src/uri/mod.rs b/src/daft-functions/src/uri/mod.rs index d06e0bb112..df67776455 100644 --- a/src/daft-functions/src/uri/mod.rs +++ b/src/daft-functions/src/uri/mod.rs @@ -6,6 +6,7 @@ use daft_dsl::{functions::ScalarFunction, ExprRef}; use download::DownloadFunction; use upload::UploadFunction; +#[must_use] pub fn download( input: ExprRef, max_connections: usize, @@ -25,6 +26,7 @@ pub fn download( .into() } +#[must_use] pub fn upload( input: ExprRef, location: &str, diff --git a/src/daft-functions/src/uri/upload.rs b/src/daft-functions/src/uri/upload.rs index d4c606955f..4ab677614c 100644 --- a/src/daft-functions/src/uri/upload.rs +++ b/src/daft-functions/src/uri/upload.rs @@ -55,7 +55,7 @@ impl ScalarUDF for UploadFunction { let data_field = data.to_field(schema)?; match data_field.dtype { DataType::Binary | DataType::FixedSizeBinary(..) | DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), - _ => Err(DaftError::TypeError(format!("Expects input to url_upload to be Binary, FixedSizeBinary or String, but received {}", data_field))), + _ => Err(DaftError::TypeError(format!("Expects input to url_upload to be Binary, FixedSizeBinary or String, but received {data_field}"))), } } _ => Err(DaftError::SchemaMismatch(format!( diff --git a/src/daft-image/src/image_buffer.rs b/src/daft-image/src/image_buffer.rs index f1595aaf1f..cab1432373 100644 --- a/src/daft-image/src/image_buffer.rs +++ b/src/daft-image/src/image_buffer.rs @@ -46,7 +46,7 @@ macro_rules! with_method_on_image_buffer { impl<'a> DaftImageBuffer<'a> { pub fn from_raw(mode: &ImageMode, width: u32, height: u32, data: Cow<'a, [u8]>) -> Self { - use DaftImageBuffer::*; + use DaftImageBuffer::{L, LA, RGB, RGBA}; match mode { ImageMode::L => L(ImageBuffer::from_raw(width, height, data).unwrap()), ImageMode::LA => LA(ImageBuffer::from_raw(width, height, data).unwrap()), @@ -64,7 +64,7 @@ impl<'a> DaftImageBuffer<'a> { } pub fn as_u8_slice(&self) -> &[u8] { - use DaftImageBuffer::*; + use DaftImageBuffer::{L, LA, RGB, RGBA}; match self { L(img) => img.as_raw(), LA(img) => img.as_raw(), @@ -74,7 +74,7 @@ impl<'a> DaftImageBuffer<'a> { } } pub fn mode(&self) -> ImageMode { - use DaftImageBuffer::*; + use DaftImageBuffer::{L, L16, LA, LA16, RGB, RGB16, RGB32F, RGBA, RGBA16, RGBA32F}; match self { L(..) => ImageMode::L, @@ -91,7 +91,7 @@ impl<'a> DaftImageBuffer<'a> { } pub fn color(&self) -> ColorType { let mode = DaftImageBuffer::mode(self); - use ImageMode::*; + use ImageMode::{L, L16, LA, LA16, RGB, RGB16, RGB32F, RGBA, RGBA16, RGBA32F}; match mode { L => ColorType::L8, LA => ColorType::La8, @@ -108,8 +108,8 @@ impl<'a> DaftImageBuffer<'a> { pub fn decode(bytes: &[u8]) -> DaftResult { image::load_from_memory(bytes) - .map(|v| v.into()) - .map_err(|e| DaftError::ValueError(format!("Decoding image from bytes failed: {}", e))) + .map(std::convert::Into::into) + .map_err(|e| DaftError::ValueError(format!("Decoding image from bytes failed: {e}"))) } pub fn encode(&self, image_format: ImageFormat, writer: &mut W) -> DaftResult<()> @@ -126,8 +126,7 @@ impl<'a> DaftImageBuffer<'a> { ) .map_err(|e| { DaftError::ValueError(format!( - "Encoding image into file format {} failed: {}", - image_format, e + "Encoding image into file format {image_format} failed: {e}" )) }) } @@ -146,7 +145,7 @@ impl<'a> DaftImageBuffer<'a> { } pub fn resize(&self, w: u32, h: u32) -> Self { - use DaftImageBuffer::*; + use DaftImageBuffer::{L, LA, RGB, RGBA}; match self { L(imgbuf) => { let result = diff --git a/src/daft-image/src/ops.rs b/src/daft-image/src/ops.rs index 6a3e2a6a75..39dc5bf886 100644 --- a/src/daft-image/src/ops.rs +++ b/src/daft-image/src/ops.rs @@ -12,7 +12,7 @@ use daft_core::{ }; use num_traits::FromPrimitive; -use crate::{iters::*, CountingWriter, DaftImageBuffer}; +use crate::{iters::ImageBufferIter, CountingWriter, DaftImageBuffer}; #[allow(clippy::len_without_is_empty)] pub trait AsImageObj { @@ -45,7 +45,7 @@ pub(crate) fn image_array_from_img_buffers( inputs: &[Option>], image_mode: &Option, ) -> DaftResult { - use DaftImageBuffer::*; + use DaftImageBuffer::{L, LA, RGB, RGBA}; let is_all_u8 = inputs .iter() .filter_map(|b| b.as_ref()) @@ -102,7 +102,7 @@ pub(crate) fn fixed_image_array_from_img_buffers( height: u32, width: u32, ) -> DaftResult { - use DaftImageBuffer::*; + use DaftImageBuffer::{L, LA, RGB, RGBA}; let is_all_u8 = inputs .iter() .filter_map(|b| b.as_ref()) @@ -112,15 +112,15 @@ pub(crate) fn fixed_image_array_from_img_buffers( let num_channels = image_mode.num_channels(); let mut data_ref = Vec::with_capacity(inputs.len()); let mut validity = arrow2::bitmap::MutableBitmap::with_capacity(inputs.len()); - let list_size = (height * width * num_channels as u32) as usize; + let list_size = (height * width * u32::from(num_channels)) as usize; let null_list = vec![0u8; list_size]; - for ib in inputs.iter() { + for ib in inputs { validity.push(ib.is_some()); let buffer = match ib { Some(ib) => ib.as_u8_slice(), None => null_list.as_slice(), }; - data_ref.push(buffer) + data_ref.push(buffer); } let data = data_ref.concat(); let validity: Option = match validity.unset_bits() { @@ -324,7 +324,7 @@ impl AsImageObj for FixedShapeImageArray { DataType::FixedShapeImage(mode, height, width) => { let arrow_array = self.physical.flat_child.downcast::().unwrap().as_arrow(); let num_channels = mode.num_channels(); - let size = height * width * num_channels as u32; + let size = height * width * u32::from(num_channels); let start = idx * size as usize; let end = (idx + 1) * size as usize; let slice_data = Cow::Borrowed(&arrow_array.values().as_slice()[start..end] as &'a [u8]); @@ -334,7 +334,7 @@ impl AsImageObj for FixedShapeImageArray { assert_eq!(result.width(), *width); Some(result) } - dt => panic!("FixedShapeImageArray should always have DataType::FixedShapeImage() as it's dtype, but got {}", dt), + dt => panic!("FixedShapeImageArray should always have DataType::FixedShapeImage() as it's dtype, but got {dt}"), } } } @@ -343,83 +343,75 @@ fn encode_images( images: &Arr, image_format: ImageFormat, ) -> DaftResult { - let arrow_array = match image_format { - ImageFormat::TIFF => { - // NOTE: A single writer/buffer can't be used for TIFF files because the encoder will overwrite the - // IFD offset for the first image instead of writing it for all subsequent images, producing corrupted - // TIFF files. We work around this by writing out a new buffer for each image. - // TODO(Clark): Fix this in the tiff crate. - let values = ImageBufferIter::new(images) - .map(|img| { - img.map(|img| { - let buf = Vec::new(); - let mut writer: CountingWriter> = - std::io::BufWriter::new(std::io::Cursor::new(buf)).into(); - img.encode(image_format, &mut writer)?; - // NOTE: BufWriter::into_inner() will flush the buffer. - Ok(writer - .into_inner() - .into_inner() - .map_err(|e| { - DaftError::ValueError(format!( - "Encoding image into file format {} failed: {}", - image_format, e - )) - })? - .into_inner()) - }) - .transpose() - }) - .collect::>>()?; - arrow2::array::BinaryArray::::from_iter(values) - } - _ => { - let mut offsets = Vec::with_capacity(images.len() + 1); - offsets.push(0i64); - let mut validity = arrow2::bitmap::MutableBitmap::with_capacity(images.len()); - let buf = Vec::new(); - let mut writer: CountingWriter> = - std::io::BufWriter::new(std::io::Cursor::new(buf)).into(); - ImageBufferIter::new(images) - .map(|img| { - match img { - Some(img) => { - img.encode(image_format, &mut writer)?; - offsets.push(writer.count() as i64); - validity.push(true); - } - None => { - offsets.push(*offsets.last().unwrap()); - validity.push(false); - } - } - Ok(()) + let arrow_array = if image_format == ImageFormat::TIFF { + // NOTE: A single writer/buffer can't be used for TIFF files because the encoder will overwrite the + // IFD offset for the first image instead of writing it for all subsequent images, producing corrupted + // TIFF files. We work around this by writing out a new buffer for each image. + // TODO(Clark): Fix this in the tiff crate. + let values = ImageBufferIter::new(images) + .map(|img| { + img.map(|img| { + let buf = Vec::new(); + let mut writer: CountingWriter> = + std::io::BufWriter::new(std::io::Cursor::new(buf)).into(); + img.encode(image_format, &mut writer)?; + // NOTE: BufWriter::into_inner() will flush the buffer. + Ok(writer + .into_inner() + .into_inner() + .map_err(|e| { + DaftError::ValueError(format!( + "Encoding image into file format {image_format} failed: {e}" + )) + })? + .into_inner()) }) - .collect::>>()?; - // NOTE: BufWriter::into_inner() will flush the buffer. - let values = writer - .into_inner() - .into_inner() - .map_err(|e| { - DaftError::ValueError(format!( - "Encoding image into file format {} failed: {}", - image_format, e - )) - })? - .into_inner(); - let encoded_data: arrow2::buffer::Buffer = values.into(); - let offsets_buffer = arrow2::offset::OffsetsBuffer::try_from(offsets)?; - let validity: Option = match validity.unset_bits() { - 0 => None, - _ => Some(validity.into()), - }; - arrow2::array::BinaryArray::::new( - arrow2::datatypes::DataType::LargeBinary, - offsets_buffer, - encoded_data, - validity, - ) - } + .transpose() + }) + .collect::>>()?; + arrow2::array::BinaryArray::::from_iter(values) + } else { + let mut offsets = Vec::with_capacity(images.len() + 1); + offsets.push(0i64); + let mut validity = arrow2::bitmap::MutableBitmap::with_capacity(images.len()); + let buf = Vec::new(); + let mut writer: CountingWriter> = + std::io::BufWriter::new(std::io::Cursor::new(buf)).into(); + ImageBufferIter::new(images) + .map(|img| { + if let Some(img) = img { + img.encode(image_format, &mut writer)?; + offsets.push(writer.count() as i64); + validity.push(true); + } else { + offsets.push(*offsets.last().unwrap()); + validity.push(false); + } + Ok(()) + }) + .collect::>>()?; + // NOTE: BufWriter::into_inner() will flush the buffer. + let values = writer + .into_inner() + .into_inner() + .map_err(|e| { + DaftError::ValueError(format!( + "Encoding image into file format {image_format} failed: {e}" + )) + })? + .into_inner(); + let encoded_data: arrow2::buffer::Buffer = values.into(); + let offsets_buffer = arrow2::offset::OffsetsBuffer::try_from(offsets)?; + let validity: Option = match validity.unset_bits() { + 0 => None, + _ => Some(validity.into()), + }; + arrow2::array::BinaryArray::::new( + arrow2::datatypes::DataType::LargeBinary, + offsets_buffer, + encoded_data, + validity, + ) }; BinaryArray::new( Field::new(images.name(), arrow_array.data_type().into()).into(), @@ -449,6 +441,7 @@ where .collect::>() } +#[must_use] pub fn image_html_value(arr: &ImageArray, idx: usize) -> String { let maybe_image = arr.as_image_obj(idx); let str_val = arr.str_value(idx).unwrap(); @@ -470,6 +463,7 @@ pub fn image_html_value(arr: &ImageArray, idx: usize) -> String { } } +#[must_use] pub fn fixed_image_html_value(arr: &FixedShapeImageArray, idx: usize) -> String { let maybe_image = arr.as_image_obj(idx); let str_val = arr.str_value(idx).unwrap(); diff --git a/src/daft-image/src/series.rs b/src/daft-image/src/series.rs index 636353768e..ae789e3d93 100644 --- a/src/daft-image/src/series.rs +++ b/src/daft-image/src/series.rs @@ -25,13 +25,12 @@ fn image_decode_impl( Err(err) => { if raise_error_on_failure { return Err(err); - } else { - log::warn!( + } + log::warn!( "Error occurred during image decoding at index: {index} {} (falling back to Null)", err ); - None - } + None } }; if let Some(mode) = mode { @@ -42,8 +41,7 @@ fn image_decode_impl( (Some(t1), Some(t2)) => { if t1 != t2 { return Err(DaftError::ValueError(format!( - "All images in a column must have the same dtype, but got: {:?} and {:?}", - t1, t2 + "All images in a column must have the same dtype, but got: {t1:?} and {t2:?}" ))); } } @@ -80,8 +78,7 @@ pub fn decode( DataType::Binary => image_decode_impl(s.binary()?, raise_error_on_failure, mode) .map(|arr| arr.into_series()), dtype => Err(DaftError::ValueError(format!( - "Decoding in-memory data into images is only supported for binary arrays, but got {}", - dtype + "Decoding in-memory data into images is only supported for binary arrays, but got {dtype}" ))), } } @@ -109,8 +106,7 @@ pub fn encode(s: &Series, image_format: ImageFormat) -> DaftResult { .encode(image_format)? .into_series()), dtype => Err(DaftError::ValueError(format!( - "Encoding images into bytes is only supported for image arrays, but got {}", - dtype + "Encoding images into bytes is only supported for image arrays, but got {dtype}" ))), } } @@ -167,13 +163,14 @@ pub fn crop(s: &Series, bbox: &Series) -> DaftResult { .downcast::()? .crop(bbox) .map(|arr| arr.into_series()), + DataType::FixedShapeImage(..) => s .fixed_size_image()? .crop(bbox) .map(|arr| arr.into_series()), + dt => Err(DaftError::ValueError(format!( - "Expected input to crop to be an Image type, but received: {}", - dt + "Expected input to crop to be an Image type, but received: {dt}" ))), } } @@ -196,8 +193,7 @@ pub fn to_mode(s: &Series, mode: ImageMode) -> DaftResult { .to_mode(mode) .map(|arr| arr.into_series()), dt => Err(DaftError::ValueError(format!( - "Expected input to crop to be an Image type, but received: {}", - dt + "Expected input to crop to be an Image type, but received: {dt}" ))), } } diff --git a/src/daft-io/src/azure_blob.rs b/src/daft-io/src/azure_blob.rs index a52092bd4e..ac77bddbfd 100644 --- a/src/daft-io/src/azure_blob.rs +++ b/src/daft-io/src/azure_blob.rs @@ -2,7 +2,7 @@ use std::{ops::Range, sync::Arc}; use async_trait::async_trait; use azure_core::{auth::TokenCredential, new_http_client}; -use azure_identity::{ClientSecretCredential, DefaultAzureCredential}; +use azure_identity::{ClientSecretCredential, DefaultAzureCredential, TokenCredentialOptions}; use azure_storage::{prelude::*, CloudLocation}; use azure_storage_blobs::{ blob::operations::GetBlobResponse, @@ -106,11 +106,11 @@ fn parse_azure_uri(uri: &str) -> super::Result<(String, Option<(String, String)> impl From for super::Error { fn from(error: Error) -> Self { - use Error::*; + use Error::{NotAFile, NotFound, UnableToOpenFile, UnableToReadBytes}; match error { UnableToReadBytes { path, source } | UnableToOpenFile { path, source } => { match source.as_http_error().map(|v| v.status().into()) { - Some(404) | Some(410) => Self::NotFound { + Some(404 | 410) => Self::NotFound { path, source: source.into(), }, @@ -138,7 +138,7 @@ impl From for super::Error { } } -pub(crate) struct AzureBlobSource { +pub struct AzureBlobSource { blob_client: Arc, } @@ -153,10 +153,11 @@ impl AzureBlobSource { return Err(Error::StorageAccountNotSet.into()); }; - let access_key = config - .access_key - .clone() - .or_else(|| std::env::var("AZURE_STORAGE_KEY").ok().map(|v| v.into())); + let access_key = config.access_key.clone().or_else(|| { + std::env::var("AZURE_STORAGE_KEY") + .ok() + .map(std::convert::Into::into) + }); let sas_token = config .sas_token .clone() @@ -184,7 +185,7 @@ impl AzureBlobSource { tenant_id.clone(), client_id.clone(), client_secret.as_string().clone(), - Default::default(), + TokenCredentialOptions::default(), ))) } else { let default_creds = Arc::new(DefaultAzureCredential::default()); @@ -216,7 +217,7 @@ impl AzureBlobSource { } else if config.use_fabric_endpoint { ClientBuilder::with_location( CloudLocation::Custom { - uri: format!("https://{}.blob.fabric.microsoft.com", storage_account), + uri: format!("https://{storage_account}.blob.fabric.microsoft.com"), }, storage_credentials, ) @@ -250,7 +251,7 @@ impl AzureBlobSource { responses_stream .map(move |response| { if let Some(is) = io_stats.clone() { - is.mark_list_requests(1) + is.mark_list_requests(1); } (response, protocol.clone()) }) @@ -294,7 +295,7 @@ impl AzureBlobSource { "{}{AZURE_DELIMITER}", prefix.trim_end_matches(&AZURE_DELIMITER) ); - let full_path = format!("{}://{}{}", protocol, container_name, prefix); + let full_path = format!("{protocol}://{container_name}{prefix}"); let full_path_with_trailing_delimiter = format!( "{}://{}{}", protocol, container_name, &prefix_with_delimiter @@ -333,10 +334,10 @@ impl AzureBlobSource { // Make sure the stream is pollable even if empty, // since we will chain it later with the two items we already popped. - let unchecked_results = if !stream_exhausted { - unchecked_results - } else { + let unchecked_results = if stream_exhausted { futures::stream::iter(vec![]).boxed() + } else { + unchecked_results }; match &maybe_first_two_items[..] { @@ -430,7 +431,7 @@ impl AzureBlobSource { responses_stream .map(move |response| { if let Some(is) = io_stats.clone() { - is.mark_list_requests(1) + is.mark_list_requests(1); } (response, protocol.clone(), container_name.clone()) }) @@ -528,7 +529,7 @@ impl ObjectSource for AzureBlobSource { .into() }); if let Some(is) = io_stats.as_ref() { - is.mark_get_requests(1) + is.mark_get_requests(1); } Ok(GetResult::Stream( io_stats_on_bytestream(Box::pin(stream), io_stats), @@ -565,7 +566,7 @@ impl ObjectSource for AzureBlobSource { .await .context(UnableToOpenFileSnafu:: { path: uri.into() })?; if let Some(is) = io_stats.as_ref() { - is.mark_head_requests(1) + is.mark_head_requests(1); } Ok(metadata.blob.properties.content_length as usize) diff --git a/src/daft-io/src/google_cloud.rs b/src/daft-io/src/google_cloud.rs index fe399ab3ec..d74484fa27 100644 --- a/src/daft-io/src/google_cloud.rs +++ b/src/daft-io/src/google_cloud.rs @@ -52,13 +52,16 @@ enum Error { impl From for super::Error { fn from(error: Error) -> Self { - use Error::*; + use Error::{ + InvalidUrl, NotAFile, NotFound, UnableToListObjects, UnableToLoadCredentials, + UnableToOpenFile, UnableToReadBytes, + }; match error { UnableToReadBytes { path, source } | UnableToOpenFile { path, source } | UnableToListObjects { path, source } => match source { GError::HttpClient(err) => match err.status().map(|s| s.as_u16()) { - Some(404) | Some(410) => Self::NotFound { + Some(404 | 410) => Self::NotFound { path, source: err.into(), }, @@ -164,7 +167,7 @@ impl GCSClientWrapper { .into() }); if let Some(is) = io_stats.as_ref() { - is.mark_get_requests(1) + is.mark_get_requests(1); } Ok(GetResult::Stream( io_stats_on_bytestream(response, io_stats), @@ -194,7 +197,7 @@ impl GCSClientWrapper { path: uri.to_string(), })?; if let Some(is) = io_stats.as_ref() { - is.mark_head_requests(1) + is.mark_head_requests(1); } Ok(response.size as usize) } @@ -214,8 +217,8 @@ impl GCSClientWrapper { prefix: Some(key.to_string()), end_offset: None, start_offset: None, - page_token: continuation_token.map(|s| s.to_string()), - delimiter: delimiter.map(|d| d.to_string()), // returns results in "directory mode" if delimiter is provided + page_token: continuation_token.map(std::string::ToString::to_string), + delimiter: delimiter.map(std::string::ToString::to_string), // returns results in "directory mode" if delimiter is provided max_results: page_size, include_trailing_delimiter: Some(false), // This will not populate "directories" in the response's .item[] projection: None, @@ -225,10 +228,10 @@ impl GCSClientWrapper { .list_objects(&req) .await .context(UnableToListObjectsSnafu { - path: format!("{GCS_SCHEME}://{}/{}", bucket, key), + path: format!("{GCS_SCHEME}://{bucket}/{key}"), })?; if let Some(is) = io_stats.as_ref() { - is.mark_list_requests(1) + is.mark_list_requests(1); } let response_items = ls_response.items.unwrap_or_default(); @@ -239,7 +242,7 @@ impl GCSClientWrapper { filetype: FileType::File, }); let dirs = response_prefixes.iter().map(|pref| FileMetadata { - filepath: format!("{GCS_SCHEME}://{}/{}", bucket, pref), + filepath: format!("{GCS_SCHEME}://{bucket}/{pref}"), size: None, filetype: FileType::Directory, }); @@ -264,7 +267,7 @@ impl GCSClientWrapper { if posix { // Attempt to forcefully ls the key as a directory (by ensuring a "/" suffix) let forced_directory_key = if key.is_empty() { - "".to_string() + String::new() } else { format!("{}{GCS_DELIMITER}", key.trim_end_matches(GCS_DELIMITER)) }; @@ -326,7 +329,7 @@ impl GCSClientWrapper { } } -pub(crate) struct GCSSource { +pub struct GCSSource { client: GCSClientWrapper, } diff --git a/src/daft-io/src/http.rs b/src/daft-io/src/http.rs index 14571fd79f..8f754aeb5d 100644 --- a/src/daft-io/src/http.rs +++ b/src/daft-io/src/http.rs @@ -138,16 +138,16 @@ fn _get_file_metadata_from_html(path: &str, text: &str) -> super::Result for super::Error { fn from(error: Error) -> Self { - use Error::*; + use Error::{UnableToDetermineSize, UnableToOpenFile}; match error { UnableToOpenFile { path, source } => match source.status().map(|v| v.as_u16()) { - Some(404) | Some(410) => Self::NotFound { + Some(404 | 410) => Self::NotFound { path, source: source.into(), }, @@ -210,7 +210,7 @@ impl ObjectSource for HttpSource { .error_for_status() .context(UnableToOpenFileSnafu:: { path: uri.into() })?; if let Some(is) = io_stats.as_ref() { - is.mark_get_requests(1) + is.mark_get_requests(1); } let size_bytes = response.content_length().map(|s| s as usize); let stream = response.bytes_stream(); @@ -250,7 +250,7 @@ impl ObjectSource for HttpSource { .context(UnableToOpenFileSnafu:: { path: uri.into() })?; if let Some(is) = io_stats.as_ref() { - is.mark_head_requests(1) + is.mark_head_requests(1); } let headers = response.headers(); @@ -306,7 +306,7 @@ impl ObjectSource for HttpSource { .error_for_status() .with_context(|_| UnableToOpenFileSnafu { path })?; if let Some(is) = io_stats.as_ref() { - is.mark_list_requests(1) + is.mark_list_requests(1); } // Reconstruct the actual path of the request, which may have been redirected via a 301 diff --git a/src/daft-io/src/huggingface.rs b/src/daft-io/src/huggingface.rs index f10f2d8de3..9095b05f87 100644 --- a/src/daft-io/src/huggingface.rs +++ b/src/daft-io/src/huggingface.rs @@ -130,9 +130,9 @@ impl FromStr for HFPathParts { } else { return Some(Self { bucket: bucket.to_string(), - repository: format!("{}/{}", username, uri), + repository: format!("{username}/{uri}"), revision: "main".to_string(), - path: "".to_string(), + path: String::new(), }); }; @@ -145,7 +145,7 @@ impl FromStr for HFPathParts { }; // {username}/{reponame} - let repository = format!("{}/{}", username, repository); + let repository = format!("{username}/{repository}"); // {path from root} // ^--------------^ let path = uri.to_string().trim_end_matches('/').to_string(); @@ -206,7 +206,7 @@ impl HFPathParts { } } -pub(crate) struct HFSource { +pub struct HFSource { http_source: HttpSource, } @@ -218,10 +218,10 @@ impl From for HFSource { impl From for super::Error { fn from(error: Error) -> Self { - use Error::*; + use Error::{UnableToDetermineSize, UnableToOpenFile}; match error { UnableToOpenFile { path, source } => match source.status().map(|v| v.as_u16()) { - Some(404) | Some(410) => Self::NotFound { + Some(404 | 410) => Self::NotFound { path, source: source.into(), }, @@ -294,7 +294,7 @@ impl ObjectSource for HFSource { .context(UnableToConnectSnafu:: { path: uri.into() })?; let response = response.error_for_status().map_err(|e| { - if let Some(401) = e.status().map(|s| s.as_u16()) { + if e.status().map(|s| s.as_u16()) == Some(401) { Error::Unauthorized } else { Error::UnableToOpenFile { @@ -305,7 +305,7 @@ impl ObjectSource for HFSource { })?; if let Some(is) = io_stats.as_ref() { - is.mark_get_requests(1) + is.mark_get_requests(1); } let size_bytes = response.content_length().map(|s| s as usize); let stream = response.bytes_stream(); @@ -344,7 +344,7 @@ impl ObjectSource for HFSource { .await .context(UnableToConnectSnafu:: { path: uri.into() })?; let response = response.error_for_status().map_err(|e| { - if let Some(401) = e.status().map(|s| s.as_u16()) { + if e.status().map(|s| s.as_u16()) == Some(401) { Error::Unauthorized } else { Error::UnableToOpenFile { @@ -355,7 +355,7 @@ impl ObjectSource for HFSource { })?; if let Some(is) = io_stats.as_ref() { - is.mark_head_requests(1) + is.mark_head_requests(1); } let headers = response.headers(); @@ -393,7 +393,7 @@ impl ObjectSource for HFSource { // hf://datasets/user/repo // but not // hf://datasets/user/repo/file.parquet - if let Some(FileFormat::Parquet) = file_format { + if file_format == Some(FileFormat::Parquet) { let res = try_parquet_api(glob_path, limit, io_stats.clone(), &self.http_source.client).await; match res { @@ -433,7 +433,7 @@ impl ObjectSource for HFSource { })?; let response = response.error_for_status().map_err(|e| { - if let Some(401) = e.status().map(|s| s.as_u16()) { + if e.status().map(|s| s.as_u16()) == Some(401) { Error::Unauthorized } else { Error::UnableToOpenFile { @@ -444,7 +444,7 @@ impl ObjectSource for HFSource { })?; if let Some(is) = io_stats.as_ref() { - is.mark_list_requests(1) + is.mark_list_requests(1); } let response = response .json::>() @@ -527,7 +527,7 @@ async fn try_parquet_api( })?; if let Some(is) = io_stats.as_ref() { - is.mark_list_requests(1) + is.mark_list_requests(1); } // {: {: [, ...]}} @@ -541,7 +541,7 @@ async fn try_parquet_api( let files = body .into_values() - .flat_map(|splits| splits.into_values()) + .flat_map(std::collections::HashMap::into_values) .flatten() .map(|uri| { Ok(FileMetadata { @@ -551,9 +551,9 @@ async fn try_parquet_api( }) }); - return Ok(Some( + Ok(Some( stream::iter(files).take(limit.unwrap_or(16 * 1024)).boxed(), - )); + )) } else { Ok(None) } diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 8d87f5b767..745fc4065c 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -149,7 +149,10 @@ pub enum Error { impl From for DaftError { fn from(err: Error) -> Self { - use Error::*; + use Error::{ + CachedError, ConnectTimeout, MiscTransient, NotFound, ReadTimeout, SocketError, + Throttled, UnableToReadBytes, + }; match err { NotFound { path, source } => Self::FileNotFound { path, source }, ConnectTimeout { .. } => Self::ConnectTimeout(err.into()), @@ -316,16 +319,17 @@ impl IOClient { match value { Some(Ok(bytes)) => Ok(Some(bytes)), - Some(Err(err)) => match raise_error_on_failure { - true => Err(err), - false => { + Some(Err(err)) => { + if raise_error_on_failure { + Err(err) + } else { log::warn!( - "Error occurred during url_download at index: {index} {} (falling back to Null)", - err - ); + "Error occurred during url_download at index: {index} {} (falling back to Null)", + err + ); Ok(None) } - }, + } None => Ok(None), } } @@ -390,7 +394,7 @@ pub fn parse_url(input: &str) -> Result<(SourceType, Cow<'_, str>)> { let expanded = home_dir.join(&input[2..]); let input = expanded.to_str()?; - Some((SourceType::File, Cow::Owned(format!("file://{}", input)))) + Some((SourceType::File, Cow::Owned(format!("file://{input}")))) }) .ok_or_else(|| crate::Error::InvalidArgument { msg: "Could not convert expanded path to string".to_string(), @@ -447,7 +451,7 @@ pub fn get_io_client(multi_thread: bool, config: Arc) -> DaftResult() { s.clone() } else if let Some(s) = e.downcast_ref::<&str>() { - s.to_string() + (*s).to_string() } else { "unknown internal error".to_string() }; @@ -488,7 +492,7 @@ impl Runtime { }); if tx.send(task_output).is_err() { - log::warn!("Spawned task output ignored: receiver dropped") + log::warn!("Spawned task output ignored: receiver dropped"); } }); rx.recv().expect("Spawned task transmitter dropped") @@ -524,22 +528,20 @@ fn init_runtime(num_threads: usize) -> Arc { } pub fn get_runtime(multi_thread: bool) -> DaftResult { - match multi_thread { - false => { - let runtime = SINGLE_THREADED_RUNTIME - .get_or_init(|| init_runtime(1)) - .clone(); - Ok(runtime) - } - true => { - let runtime = THREADED_RUNTIME - .get_or_init(|| init_runtime(*THREADED_RUNTIME_NUM_WORKER_THREADS)) - .clone(); - Ok(runtime) - } + if !multi_thread { + let runtime = SINGLE_THREADED_RUNTIME + .get_or_init(|| init_runtime(1)) + .clone(); + Ok(runtime) + } else { + let runtime = THREADED_RUNTIME + .get_or_init(|| init_runtime(*THREADED_RUNTIME_NUM_WORKER_THREADS)) + .clone(); + Ok(runtime) } } +#[must_use] pub fn get_io_pool_num_threads() -> Option { match tokio::runtime::Handle::try_current() { Ok(handle) => { diff --git a/src/daft-io/src/local.rs b/src/daft-io/src/local.rs index 4ed9eaa54b..9525c861c7 100644 --- a/src/daft-io/src/local.rs +++ b/src/daft-io/src/local.rs @@ -28,7 +28,7 @@ use crate::{ /// as long as there is no "mix" of "\" and "/". const PATH_SEGMENT_DELIMITER: &str = "/"; -pub(crate) struct LocalSource {} +pub struct LocalSource {} #[derive(Debug, Snafu)] enum Error { @@ -82,10 +82,13 @@ enum Error { impl From for super::Error { fn from(error: Error) -> Self { - use Error::*; + use Error::{ + UnableToFetchDirectoryEntries, UnableToFetchFileMetadata, UnableToOpenFile, + UnableToOpenFileForWriting, UnableToReadBytes, UnableToWriteToFile, + }; match error { UnableToOpenFile { path, source } | UnableToFetchDirectoryEntries { path, source } => { - use std::io::ErrorKind::*; + use std::io::ErrorKind::NotFound; match source.kind() { NotFound => Self::NotFound { path, @@ -98,7 +101,7 @@ impl From for super::Error { } } UnableToFetchFileMetadata { path, source } => { - use std::io::ErrorKind::*; + use std::io::ErrorKind::{IsADirectory, NotFound}; match source.kind() { NotFound | IsADirectory => Self::NotFound { path, @@ -277,7 +280,7 @@ impl ObjectSource for LocalSource { if meta.file_type().is_file() { // Provided uri points to a file, so only return that file. return Ok(futures::stream::iter([Ok(FileMetadata { - filepath: format!("{}{}", LOCAL_PROTOCOL, uri), + filepath: format!("{LOCAL_PROTOCOL}{uri}"), size: Some(meta.len()), filetype: object_io::FileType::File, })]) @@ -334,7 +337,7 @@ impl ObjectSource for LocalSource { } } -pub(crate) async fn collect_file(local_file: LocalFile) -> Result { +pub async fn collect_file(local_file: LocalFile) -> Result { let path = &local_file.path; let mut file = tokio::fs::File::open(path) .await @@ -373,7 +376,6 @@ pub(crate) async fn collect_file(local_file: LocalFile) -> Result { } #[cfg(test)] - mod tests { use std::{default, io::Write}; diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index 32bf328f17..6a3d27b4ef 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -77,7 +77,7 @@ where impl GetResult { pub async fn bytes(self) -> super::Result { - use GetResult::*; + use GetResult::{File, Stream}; let mut get_result = self; match get_result { File(f) => collect_file(f).await, @@ -90,10 +90,10 @@ impl GetResult { let mut result = collect_bytes(stream, size, permit).await; // drop permit to ensure quota for attempt in 1..NUM_TRIES { match result { - Err(super::Error::SocketError { .. }) - | Err(super::Error::UnableToReadBytes { .. }) - if let Some(rp) = &retry_params => - { + Err( + super::Error::SocketError { .. } + | super::Error::UnableToReadBytes { .. }, + ) if let Some(rp) = &retry_params => { let jitter = rand::thread_rng() .gen_range(0..((1 << (attempt - 1)) * JITTER_MS)) as u64; @@ -123,6 +123,7 @@ impl GetResult { } } + #[must_use] pub fn with_retry(self, params: StreamingRetryParams) -> Self { match self { Self::File(..) => self, @@ -133,7 +134,7 @@ impl GetResult { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum FileType { File, Directory, @@ -148,7 +149,7 @@ impl TryFrom for FileType { } else if value.is_file() { Ok(Self::File) } else if value.is_symlink() { - Err(DaftError::InternalError(format!("Symlinks should never be encountered when constructing FileMetadata, but got: {:?}", value))) + Err(DaftError::InternalError(format!("Symlinks should never be encountered when constructing FileMetadata, but got: {value:?}"))) } else { unreachable!( "Can only be a directory, file, or symlink, but got: {:?}", @@ -158,7 +159,7 @@ impl TryFrom for FileType { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct FileMetadata { pub filepath: String, pub size: Option, @@ -173,7 +174,7 @@ pub struct LSResult { use async_stream::stream; #[async_trait] -pub(crate) trait ObjectSource: Sync + Send { +pub trait ObjectSource: Sync + Send { async fn get( &self, uri: &str, diff --git a/src/daft-io/src/object_store_glob.rs b/src/daft-io/src/object_store_glob.rs index 13b43f773c..5380d9c9d5 100644 --- a/src/daft-io/src/object_store_glob.rs +++ b/src/daft-io/src/object_store_glob.rs @@ -34,7 +34,7 @@ const MARKER_FILES: [&str; 3] = ["_metadata", "_common_metadata", "_success"]; const MARKER_PREFIXES: [&str; 2] = ["_started", "_committed"]; #[derive(Clone)] -pub(crate) struct GlobState { +pub struct GlobState { // Current path in dirtree and glob_fragments pub current_path: String, pub current_fragment_idx: usize, @@ -62,7 +62,7 @@ impl GlobState { current_path: path, current_fragment_idx: idx, current_fanout: self.current_fanout * fanout_factor, - ..self.clone() + ..self } } @@ -75,7 +75,7 @@ impl GlobState { } #[derive(Debug, Clone)] -pub(crate) struct GlobFragment { +pub struct GlobFragment { data: String, escaped_data: String, first_wildcard_idx: Option, @@ -113,16 +113,13 @@ impl GlobFragment { let mut ptr = 0; while ptr < data.len() { let remaining = &data[ptr..]; - match remaining.find(r"\\") { - Some(backslash_idx) => { - escaped_data.push_str(&remaining[..backslash_idx].replace('\\', "")); - escaped_data.extend(std::iter::once('\\')); - ptr += backslash_idx + 2; - } - None => { - escaped_data.push_str(&remaining.replace('\\', "")); - break; - } + if let Some(backslash_idx) = remaining.find(r"\\") { + escaped_data.push_str(&remaining[..backslash_idx].replace('\\', "")); + escaped_data.extend(std::iter::once('\\')); + ptr += backslash_idx + 2; + } else { + escaped_data.push_str(&remaining.replace('\\', "")); + break; } } @@ -168,7 +165,7 @@ impl GlobFragment { /// 2. Non-wildcard fragments are joined and coalesced by delimiter /// 3. The first fragment is prefixed by "{scheme}://" /// 4. Preserves any leading delimiters -pub(crate) fn to_glob_fragments(glob_str: &str) -> super::Result> { +pub fn to_glob_fragments(glob_str: &str) -> super::Result> { // NOTE: We only use the URL parse library to get the scheme, because it will escape some of our glob special characters // such as ? and {} let glob_url = url::Url::parse(glob_str).map_err(|e| super::Error::InvalidUrl { @@ -286,10 +283,7 @@ async fn ls_with_prefix_fallback( // STOP EARLY!! // If the number of directory results are more than `max_dirs`, we terminate the function early, // throw away our results buffer and return a stream of FileType::File files using `prefix_ls` instead - if max_dirs - .map(|max_dirs| dir_count_so_far > max_dirs) - .unwrap_or(false) - { + if max_dirs.is_some_and(|max_dirs| dir_count_so_far > max_dirs) { return ( prefix_ls(source.clone(), uri.to_string(), page_size, io_stats), 0, @@ -357,7 +351,7 @@ fn _should_return(fm: &FileMetadata) -> bool { /// parallel connections (usually defaulting to 64). /// * page_size: control the returned results page size, or None to use the ObjectSource's defaults. Usually only used for testing /// but may yield some performance improvements depending on the workload. -pub(crate) async fn glob( +pub async fn glob( source: Arc, glob: &str, fanout_limit: Option, @@ -385,7 +379,7 @@ pub(crate) async fn glob( } if attempt_as_dir { let mut results = source.iter_dir(glob.as_str(), true, page_size, io_stats).await?; - while let Some(result) = results.next().await && remaining_results.map(|rr| rr > 0).unwrap_or(true) { + while let Some(result) = results.next().await && remaining_results.map_or(true, |rr| rr > 0) { match result { Ok(fm) => { if _should_return(&fm) { @@ -560,7 +554,7 @@ pub(crate) async fn glob( } else if current_fragment.has_special_character() { let partial_glob_matcher = GlobBuilder::new( GlobFragment::join( - &state.glob_fragments[..state.current_fragment_idx + 1], + &state.glob_fragments[..=state.current_fragment_idx], GLOB_DELIMITER, ) .raw_str(), @@ -641,7 +635,7 @@ pub(crate) async fn glob( to_rtn_tx, source.clone(), GlobState { - current_path: "".to_string(), + current_path: String::new(), current_fragment_idx: 0, glob_fragments: Arc::new(glob_fragments), full_glob_matcher: Arc::new(full_glob_matcher), @@ -655,7 +649,7 @@ pub(crate) async fn glob( let to_rtn_stream = stream! { let mut remaining_results = limit; - while remaining_results.map(|rr| rr > 0).unwrap_or(true) && let Some(v) = to_rtn_rx.recv().await { + while remaining_results.map_or(true, |rr| rr > 0) && let Some(v) = to_rtn_rx.recv().await { if v.as_ref().is_ok_and(|v| !_should_return(v)) { continue diff --git a/src/daft-io/src/python.rs b/src/daft-io/src/python.rs index 484911b6b4..6dac52af8a 100644 --- a/src/daft-io/src/python.rs +++ b/src/daft-io/src/python.rs @@ -20,7 +20,7 @@ mod py { ) -> PyResult>> { let multithreaded_io = multithreaded_io.unwrap_or(true); let io_stats = IOStatsContext::new(format!("io_glob for {path}")); - let io_stats_handle = io_stats.clone(); + let io_stats_handle = io_stats; let lsr: DaftResult> = py.allow_threads(|| { let io_client = get_io_client( diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index e6eb829a78..1604bf0aff 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -43,7 +43,7 @@ use crate::{ const S3_DELIMITER: &str = "/"; const DEFAULT_GLOB_FANOUT_LIMIT: usize = 1024; -pub(crate) struct S3LikeSource { +pub struct S3LikeSource { region_to_client_map: tokio::sync::RwLock>>, connection_pool_sema: Arc, default_region: Region, @@ -141,7 +141,10 @@ const THROTTLING_ERRORS: &[&str] = &[ impl From for super::Error { fn from(error: Error) -> Self { - use Error::*; + use Error::{ + InvalidUrl, NotAFile, NotFound, UnableToHeadFile, UnableToListObjects, + UnableToLoadCredentials, UnableToOpenFile, UnableToReadBytes, + }; fn classify_unhandled_error< E: std::error::Error + ProvideErrorMetadata + Send + Sync + 'static, @@ -296,7 +299,7 @@ impl From for super::Error { } /// Retrieves an S3Config from the environment by leveraging the AWS SDK's credentials chain -pub(crate) async fn s3_config_from_env() -> super::Result { +pub async fn s3_config_from_env() -> super::Result { let default_s3_config = S3Config::default(); let (anonymous, s3_conf) = build_s3_conf(&default_s3_config, None).await?; let creds = s3_conf @@ -307,7 +310,7 @@ pub(crate) async fn s3_config_from_env() -> super::Result { let key_id = Some(creds.access_key_id().to_string()); let access_key = Some(creds.secret_access_key().to_string().into()); let session_token = creds.session_token().map(|t| t.to_string().into()); - let region_name = s3_conf.region().map(|r| r.to_string()); + let region_name = s3_conf.region().map(std::string::ToString::to_string); Ok(S3Config { // Do not perform auto-discovery of endpoint_url. This is possible, but requires quite a bit // of work that our current implementation of `build_s3_conf` does not yet do. See smithy-rs code: @@ -402,11 +405,7 @@ async fn build_s3_conf( .as_ref() .map(|s| s.as_string().clone()) .unwrap(), - config - .session_token - .as_ref() - .map(|s| s.as_string().clone()) - .clone(), + config.session_token.as_ref().map(|s| s.as_string().clone()), ); Some(aws_credential_types::provider::SharedCredentialsProvider::new(creds)) } else if config.access_key.is_some() || config.key_id.is_some() { @@ -442,7 +441,7 @@ async fn build_s3_conf( CredentialsCache::lazy_builder() .buffer_time(Duration::from_secs(*buffer_time)) .into_credentials_cache(), - ) + ); } loader.load().await @@ -481,7 +480,7 @@ async fn build_s3_conf( } else if retry_mode.trim().eq_ignore_ascii_case("standard") { retry_config } else { - return Err(crate::Error::InvalidArgument { msg: format!("Invalid Retry Mode, Daft S3 client currently only supports standard and adaptive, got {}", retry_mode) }); + return Err(crate::Error::InvalidArgument { msg: format!("Invalid Retry Mode, Daft S3 client currently only supports standard and adaptive, got {retry_mode}") }); } } else { retry_config @@ -507,7 +506,7 @@ async fn build_s3_conf( const MAX_WAITTIME_MS: u64 = 45_000; let check_creds = async || -> super::Result { use rand::Rng; - use CredentialsError::*; + use CredentialsError::{CredentialsNotLoaded, ProviderTimedOut}; let mut attempt = 0; let first_attempt_time = std::time::Instant::now(); loop { @@ -518,22 +517,21 @@ async fn build_s3_conf( attempt += 1; match creds { Ok(_) => return Ok(false), - Err(err @ ProviderTimedOut(..)) => { + Err(err @ ProviderTimedOut(..)) => { let total_time_waited_ms: u64 = first_attempt_time.elapsed().as_millis().try_into().unwrap(); if attempt < CRED_TRIES && (total_time_waited_ms < MAX_WAITTIME_MS) { - let jitter = rand::thread_rng().gen_range(0..((1< { log::warn!("S3 Credentials not provided or found when making client for {}! Reverting to Anonymous mode. {err}", s3_conf.region().unwrap_or(&DEFAULT_REGION)); - return Ok(true) - }, + return Ok(true); + } Err(err) => Err(err), }.with_context(|_| UnableToLoadCredentialsSnafu {})?; } @@ -726,7 +724,7 @@ impl S3LikeSource { #[async_recursion] async fn _head_impl( &self, - _permit: SemaphorePermit<'async_recursion>, + permit: SemaphorePermit<'async_recursion>, uri: &str, region: &Region, ) -> super::Result { @@ -794,7 +792,7 @@ impl S3LikeSource { let new_region = Region::new(region_name); log::debug!("S3 Region of {uri} different than client {:?} vs {:?} Attempting HEAD in that region with new client", new_region, region); - self._head_impl(_permit, uri, &new_region).await + self._head_impl(permit, uri, &new_region).await } _ => Err(UnableToHeadFileSnafu { path: uri } .into_error(SdkError::ServiceError(err)) @@ -810,7 +808,7 @@ impl S3LikeSource { #[async_recursion] async fn _list_impl( &self, - _permit: SemaphorePermit<'async_recursion>, + permit: SemaphorePermit<'async_recursion>, scheme: &str, bucket: &str, key: &str, @@ -875,13 +873,15 @@ impl S3LikeSource { Ok(v) => { let dirs = v.common_prefixes(); let files = v.contents(); - let continuation_token = v.next_continuation_token().map(|s| s.to_string()); + let continuation_token = v + .next_continuation_token() + .map(std::string::ToString::to_string); let mut total_len = 0; if let Some(dirs) = dirs { - total_len += dirs.len() + total_len += dirs.len(); } if let Some(files) = files { - total_len += files.len() + total_len += files.len(); } let mut all_files = Vec::with_capacity(total_len); if let Some(dirs) = dirs { @@ -934,7 +934,7 @@ impl S3LikeSource { let new_region = Region::new(region_name); log::debug!("S3 Region of {uri} different than client {:?} vs {:?} Attempting List in that region with new client", new_region, region); self._list_impl( - _permit, + permit, scheme, bucket, key, @@ -1023,7 +1023,7 @@ impl ObjectSource for S3LikeSource { if io_stats.is_some() { if let GetResult::Stream(stream, num_bytes, permit, retry_params) = get_result { if let Some(is) = io_stats.as_ref() { - is.mark_get_requests(1) + is.mark_get_requests(1); } Ok(GetResult::Stream( io_stats_on_bytestream(stream, io_stats), @@ -1071,7 +1071,7 @@ impl ObjectSource for S3LikeSource { .context(UnableToGrabSemaphoreSnafu)?; let head_result = self._head_impl(permit, uri, &self.default_region).await?; if let Some(is) = io_stats.as_ref() { - is.mark_head_requests(1) + is.mark_head_requests(1); } Ok(head_result) } @@ -1115,7 +1115,7 @@ impl ObjectSource for S3LikeSource { // Perform a directory-based list of entries in the next level // assume its a directory first let key = if key.is_empty() { - "".to_string() + String::new() } else { format!("{}{S3_DELIMITER}", key.trim_end_matches(S3_DELIMITER)) }; @@ -1139,7 +1139,7 @@ impl ObjectSource for S3LikeSource { .await? }; if let Some(is) = io_stats.as_ref() { - is.mark_list_requests(1) + is.mark_list_requests(1); } if lsr.files.is_empty() && key.contains(S3_DELIMITER) { @@ -1163,7 +1163,7 @@ impl ObjectSource for S3LikeSource { ) .await?; if let Some(is) = io_stats.as_ref() { - is.mark_list_requests(1) + is.mark_list_requests(1); } let target_path = format!("{scheme}://{bucket}/{key}"); lsr.files.retain(|f| f.filepath == target_path); @@ -1198,7 +1198,7 @@ impl ObjectSource for S3LikeSource { .await? }; if let Some(is) = io_stats.as_ref() { - is.mark_list_requests(1) + is.mark_list_requests(1); } Ok(lsr) @@ -1208,7 +1208,6 @@ impl ObjectSource for S3LikeSource { #[cfg(test)] mod tests { - use common_io_config::S3Config; use crate::{object_io::ObjectSource, Result, S3LikeSource}; diff --git a/src/daft-io/src/stats.rs b/src/daft-io/src/stats.rs index 32aabd1b90..a4e70cf2ce 100644 --- a/src/daft-io/src/stats.rs +++ b/src/daft-io/src/stats.rs @@ -41,7 +41,7 @@ impl Drop for IOStatsContext { } } -pub(crate) struct IOStatsByteStreamContextHandle { +pub struct IOStatsByteStreamContextHandle { // do not enable Copy or Clone on this struct bytes_read: usize, inner: IOStatsRef, diff --git a/src/daft-io/src/stream_utils.rs b/src/daft-io/src/stream_utils.rs index 4ed42811d3..a18eb30e9d 100644 --- a/src/daft-io/src/stream_utils.rs +++ b/src/daft-io/src/stream_utils.rs @@ -3,7 +3,7 @@ use futures::{stream::BoxStream, StreamExt}; use crate::stats::{IOStatsByteStreamContextHandle, IOStatsRef}; -pub(crate) fn io_stats_on_bytestream( +pub fn io_stats_on_bytestream( mut s: impl futures::stream::Stream> + Unpin + std::marker::Send diff --git a/src/daft-json/src/decoding.rs b/src/daft-json/src/decoding.rs index 65f090f88b..96aad7d241 100644 --- a/src/daft-json/src/decoding.rs +++ b/src/daft-json/src/decoding.rs @@ -24,7 +24,7 @@ use simd_json::StaticNode; use crate::deserializer::Value as BorrowedValue; const JSON_NULL_VALUE: BorrowedValue = BorrowedValue::Static(StaticNode::Null); /// Deserialize chunk of JSON records into a chunk of Arrow2 arrays. -pub(crate) fn deserialize_records<'a, A: Borrow>>( +pub fn deserialize_records<'a, A: Borrow>>( records: &[A], schema: &Schema, schema_is_projection: bool, @@ -38,7 +38,7 @@ pub(crate) fn deserialize_records<'a, A: Borrow>>( for record in records { match record.borrow() { BorrowedValue::Object(record) => { - for (key, value) in record.iter() { + for (key, value) in record { let arr = results.get_mut(key.as_ref()); if let Some(arr) = arr { deserialize_into(arr, &[value]); @@ -62,7 +62,7 @@ pub(crate) fn deserialize_records<'a, A: Borrow>>( Ok(results.into_values().map(|mut ma| ma.as_box()).collect()) } -pub(crate) fn allocate_array(f: &Field, length: usize) -> Box { +pub fn allocate_array(f: &Field, length: usize) -> Box { match f.data_type() { DataType::Null => Box::new(MutableNullArray::new(DataType::Null, 0)), DataType::Int8 => Box::new(MutablePrimitiveArray::::with_capacity(length)), @@ -126,7 +126,7 @@ pub(crate) fn allocate_array(f: &Field, length: usize) -> Box } /// Deserialize `rows` by extending them into the given `target` -pub(crate) fn deserialize_into<'a, A: Borrow>>( +pub fn deserialize_into<'a, A: Borrow>>( target: &mut Box, rows: &[A], ) { @@ -134,7 +134,7 @@ pub(crate) fn deserialize_into<'a, A: Borrow>>( DataType::Null => { // TODO(Clark): Return an error if any of rows are not Value::Null. for _ in 0..rows.len() { - target.push_null() + target.push_null(); } } DataType::Boolean => generic_deserialize_into(target, rows, deserialize_boolean_into), @@ -143,17 +143,17 @@ pub(crate) fn deserialize_into<'a, A: Borrow>>( DataType::Int8 => deserialize_primitive_into::<_, i8>(target, rows), DataType::Int16 => deserialize_primitive_into::<_, i16>(target, rows), DataType::Int32 | DataType::Interval(IntervalUnit::YearMonth) => { - deserialize_primitive_into::<_, i32>(target, rows) + deserialize_primitive_into::<_, i32>(target, rows); } DataType::Date32 | DataType::Time32(_) => deserialize_date_into(target, rows), DataType::Interval(IntervalUnit::DayTime) => { unimplemented!("There is no natural representation of DayTime in JSON.") } DataType::Int64 | DataType::Duration(_) => { - deserialize_primitive_into::<_, i64>(target, rows) + deserialize_primitive_into::<_, i64>(target, rows); } DataType::Timestamp(..) | DataType::Date64 | DataType::Time64(_) => { - deserialize_datetime_into(target, rows) + deserialize_datetime_into(target, rows); } DataType::UInt8 => deserialize_primitive_into::<_, u8>(target, rows), DataType::UInt16 => deserialize_primitive_into::<_, u16>(target, rows), @@ -170,7 +170,7 @@ pub(crate) fn deserialize_into<'a, A: Borrow>>( deserialize_utf8_into, ), DataType::FixedSizeList(_, _) => { - generic_deserialize_into(target, rows, deserialize_fixed_size_list_into) + generic_deserialize_into(target, rows, deserialize_fixed_size_list_into); } DataType::List(_) => deserialize_list_into( target @@ -187,7 +187,11 @@ pub(crate) fn deserialize_into<'a, A: Borrow>>( rows, ), DataType::Struct(_) => { - generic_deserialize_into::<_, MutableStructArray>(target, rows, deserialize_struct_into) + generic_deserialize_into::<_, MutableStructArray>( + target, + rows, + deserialize_struct_into, + ); } // TODO(Clark): Add support for decimal type. // TODO(Clark): Add support for binary and large binary types. @@ -234,7 +238,7 @@ fn deserialize_utf8_into<'a, O: Offset, A: Borrow>>( match row.borrow() { BorrowedValue::String(v) => target.push(Some(v.as_ref())), BorrowedValue::Static(StaticNode::Bool(v)) => { - target.push(Some(if *v { "true" } else { "false" })) + target.push(Some(if *v { "true" } else { "false" })); } BorrowedValue::Static(node) if !matches!(node, StaticNode::Null) => { write!(scratch, "{node}").unwrap(); @@ -401,7 +405,7 @@ fn deserialize_struct_into<'a, A: Borrow>>( .collect::>(), _ => unreachable!(), }; - rows.iter().for_each(|row| { + for row in rows { match row.borrow() { BorrowedValue::Object(value) => { values.iter_mut().for_each(|(s, inner)| { @@ -416,7 +420,7 @@ fn deserialize_struct_into<'a, A: Borrow>>( target.push(false); } }; - }); + } // Then deserialize each field's JSON values buffer to the appropriate Arrow2 array. // // Column ordering invariant - this assumes that values and target.mut_values() have aligned columns; diff --git a/src/daft-json/src/deserializer.rs b/src/daft-json/src/deserializer.rs index c9342ff9ad..dabde80368 100644 --- a/src/daft-json/src/deserializer.rs +++ b/src/daft-json/src/deserializer.rs @@ -7,7 +7,7 @@ pub type Object<'value> = IndexMap, Value<'value>>; /// Borrowed JSON-DOM Value, consider using the `ValueTrait` /// to access its content #[derive(Debug, Clone)] -pub(crate) enum Value<'value> { +pub enum Value<'value> { /// Static values Static(StaticNode), /// string type diff --git a/src/daft-json/src/inference.rs b/src/daft-json/src/inference.rs index 76569aecc0..0d88515036 100644 --- a/src/daft-json/src/inference.rs +++ b/src/daft-json/src/inference.rs @@ -12,7 +12,7 @@ use crate::deserializer::{Object, Value as BorrowedValue}; const ITEM_NAME: &str = "item"; /// Infer Arrow2 schema from JSON Value record. -pub(crate) fn infer_records_schema(record: &BorrowedValue) -> Result { +pub fn infer_records_schema(record: &BorrowedValue) -> Result { let fields = match record { BorrowedValue::Object(record) => record .iter() @@ -97,7 +97,7 @@ fn infer_array(values: &[BorrowedValue]) -> Result { /// Convert each column's set of inferred dtypes to a field with a consolidated dtype, following the coercion rules /// defined in coerce_data_type. -pub(crate) fn column_types_map_to_fields( +pub fn column_types_map_to_fields( column_types: IndexMap>, ) -> Vec { column_types @@ -116,7 +116,7 @@ pub(crate) fn column_types_map_to_fields( /// * Lists and scalars are coerced to a list of a compatible scalar /// * Structs contain the union of all fields /// * All other types are coerced to `Utf8` -pub(crate) fn coerce_data_type(mut datatypes: HashSet) -> DataType { +pub fn coerce_data_type(mut datatypes: HashSet) -> DataType { // Drop null dtype from the dtype set. datatypes.remove(&DataType::Null); diff --git a/src/daft-json/src/local.rs b/src/daft-json/src/local.rs index 224c94f24f..d5c9828921 100644 --- a/src/daft-json/src/local.rs +++ b/src/daft-json/src/local.rs @@ -117,9 +117,9 @@ impl<'a> JsonReader<'a> { let mut total_rows = 128; if let Some((mean, std)) = get_line_stats_json(bytes, self.sample_size) { - let line_length_upper_bound = mean + 1.1 * std; + let line_length_upper_bound = 1.1f32.mul_add(std, mean); - total_rows = (bytes.len() as f32 / (mean - 0.01 * std)) as usize; + total_rows = (bytes.len() as f32 / 0.01f32.mul_add(-std, mean)) as usize; if let Some(n_rows) = self.n_rows { total_rows = std::cmp::min(n_rows, total_rows); // the guessed upper bound of the no. of bytes in the file @@ -127,7 +127,7 @@ impl<'a> JsonReader<'a> { if n_bytes < bytes.len() { if let Some(pos) = next_line_position(&bytes[n_bytes..]) { - bytes = &bytes[..n_bytes + pos] + bytes = &bytes[..n_bytes + pos]; } } } @@ -197,7 +197,7 @@ impl<'a> JsonReader<'a> { match v { Value::Object(record) => { - for (s, inner) in columns.iter_mut() { + for (s, inner) in &mut columns { match record.get(s) { Some(value) => { deserialize_into(inner, &[value]); @@ -225,10 +225,7 @@ impl<'a> JsonReader<'a> { .zip(daft_fields) .map(|(mut ma, fld)| { let arr = ma.as_box(); - Series::try_from_field_and_arrow_array( - fld.clone(), - cast_array_for_daft_if_needed(arr), - ) + Series::try_from_field_and_arrow_array(fld, cast_array_for_daft_if_needed(arr)) }) .collect::>>()?; @@ -368,8 +365,8 @@ fn get_line_stats_json(bytes: &[u8], n_lines: usize) -> Option<(f32, f32)> { let n_samples = lengths.len(); let mean = (n_read as f32) / (n_samples as f32); let mut std = 0.0; - for &len in lengths.iter() { - std += (len as f32 - mean).pow(2.0) + for &len in &lengths { + std += (len as f32 - mean).pow(2.0); } std = (std / n_samples as f32).sqrt(); Some((mean, std)) @@ -463,7 +460,7 @@ mod tests { #[test] fn test_infer_schema_empty() { - let json = r#""#; + let json = r""; let result = infer_schema(json.as_bytes(), None, None); let expected_schema = ArrowSchema::from(vec![]); diff --git a/src/daft-json/src/options.rs b/src/daft-json/src/options.rs index be045e16fa..f9ae79cf51 100644 --- a/src/daft-json/src/options.rs +++ b/src/daft-json/src/options.rs @@ -83,7 +83,7 @@ impl JsonConvertOptions { Self::new_internal( limit, include_columns, - schema.map(|s| s.into()), + schema.map(std::convert::Into::into), predicate.map(|p| p.expr), ) } diff --git a/src/daft-json/src/read.rs b/src/daft-json/src/read.rs index 7396e6ca04..ba9933a46b 100644 --- a/src/daft-json/src/read.rs +++ b/src/daft-json/src/read.rs @@ -78,7 +78,7 @@ pub fn read_json_bulk( // Launch a read task per URI, throttling the number of concurrent file reads to num_parallel tasks. let task_stream = futures::stream::iter(uris.iter().map(|uri| { let (uri, convert_options, parse_options, read_options, io_client, io_stats) = ( - uri.to_string(), + (*uri).to_string(), convert_options.clone(), parse_options.clone(), read_options.clone(), @@ -164,7 +164,7 @@ pub(crate) fn tables_concat(mut tables: Vec

) -> DaftResult
{ Table::new_with_size( first_table.schema.clone(), new_series, - tables.iter().map(|t| t.len()).sum(), + tables.iter().map(daft_table::Table::len).sum(), ) } @@ -205,7 +205,7 @@ async fn read_json_single_into_table( let required_columns_for_predicate = get_required_columns(predicate); for rc in required_columns_for_predicate { if include_columns.iter().all(|c| c.as_str() != rc.as_str()) { - include_columns.push(rc) + include_columns.push(rc); } } } @@ -312,7 +312,7 @@ pub async fn stream_json( let required_columns_for_predicate = get_required_columns(predicate); for rc in required_columns_for_predicate { if include_columns.iter().all(|c| c.as_str() != rc.as_str()) { - include_columns.push(rc) + include_columns.push(rc); } } } @@ -595,7 +595,7 @@ mod tests { // Get consolidated schema from parsed JSON. let mut column_types: IndexMap> = IndexMap::new(); - parsed.iter().for_each(|record| { + for record in &parsed { let schema = infer_records_schema(record).unwrap(); for field in schema.fields { match column_types.entry(field.name) { @@ -609,7 +609,7 @@ mod tests { } } } - }); + } let fields = column_types_map_to_fields(column_types); let schema: arrow2::datatypes::Schema = fields.into(); // Apply projection to schema. @@ -673,7 +673,7 @@ mod tests { let file = format!( "{}/test/iris_tiny.jsonl{}", env!("CARGO_MANIFEST_DIR"), - compression.map_or("".to_string(), |ext| format!(".{}", ext)) + compression.map_or(String::new(), |ext| format!(".{}", ext)) ); let mut io_config = IOConfig::default(); @@ -1193,7 +1193,7 @@ mod tests { ) -> DaftResult<()> { let file = format!( "s3://daft-public-data/test_fixtures/json-dev/iris_tiny.jsonl{}", - compression.map_or("".to_string(), |ext| format!(".{}", ext)) + compression.map_or(String::new(), |ext| format!(".{}", ext)) ); let mut io_config = IOConfig::default(); diff --git a/src/daft-json/src/schema.rs b/src/daft-json/src/schema.rs index 5a8e37aa85..e867c513c5 100644 --- a/src/daft-json/src/schema.rs +++ b/src/daft-json/src/schema.rs @@ -81,7 +81,7 @@ pub async fn read_json_schema_bulk( let result = runtime_handle .block_on_current_thread(async { let task_stream = futures::stream::iter(uris.iter().map(|uri| { - let owned_string = uri.to_string(); + let owned_string = (*uri).to_string(); let owned_client = io_client.clone(); let owned_io_stats = io_stats.clone(); let owned_parse_options = parse_options.clone(); @@ -231,14 +231,14 @@ mod tests { let file = format!( "{}/test/iris_tiny.jsonl{}", env!("CARGO_MANIFEST_DIR"), - compression.map_or("".to_string(), |ext| format!(".{}", ext)) + compression.map_or(String::new(), |ext| format!(".{}", ext)) ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let schema = read_json_schema(file.as_ref(), None, None, io_client.clone(), None)?; + let schema = read_json_schema(file.as_ref(), None, None, io_client, None)?; assert_eq!( schema, Schema::new(vec![ @@ -323,7 +323,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let schema = read_json_schema(file.as_ref(), None, None, io_client.clone(), None)?; + let schema = read_json_schema(file.as_ref(), None, None, io_client, None)?; assert_eq!( schema, Schema::new(vec![ @@ -349,7 +349,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let schema = read_json_schema(file.as_ref(), None, None, io_client.clone(), None)?; + let schema = read_json_schema(file.as_ref(), None, None, io_client, None)?; assert_eq!( schema, Schema::new(vec![ @@ -374,7 +374,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let schema = read_json_schema(file.as_ref(), None, Some(100), io_client.clone(), None)?; + let schema = read_json_schema(file.as_ref(), None, Some(100), io_client, None)?; assert_eq!( schema, Schema::new(vec![ @@ -416,14 +416,14 @@ mod tests { ) -> DaftResult<()> { let file = format!( "s3://daft-public-data/test_fixtures/json-dev/iris_tiny.jsonl{}", - compression.map_or("".to_string(), |ext| format!(".{}", ext)) + compression.map_or(String::new(), |ext| format!(".{}", ext)) ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let schema = read_json_schema(file.as_ref(), None, None, io_client.clone(), None)?; + let schema = read_json_schema(file.as_ref(), None, None, io_client, None)?; assert_eq!( schema, Schema::new(vec![ diff --git a/src/daft-local-execution/src/channel.rs b/src/daft-local-execution/src/channel.rs index 4bc6fd1f5c..f16e3bd061 100644 --- a/src/daft-local-execution/src/channel.rs +++ b/src/daft-local-execution/src/channel.rs @@ -19,19 +19,16 @@ pub struct PipelineChannel { impl PipelineChannel { pub fn new(buffer_size: usize, in_order: bool) -> Self { - match in_order { - true => { - let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip(); - let sender = PipelineSender::InOrder(RoundRobinSender::new(senders)); - let receiver = PipelineReceiver::InOrder(RoundRobinReceiver::new(receivers)); - Self { sender, receiver } - } - false => { - let (sender, receiver) = create_channel(buffer_size); - let sender = PipelineSender::OutOfOrder(sender); - let receiver = PipelineReceiver::OutOfOrder(receiver); - Self { sender, receiver } - } + if in_order { + let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip(); + let sender = PipelineSender::InOrder(RoundRobinSender::new(senders)); + let receiver = PipelineReceiver::InOrder(RoundRobinReceiver::new(receivers)); + Self { sender, receiver } + } else { + let (sender, receiver) = create_channel(buffer_size); + let sender = PipelineSender::OutOfOrder(sender); + let receiver = PipelineReceiver::OutOfOrder(receiver); + Self { sender, receiver } } } diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index 13e79b5ede..525e308ebe 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -19,7 +19,7 @@ enum AntiSemiProbeState { impl AntiSemiProbeState { fn set_table(&mut self, table: &Arc) { - if let Self::Building = self { + if matches!(self, Self::Building) { *self = Self::ReadyToProbe(table.clone()); } else { panic!("AntiSemiProbeState should only be in Building state when setting table") @@ -57,7 +57,7 @@ impl AntiSemiProbeOperator { fn probe_anti_semi( &self, input: &Arc, - state: &mut AntiSemiProbeState, + state: &AntiSemiProbeState, ) -> DaftResult> { let probe_set = state.get_probeable(); @@ -102,30 +102,23 @@ impl IntermediateOperator for AntiSemiProbeOperator { input: &PipelineResultType, state: Option<&mut Box>, ) -> DaftResult { - match idx { - 0 => { - let state = state - .expect("AntiSemiProbeOperator should have state") - .as_any_mut() - .downcast_mut::() - .expect("AntiSemiProbeOperator state should be AntiSemiProbeState"); - let (probe_table, _) = input.as_probe_table(); - state.set_table(probe_table); - Ok(IntermediateOperatorResult::NeedMoreInput(None)) - } - _ => { - let state = state - .expect("AntiSemiProbeOperator should have state") - .as_any_mut() - .downcast_mut::() - .expect("AntiSemiProbeOperator state should be AntiSemiProbeState"); - let input = input.as_data(); - let out = match self.join_type { - JoinType::Semi | JoinType::Anti => self.probe_anti_semi(input, state), - _ => unreachable!("Only Semi and Anti joins are supported"), - }?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) - } + let state = state + .expect("AntiSemiProbeOperator should have state") + .as_any_mut() + .downcast_mut::() + .expect("AntiSemiProbeOperator state should be AntiSemiProbeState"); + + if idx == 0 { + let (probe_table, _) = input.as_probe_table(); + state.set_table(probe_table); + Ok(IntermediateOperatorResult::NeedMoreInput(None)) + } else { + let input = input.as_data(); + let out = match self.join_type { + JoinType::Semi | JoinType::Anti => self.probe_anti_semi(input, state), + _ => unreachable!("Only Semi and Anti joins are supported"), + }?; + Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) } } diff --git a/src/daft-local-execution/src/intermediate_ops/buffer.rs b/src/daft-local-execution/src/intermediate_ops/buffer.rs index 67b17c5380..3c66301610 100644 --- a/src/daft-local-execution/src/intermediate_ops/buffer.rs +++ b/src/daft-local-execution/src/intermediate_ops/buffer.rs @@ -1,4 +1,8 @@ -use std::{cmp::Ordering::*, collections::VecDeque, sync::Arc}; +use std::{ + cmp::Ordering::{Equal, Greater, Less}, + collections::VecDeque, + sync::Arc, +}; use common_error::DaftResult; use daft_micropartition::MicroPartition; @@ -57,8 +61,13 @@ impl OperatorBuffer { self.curr_len -= self.threshold; match to_concat.len() { 1 => Ok(to_concat.pop().unwrap()), - _ => MicroPartition::concat(&to_concat.iter().map(|x| x.as_ref()).collect::>()) - .map(Arc::new), + _ => MicroPartition::concat( + &to_concat + .iter() + .map(std::convert::AsRef::as_ref) + .collect::>(), + ) + .map(Arc::new), } } @@ -67,9 +76,14 @@ impl OperatorBuffer { return None; } - let concated = - MicroPartition::concat(&self.buffer.iter().map(|x| x.as_ref()).collect::>()) - .map(Arc::new); + let concated = MicroPartition::concat( + &self + .buffer + .iter() + .map(std::convert::AsRef::as_ref) + .collect::>(), + ) + .map(Arc::new); self.buffer.clear(); self.curr_len = 0; Some(concated) diff --git a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs index 0a037dc6bb..dd53b9eac4 100644 --- a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs @@ -21,7 +21,7 @@ enum HashJoinProbeState { impl HashJoinProbeState { fn set_table(&mut self, table: &Arc, tables: &Arc>) { - if let Self::Building = self { + if matches!(self, Self::Building) { *self = Self::ReadyToProbe(table.clone(), tables.clone()); } else { panic!("HashJoinProbeState should only be in Building state when setting table") @@ -98,7 +98,7 @@ impl HashJoinProbeOperator { fn probe_inner( &self, input: &Arc, - state: &mut HashJoinProbeState, + state: &HashJoinProbeState, ) -> DaftResult> { let (probe_table, tables) = state.get_probeable_and_table(); @@ -161,7 +161,7 @@ impl HashJoinProbeOperator { fn probe_left_right( &self, input: &Arc, - state: &mut HashJoinProbeState, + state: &HashJoinProbeState, ) -> DaftResult> { let (probe_table, tables) = state.get_probeable_and_table(); @@ -170,7 +170,7 @@ impl HashJoinProbeOperator { let mut build_side_growable = GrowableTable::new( &tables.iter().collect::>(), true, - tables.iter().map(|t| t.len()).sum(), + tables.iter().map(daft_table::Table::len).sum(), )?; let input_tables = input.get_tables()?; @@ -233,33 +233,28 @@ impl IntermediateOperator for HashJoinProbeOperator { input: &PipelineResultType, state: Option<&mut Box>, ) -> DaftResult { - match idx { - 0 => { - let state = state - .expect("HashJoinProbeOperator should have state") - .as_any_mut() - .downcast_mut::() - .expect("HashJoinProbeOperator state should be HashJoinProbeState"); - let (probe_table, tables) = input.as_probe_table(); - state.set_table(probe_table, tables); - Ok(IntermediateOperatorResult::NeedMoreInput(None)) - } - _ => { - let state = state - .expect("HashJoinProbeOperator should have state") - .as_any_mut() - .downcast_mut::() - .expect("HashJoinProbeOperator state should be HashJoinProbeState"); - let input = input.as_data(); - let out = match self.join_type { - JoinType::Inner => self.probe_inner(input, state), - JoinType::Left | JoinType::Right => self.probe_left_right(input, state), - _ => { - unimplemented!("Only Inner, Left, and Right joins are supported in HashJoinProbeOperator") - } - }?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) - } + let state = state + .expect("HashJoinProbeOperator should have state") + .as_any_mut() + .downcast_mut::() + .expect("HashJoinProbeOperator state should be HashJoinProbeState"); + + if idx == 0 { + let (probe_table, tables) = input.as_probe_table(); + state.set_table(probe_table, tables); + Ok(IntermediateOperatorResult::NeedMoreInput(None)) + } else { + let input = input.as_data(); + let out = match self.join_type { + JoinType::Inner => self.probe_inner(input, state), + JoinType::Left | JoinType::Right => self.probe_left_right(input, state), + _ => { + unimplemented!( + "Only Inner, Left, and Right joins are supported in HashJoinProbeOperator" + ) + } + }?; + Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) } } diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index abb5c5388b..7b0267c7c0 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -36,7 +36,7 @@ pub trait IntermediateOperator: Send + Sync { } } -pub(crate) struct IntermediateNode { +pub struct IntermediateNode { intermediate_op: Arc, children: Vec>, runtime_stats: Arc, @@ -138,7 +138,7 @@ impl IntermediateNode { let mut buffer = OperatorBuffer::new(morsel_size); while let Some(morsel) = receiver.recv().await { if morsel.should_broadcast() { - for worker_sender in worker_senders.iter() { + for worker_sender in &worker_senders { let _ = worker_sender.send((idx, morsel.clone())).await; } } else { @@ -166,13 +166,11 @@ impl TreeDisplay for IntermediateNode { use std::fmt::Write; let mut display = String::new(); writeln!(display, "{}", self.intermediate_op.name()).unwrap(); - use common_display::DisplayLevel::*; - match level { - Compact => {} - _ => { - let rt_result = self.runtime_stats.result(); - rt_result.display(&mut display, true, true, true).unwrap(); - } + use common_display::DisplayLevel::Compact; + if matches!(level, Compact) { + } else { + let rt_result = self.runtime_stats.result(); + rt_result.display(&mut display, true, true, true).unwrap(); } display } @@ -184,7 +182,10 @@ impl TreeDisplay for IntermediateNode { impl PipelineNode for IntermediateNode { fn children(&self) -> Vec<&dyn PipelineNode> { - self.children.iter().map(|v| v.as_ref()).collect() + self.children + .iter() + .map(std::convert::AsRef::as_ref) + .collect() } fn name(&self) -> &'static str { @@ -197,7 +198,7 @@ impl PipelineNode for IntermediateNode { runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result { let mut child_result_receivers = Vec::with_capacity(self.children.len()); - for child in self.children.iter_mut() { + for child in &mut self.children { let child_result_channel = child.start(maintain_order, runtime_handle)?; child_result_receivers .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats)); diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index 2968c2b04b..b7809b4126 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -20,6 +20,7 @@ pub struct ExecutionRuntimeHandle { } impl ExecutionRuntimeHandle { + #[must_use] pub fn new(default_morsel_size: usize) -> Self { Self { worker_set: tokio::task::JoinSet::new(), @@ -44,6 +45,7 @@ impl ExecutionRuntimeHandle { self.worker_set.shutdown().await; } + #[must_use] pub fn default_morsel_size(&self) -> usize { self.default_morsel_size } diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 46e1b02628..53cb8e1215 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -84,7 +84,7 @@ pub trait PipelineNode: Sync + Send + TreeDisplay { fn as_tree_display(&self) -> &dyn TreeDisplay; } -pub(crate) fn viz_pipeline(root: &dyn PipelineNode) -> String { +pub fn viz_pipeline(root: &dyn PipelineNode) -> String { let mut output = String::new(); let mut visitor = MermaidDisplayVisitor::new( &mut output, @@ -158,7 +158,7 @@ pub fn physical_plan_to_pipeline( first_stage_aggs .values() .cloned() - .map(|e| Arc::new(Expr::Agg(e.clone()))) + .map(|e| Arc::new(Expr::Agg(e))) .collect(), vec![], ); @@ -170,7 +170,7 @@ pub fn physical_plan_to_pipeline( second_stage_aggs .values() .cloned() - .map(|e| Arc::new(Expr::Agg(e.clone()))) + .map(|e| Arc::new(Expr::Agg(e))) .collect(), vec![], ); @@ -196,7 +196,7 @@ pub fn physical_plan_to_pipeline( first_stage_aggs .values() .cloned() - .map(|e| Arc::new(Expr::Agg(e.clone()))) + .map(|e| Arc::new(Expr::Agg(e))) .collect(), group_by.clone(), ); @@ -212,7 +212,7 @@ pub fn physical_plan_to_pipeline( second_stage_aggs .values() .cloned() - .map(|e| Arc::new(Expr::Agg(e.clone()))) + .map(|e| Arc::new(Expr::Agg(e))) .collect(), group_by.clone(), ); @@ -265,7 +265,7 @@ pub fn physical_plan_to_pipeline( let probe_schema = probe_child.schema(); let probe_node = || -> DaftResult<_> { let common_join_keys: IndexSet<_> = get_common_join_keys(left_on, right_on) - .map(|k| k.to_string()) + .map(std::string::ToString::to_string) .collect(); let build_key_fields = build_on .iter() @@ -300,8 +300,7 @@ pub fn physical_plan_to_pipeline( .collect::>(); // we should move to a builder pattern - let build_sink = - HashJoinBuildSink::new(key_schema.clone(), casted_build_on, join_type)?; + let build_sink = HashJoinBuildSink::new(key_schema, casted_build_on, join_type)?; let build_child_node = physical_plan_to_pipeline(build_child, psets)?; let build_node = BlockingSinkNode::new(build_sink.boxed(), build_child_node).boxed(); diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 38d7c3e479..a89d06b2f8 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -82,7 +82,7 @@ impl NativeExecutor { part_id, parts .into_iter() - .map(|part| part.into()) + .map(std::convert::Into::into) .collect::>>(), ) }) @@ -130,7 +130,7 @@ pub fn run_local( .thread_name_fn(|| { static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst); - format!("Executor-Worker-{}", id) + format!("Executor-Worker-{id}") }) .build() .expect("Failed to create tokio runtime"); @@ -159,7 +159,7 @@ pub fn run_local( .duration_since(UNIX_EPOCH) .expect("Time went backwards") .as_millis(); - let file_name = format!("explain-analyze-{}-mermaid.md", curr_ms); + let file_name = format!("explain-analyze-{curr_ms}-mermaid.md"); let mut file = File::create(file_name)?; writeln!(file, "```mermaid\n{}\n```", viz_pipeline(pipeline.as_ref()))?; } @@ -187,7 +187,7 @@ pub fn run_local( .join() .expect("Execution engine thread panicked"); match join_result { - Ok(_) => None, + Ok(()) => None, Err(e) => Some(Err(e)), } } else { diff --git a/src/daft-local-execution/src/runtime_stats.rs b/src/daft-local-execution/src/runtime_stats.rs index 7489a8fd36..de1f657273 100644 --- a/src/daft-local-execution/src/runtime_stats.rs +++ b/src/daft-local-execution/src/runtime_stats.rs @@ -13,14 +13,14 @@ use crate::{ }; #[derive(Default)] -pub(crate) struct RuntimeStatsContext { +pub struct RuntimeStatsContext { rows_received: AtomicU64, rows_emitted: AtomicU64, cpu_us: AtomicU64, } #[derive(Debug)] -pub(crate) struct RuntimeStats { +pub struct RuntimeStats { pub rows_received: u64, pub rows_emitted: u64, pub cpu_us: u64, @@ -53,7 +53,7 @@ impl RuntimeStats { if cpu_time { let tms = (self.cpu_us as f32) / 1000f32; - writeln!(w, "CPU Time = {:.2}ms", tms)?; + writeln!(w, "CPU Time = {tms:.2}ms")?; } Ok(()) @@ -108,7 +108,7 @@ impl RuntimeStatsContext { } } -pub(crate) struct CountingSender { +pub struct CountingSender { sender: Sender, rt: Arc, } @@ -124,7 +124,9 @@ impl CountingSender { ) -> Result<(), SendError> { let len = match v { PipelineResultType::Data(ref mp) => mp.len(), - PipelineResultType::ProbeTable(_, ref tables) => tables.iter().map(|t| t.len()).sum(), + PipelineResultType::ProbeTable(_, ref tables) => { + tables.iter().map(daft_table::Table::len).sum() + } }; self.sender.send(v).await?; self.rt.mark_rows_emitted(len as u64); @@ -132,7 +134,7 @@ impl CountingSender { } } -pub(crate) struct CountingReceiver { +pub struct CountingReceiver { receiver: PipelineReceiver, rt: Arc, } @@ -148,7 +150,7 @@ impl CountingReceiver { let len = match v { PipelineResultType::Data(ref mp) => mp.len(), PipelineResultType::ProbeTable(_, ref tables) => { - tables.iter().map(|t| t.len()).sum() + tables.iter().map(daft_table::Table::len).sum() } }; self.rt.mark_rows_received(len as u64); diff --git a/src/daft-local-execution/src/sinks/aggregate.rs b/src/daft-local-execution/src/sinks/aggregate.rs index eae85a3f21..e94ff7c68b 100644 --- a/src/daft-local-execution/src/sinks/aggregate.rs +++ b/src/daft-local-execution/src/sinks/aggregate.rs @@ -52,8 +52,12 @@ impl BlockingSink for AggregateSink { !parts.is_empty(), "We can not finalize AggregateSink with no data" ); - let concated = - MicroPartition::concat(&parts.iter().map(|x| x.as_ref()).collect::>())?; + let concated = MicroPartition::concat( + &parts + .iter() + .map(std::convert::AsRef::as_ref) + .collect::>(), + )?; let agged = Arc::new(concated.agg(&self.agg_exprs, &self.group_by)?); self.state = AggregateState::Done(agged.clone()); Ok(Some(agged.into())) diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 8894db503d..dc38e1df34 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -23,7 +23,7 @@ pub trait BlockingSink: Send + Sync { fn name(&self) -> &'static str; } -pub(crate) struct BlockingSinkNode { +pub struct BlockingSinkNode { // use a RW lock op: Arc>>, name: &'static str, @@ -51,13 +51,11 @@ impl TreeDisplay for BlockingSinkNode { use std::fmt::Write; let mut display = String::new(); writeln!(display, "{}", self.name()).unwrap(); - use common_display::DisplayLevel::*; - match level { - Compact => {} - _ => { - let rt_result = self.runtime_stats.result(); - rt_result.display(&mut display, true, true, true).unwrap(); - } + use common_display::DisplayLevel::Compact; + if matches!(level, Compact) { + } else { + let rt_result = self.runtime_stats.result(); + rt_result.display(&mut display, true, true, true).unwrap(); } display } @@ -96,9 +94,10 @@ impl PipelineNode for BlockingSinkNode { let span = info_span!("BlockingSinkNode::execute"); let mut guard = op.lock().await; while let Some(val) = child_results_receiver.recv().await { - if let BlockingSinkStatus::Finished = - rt_context.in_span(&span, || guard.sink(val.as_data()))? - { + if matches!( + rt_context.in_span(&span, || guard.sink(val.as_data()))?, + BlockingSinkStatus::Finished + ) { break; } } diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index 5f84045101..3af65702cd 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -76,7 +76,7 @@ impl ProbeTableState { } } -pub(crate) struct HashJoinBuildSink { +pub struct HashJoinBuildSink { probe_table_state: ProbeTableState, } diff --git a/src/daft-local-execution/src/sinks/limit.rs b/src/daft-local-execution/src/sinks/limit.rs index 91435961d5..40b4d1538f 100644 --- a/src/daft-local-execution/src/sinks/limit.rs +++ b/src/daft-local-execution/src/sinks/limit.rs @@ -35,7 +35,7 @@ impl StreamingSink for LimitSink { let input_num_rows = input.len(); - use std::cmp::Ordering::*; + use std::cmp::Ordering::{Equal, Greater, Less}; match input_num_rows.cmp(&self.remaining) { Less => { self.remaining -= input_num_rows; diff --git a/src/daft-local-execution/src/sinks/sort.rs b/src/daft-local-execution/src/sinks/sort.rs index 86d951fd83..169ea9e55d 100644 --- a/src/daft-local-execution/src/sinks/sort.rs +++ b/src/daft-local-execution/src/sinks/sort.rs @@ -50,8 +50,12 @@ impl BlockingSink for SortSink { !parts.is_empty(), "We can not finalize SortSink with no data" ); - let concated = - MicroPartition::concat(&parts.iter().map(|x| x.as_ref()).collect::>())?; + let concated = MicroPartition::concat( + &parts + .iter() + .map(std::convert::AsRef::as_ref) + .collect::>(), + )?; let sorted = Arc::new(concated.sort(&self.sort_by, &self.descending)?); self.state = SortState::Done(sorted.clone()); Ok(Some(sorted.into())) diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 5b188c4ad8..f18a7efca0 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -27,7 +27,7 @@ pub trait StreamingSink: Send + Sync { fn name(&self) -> &'static str; } -pub(crate) struct StreamingSinkNode { +pub struct StreamingSinkNode { // use a RW lock op: Arc>>, name: &'static str, @@ -55,13 +55,11 @@ impl TreeDisplay for StreamingSinkNode { use std::fmt::Write; let mut display = String::new(); writeln!(display, "{}", self.name()).unwrap(); - use common_display::DisplayLevel::*; - match level { - Compact => {} - _ => { - let rt_result = self.runtime_stats.result(); - rt_result.display(&mut display, true, true, true).unwrap(); - } + use common_display::DisplayLevel::Compact; + if matches!(level, Compact) { + } else { + let rt_result = self.runtime_stats.result(); + rt_result.display(&mut display, true, true, true).unwrap(); } display } @@ -75,7 +73,10 @@ impl TreeDisplay for StreamingSinkNode { impl PipelineNode for StreamingSinkNode { fn children(&self) -> Vec<&dyn PipelineNode> { - self.children.iter().map(|v| v.as_ref()).collect() + self.children + .iter() + .map(std::convert::AsRef::as_ref) + .collect() } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index 58ca336b94..55df4a66e1 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -65,19 +65,18 @@ impl Source for ScanTaskSource { runtime_handle: &mut ExecutionRuntimeHandle, io_stats: IOStatsRef, ) -> crate::Result> { - let (senders, receivers): (Vec<_>, Vec<_>) = match maintain_order { - true => (0..self.scan_tasks.len()) + let (senders, receivers): (Vec<_>, Vec<_>) = if maintain_order { + (0..self.scan_tasks.len()) .map(|_| create_channel(1)) - .unzip(), - false => { - let (sender, receiver) = create_channel(self.scan_tasks.len()); - ( - std::iter::repeat(sender) - .take(self.scan_tasks.len()) - .collect(), - vec![receiver], - ) - } + .unzip() + } else { + let (sender, receiver) = create_channel(self.scan_tasks.len()); + ( + std::iter::repeat(sender) + .take(self.scan_tasks.len()) + .collect(), + vec![receiver], + ) }; for (scan_task, sender) in self.scan_tasks.iter().zip(senders) { runtime_handle.spawn( @@ -104,18 +103,18 @@ async fn stream_scan_task( io_stats: Option, maintain_order: bool, ) -> DaftResult>> + Send> { - let pushdown_columns = scan_task - .pushdowns - .columns - .as_ref() - .map(|v| v.iter().map(|s| s.as_str()).collect::>()); + let pushdown_columns = scan_task.pushdowns.columns.as_ref().map(|v| { + v.iter() + .map(std::string::String::as_str) + .collect::>() + }); let file_column_names = match ( pushdown_columns, scan_task.partition_spec().map(|ps| ps.to_fill_map()), ) { (None, _) => None, - (Some(columns), None) => Some(columns.to_vec()), + (Some(columns), None) => Some(columns.clone()), // If the ScanTask has a partition_spec, we elide reads of partition columns from the file (Some(columns), Some(partition_fillmap)) => Some( @@ -209,10 +208,10 @@ async fn stream_scan_task( scan_task.pushdowns.limit, file_column_names .as_ref() - .map(|cols| cols.iter().map(|col| col.to_string()).collect()), + .map(|cols| cols.iter().map(|col| (*col).to_string()).collect()), col_names .as_ref() - .map(|cols| cols.iter().map(|col| col.to_string()).collect()), + .map(|cols| cols.iter().map(|col| (*col).to_string()).collect()), Some(schema_of_file), scan_task.pushdowns.filters.clone(), ); @@ -244,7 +243,7 @@ async fn stream_scan_task( scan_task.pushdowns.limit, file_column_names .as_ref() - .map(|cols| cols.iter().map(|col| col.to_string()).collect()), + .map(|cols| cols.iter().map(|col| (*col).to_string()).collect()), Some(schema_of_file), scan_task.pushdowns.filters.clone(), ); @@ -311,7 +310,7 @@ async fn stream_scan_task( .as_ref(), )?; let mp = Arc::new(MicroPartition::new_loaded( - scan_task.materialized_schema().clone(), + scan_task.materialized_schema(), Arc::new(vec![casted_table]), scan_task.statistics.clone(), )); diff --git a/src/daft-local-execution/src/sources/source.rs b/src/daft-local-execution/src/sources/source.rs index 175dc66427..8c55401db2 100644 --- a/src/daft-local-execution/src/sources/source.rs +++ b/src/daft-local-execution/src/sources/source.rs @@ -12,7 +12,7 @@ use crate::{ pub type SourceStream<'a> = BoxStream<'a, Arc>; -pub(crate) trait Source: Send + Sync { +pub trait Source: Send + Sync { fn name(&self) -> &'static str; fn get_data( &self, @@ -33,22 +33,20 @@ impl TreeDisplay for SourceNode { use std::fmt::Write; let mut display = String::new(); writeln!(display, "{}", self.name()).unwrap(); - use common_display::DisplayLevel::*; - match level { - Compact => {} - _ => { - let rt_result = self.runtime_stats.result(); + use common_display::DisplayLevel::Compact; + if matches!(level, Compact) { + } else { + let rt_result = self.runtime_stats.result(); - writeln!(display).unwrap(); - rt_result.display(&mut display, false, true, false).unwrap(); - let bytes_read = self.io_stats.load_bytes_read(); - writeln!( - display, - "bytes read = {}", - bytes_to_human_readable(bytes_read) - ) - .unwrap(); - } + writeln!(display).unwrap(); + rt_result.display(&mut display, false, true, false).unwrap(); + let bytes_read = self.io_stats.load_bytes_read(); + writeln!( + display, + "bytes read = {}", + bytes_to_human_readable(bytes_read) + ) + .unwrap(); } display } diff --git a/src/daft-micropartition/src/lib.rs b/src/daft-micropartition/src/lib.rs index 1a01f4e933..c677a0fd96 100644 --- a/src/daft-micropartition/src/lib.rs +++ b/src/daft-micropartition/src/lib.rs @@ -1,5 +1,6 @@ #![feature(let_chains)] #![feature(iterator_try_reduce)] +#![feature(iterator_try_collect)] use common_error::DaftError; use snafu::Snafu; diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 5149722a1e..105fd561e8 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -1,7 +1,6 @@ use std::{ collections::{BTreeMap, HashMap, HashSet}, fmt::Display, - ops::Deref, sync::{Arc, Mutex}, }; @@ -30,7 +29,7 @@ use {crate::PyIOSnafu, common_file_formats::DatabaseSourceConfig}; use crate::{DaftCSVSnafu, DaftCoreComputeSnafu}; #[derive(Debug)] -pub(crate) enum TableState { +pub enum TableState { Unloaded(Arc), Loaded(Arc>), } @@ -45,14 +44,14 @@ impl Display for TableState { scan_task .sources .iter() - .map(|s| s.get_path()) + .map(daft_scan::DataSource::get_path) .collect::>() ) } Self::Loaded(tables) => { writeln!(f, "TableState: Loaded. {} tables", tables.len())?; for tab in tables.iter() { - writeln!(f, "{}", tab)?; + writeln!(f, "{tab}")?; } Ok(()) } @@ -97,26 +96,23 @@ fn materialize_scan_task( scan_task: Arc, io_stats: Option, ) -> crate::Result<(Vec
, SchemaRef)> { - let pushdown_columns = scan_task - .pushdowns - .columns - .as_ref() - .map(|v| v.iter().map(|s| s.as_str()).collect::>()); + let pushdown_columns = scan_task.pushdowns.columns.as_ref().map(|v| { + v.iter() + .map(std::string::String::as_str) + .collect::>() + }); let file_column_names = _get_file_column_names(pushdown_columns.as_deref(), scan_task.partition_spec()); - let urls = scan_task.sources.iter().map(|s| s.get_path()); + let urls = scan_task + .sources + .iter() + .map(daft_scan::DataSource::get_path); let mut table_values = match scan_task.storage_config.as_ref() { StorageConfig::Native(native_storage_config) => { let multithreaded_io = native_storage_config.multithreaded_io; - let io_config = Arc::new( - native_storage_config - .io_config - .as_ref() - .cloned() - .unwrap_or_default(), - ); + let io_config = Arc::new(native_storage_config.io_config.clone().unwrap_or_default()); let io_client = daft_io::get_io_client(multithreaded_io, io_config).unwrap(); match scan_task.file_format_config.as_ref() { @@ -141,7 +137,7 @@ fn materialize_scan_task( let iceberg_delete_files = scan_task .sources .iter() - .flat_map(|s| s.get_iceberg_delete_files()) + .filter_map(|s| s.get_iceberg_delete_files()) .flatten() .map(String::as_str) .collect::>() @@ -172,7 +168,7 @@ fn materialize_scan_task( scan_task.pushdowns.limit, row_groups, scan_task.pushdowns.filters.clone(), - io_client.clone(), + io_client, io_stats, num_parallel_tasks, multithreaded_io, @@ -205,10 +201,10 @@ fn materialize_scan_task( scan_task.pushdowns.limit, file_column_names .as_ref() - .map(|cols| cols.iter().map(|col| col.to_string()).collect()), + .map(|cols| cols.iter().map(|col| (*col).to_string()).collect()), col_names .as_ref() - .map(|cols| cols.iter().map(|col| col.to_string()).collect()), + .map(|cols| cols.iter().map(|col| (*col).to_string()).collect()), Some(schema_of_file), scan_task.pushdowns.filters.clone(), ); @@ -247,7 +243,7 @@ fn materialize_scan_task( scan_task.pushdowns.limit, file_column_names .as_ref() - .map(|cols| cols.iter().map(|col| col.to_string()).collect()), + .map(|cols| cols.iter().map(|col| (*col).to_string()).collect()), Some(scan_task.schema.clone()), scan_task.pushdowns.filters.clone(), ); @@ -306,7 +302,7 @@ fn materialize_scan_task( .map(|cols| cols.as_ref().clone()), scan_task.pushdowns.limit, ) - .map(|t| t.into()) + .map(std::convert::Into::into) .context(PyIOSnafu) }) .collect::>>() @@ -333,7 +329,7 @@ fn materialize_scan_task( .map(|cols| cols.as_ref().clone()), scan_task.pushdowns.limit, ) - .map(|t| t.into()) + .map(std::convert::Into::into) .context(PyIOSnafu) }) .collect::>>() @@ -352,7 +348,7 @@ fn materialize_scan_task( .map(|cols| cols.as_ref().clone()), scan_task.pushdowns.limit, ) - .map(|t| t.into()) + .map(std::convert::Into::into) .context(PyIOSnafu) }) .collect::>>() @@ -377,7 +373,7 @@ fn materialize_scan_task( .map(|cols| cols.as_ref().clone()), scan_task.pushdowns.limit, ) - .map(|t| t.into()) + .map(std::convert::Into::into) .context(PyIOSnafu)?; Ok(vec![table]) })? @@ -411,14 +407,13 @@ impl MicroPartition { /// Invariants: /// 1. Each Loaded column statistic in `statistics` must be castable to the corresponding column in the MicroPartition's schema /// 2. Creating a new MicroPartition with a ScanTask that has any filter predicates or limits is not allowed and will panic + #[must_use] pub fn new_unloaded( scan_task: Arc, metadata: TableMetadata, statistics: TableStatistics, ) -> Self { - if scan_task.pushdowns.filters.is_some() { - panic!("Cannot create unloaded MicroPartition from a ScanTask with pushdowns that have filters"); - } + assert!(scan_task.pushdowns.filters.is_none(), "Cannot create unloaded MicroPartition from a ScanTask with pushdowns that have filters"); let schema = scan_task.materialized_schema(); let fill_map = scan_task.partition_spec().map(|pspec| pspec.to_fill_map()); @@ -438,6 +433,7 @@ impl MicroPartition { /// Schema invariants: /// 1. `schema` must match each Table's schema exactly /// 2. If `statistics` is provided, each Loaded column statistic must be castable to the corresponding column in the MicroPartition's schema + #[must_use] pub fn new_loaded( schema: SchemaRef, tables: Arc>, @@ -456,7 +452,7 @@ impl MicroPartition { .cast_to_schema(schema.clone()) .expect("Statistics cannot be casted to schema") }); - let tables_len_sum = tables.iter().map(|t| t.len()).sum(); + let tables_len_sum = tables.iter().map(daft_table::Table::len).sum(); Self { schema, @@ -508,13 +504,13 @@ impl MicroPartition { let uris = scan_task .sources .iter() - .map(|s| s.get_path()) + .map(daft_scan::DataSource::get_path) .collect::>(); - let columns = scan_task - .pushdowns - .columns - .as_ref() - .map(|cols| cols.iter().map(|s| s.as_str()).collect::>()); + let columns = scan_task.pushdowns.columns.as_ref().map(|cols| { + cols.iter() + .map(std::string::String::as_str) + .collect::>() + }); let parquet_metadata = scan_task .sources .iter() @@ -524,7 +520,7 @@ impl MicroPartition { let row_groups = parquet_sources_to_row_groups(scan_task.sources.as_slice()); let mut iceberg_delete_files: HashSet<&str> = HashSet::new(); - for source in scan_task.sources.iter() { + for source in &scan_task.sources { if let Some(delete_files) = source.get_iceberg_delete_files() { iceberg_delete_files.extend(delete_files.iter().map(String::as_str)); } @@ -539,10 +535,7 @@ impl MicroPartition { row_groups, scan_task.pushdowns.filters.clone(), scan_task.partition_spec(), - cfg.io_config - .clone() - .map(|c| Arc::new(c.clone())) - .unwrap_or_default(), + cfg.io_config.clone().map(Arc::new).unwrap_or_default(), Some(io_stats), if scan_task.sources.len() == 1 { 1 } else { 128 }, // Hardcoded for to 128 bulk reads cfg.multithreaded_io, @@ -550,7 +543,7 @@ impl MicroPartition { coerce_int96_timestamp_unit, ..Default::default() }, - Some(schema.clone()), + Some(schema), field_id_mapping.clone(), parquet_metadata, chunk_size, @@ -568,8 +561,9 @@ impl MicroPartition { } } + #[must_use] pub fn empty(schema: Option) -> Self { - let schema = schema.unwrap_or(Schema::empty().into()); + let schema = schema.unwrap_or_else(|| Schema::empty().into()); Self::new_loaded(schema, Arc::new(vec![]), None) } @@ -591,15 +585,15 @@ impl MicroPartition { pub fn size_bytes(&self) -> DaftResult> { let guard = self.state.lock().unwrap(); - let size_bytes = if let TableState::Loaded(tables) = guard.deref() { + let size_bytes = if let TableState::Loaded(tables) = &*guard { let total_size: usize = tables .iter() - .map(|t| t.size_bytes()) + .map(daft_table::Table::size_bytes) .collect::>>()? .iter() .sum(); Some(total_size) - } else if let TableState::Unloaded(scan_task) = guard.deref() { + } else if let TableState::Unloaded(scan_task) = &*guard { // TODO: pass in the execution config once we have it available scan_task.estimate_in_memory_size_bytes(None) } else { @@ -610,9 +604,17 @@ impl MicroPartition { Ok(size_bytes) } + /// Retrieves tables from the MicroPartition, reading data if not already loaded. + /// + /// This method: + /// 1. Returns cached tables if already loaded. + /// 2. If unloaded, reads data from the source, caches it, and returns the new tables. + /// + /// "Reading if necessary" means I/O operations only occur for unloaded data, + /// optimizing performance by avoiding redundant reads. pub(crate) fn tables_or_read(&self, io_stats: IOStatsRef) -> crate::Result>> { let mut guard = self.state.lock().unwrap(); - match guard.deref() { + match &*guard { TableState::Unloaded(scan_task) => { let (tables, _) = materialize_scan_task(scan_task.clone(), Some(io_stats))?; let table_values = Arc::new(tables); @@ -642,7 +644,7 @@ impl MicroPartition { .context(DaftCoreComputeSnafu)?; *guard = TableState::Loaded(Arc::new(vec![new_table])); }; - if let TableState::Loaded(tables) = guard.deref() { + if let TableState::Loaded(tables) = &*guard { assert_eq!(tables.len(), 1); Ok(tables.clone()) } else { @@ -693,7 +695,7 @@ fn prune_fields_from_schema( let avail_names = schema .fields .keys() - .map(|f| f.as_str()) + .map(std::string::String::as_str) .collect::>(); let mut names_to_keep = HashSet::new(); for col_name in columns { @@ -701,8 +703,8 @@ fn prune_fields_from_schema( names_to_keep.insert(*col_name); } else { return Err(super::Error::FieldNotFound { - field: col_name.to_string(), - available_fields: avail_names.iter().map(|v| v.to_string()).collect(), + field: (*col_name).to_string(), + available_fields: avail_names.iter().map(|v| (*v).to_string()).collect(), } .into()); } @@ -731,14 +733,14 @@ fn parquet_sources_to_row_groups(sources: &[DataSource]) -> Option>(); - if row_groups.iter().any(|rgs| rgs.is_some()) { + if row_groups.iter().any(std::option::Option::is_some) { Some(row_groups) } else { None } } -pub(crate) fn read_csv_into_micropartition( +pub fn read_csv_into_micropartition( uris: &[&str], convert_options: Option, parse_options: Option, @@ -747,7 +749,7 @@ pub(crate) fn read_csv_into_micropartition( multithreaded_io: bool, io_stats: Option, ) -> DaftResult { - let io_client = daft_io::get_io_client(multithreaded_io, io_config.clone())?; + let io_client = daft_io::get_io_client(multithreaded_io, io_config)?; match uris { [] => Ok(MicroPartition::empty(None)), @@ -779,7 +781,7 @@ pub(crate) fn read_csv_into_micropartition( // Construct MicroPartition from tables and unioned schema Ok(MicroPartition::new_loaded( - unioned_schema.clone(), + unioned_schema, Arc::new(tables), None, )) @@ -787,7 +789,7 @@ pub(crate) fn read_csv_into_micropartition( } } -pub(crate) fn read_json_into_micropartition( +pub fn read_json_into_micropartition( uris: &[&str], convert_options: Option, parse_options: Option, @@ -796,7 +798,7 @@ pub(crate) fn read_json_into_micropartition( multithreaded_io: bool, io_stats: Option, ) -> DaftResult { - let io_client = daft_io::get_io_client(multithreaded_io, io_config.clone())?; + let io_client = daft_io::get_io_client(multithreaded_io, io_config)?; match uris { [] => Ok(MicroPartition::empty(None)), @@ -828,7 +830,7 @@ pub(crate) fn read_json_into_micropartition( // Construct MicroPartition from tables and unioned schema Ok(MicroPartition::new_loaded( - unioned_schema.clone(), + unioned_schema, Arc::new(tables), None, )) @@ -890,10 +892,12 @@ fn _read_delete_files( None, )?; - let mut delete_map: HashMap> = - uris.iter().map(|uri| (uri.to_string(), vec![])).collect(); + let mut delete_map: HashMap> = uris + .iter() + .map(|uri| ((*uri).to_string(), vec![])) + .collect(); - for table in tables.iter() { + for table in &tables { // values in the file_path column are guaranteed by the iceberg spec to match the full URI of the corresponding data file // https://iceberg.apache.org/spec/#position-delete-files let file_paths = table.get_column("file_path")?.downcast::()?; @@ -948,7 +952,11 @@ fn _read_parquet_into_loaded_micropartition>( }) .transpose()?; - let columns = columns.map(|cols| cols.iter().map(|c| c.as_ref()).collect::>()); + let columns = columns.map(|cols| { + cols.iter() + .map(std::convert::AsRef::as_ref) + .collect::>() + }); let file_column_names = _get_file_column_names(columns.as_deref(), partition_spec); let all_tables = read_parquet_bulk( @@ -970,15 +978,14 @@ fn _read_parquet_into_loaded_micropartition>( )?; // Prefer using the `catalog_provided_schema` but fall back onto inferred schema from Parquet files - let full_daft_schema = match catalog_provided_schema { - Some(catalog_provided_schema) => catalog_provided_schema, - None => { - let unioned_schema = all_tables - .iter() - .map(|t| t.schema.clone()) - .try_reduce(|l, r| DaftResult::Ok(l.union(&r)?.into()))?; - unioned_schema.expect("we need at least 1 schema") - } + let full_daft_schema = if let Some(catalog_provided_schema) = catalog_provided_schema { + catalog_provided_schema + } else { + let unioned_schema = all_tables + .iter() + .map(|t| t.schema.clone()) + .try_reduce(|l, r| DaftResult::Ok(l.union(&r)?.into()))?; + unioned_schema.expect("we need at least 1 schema") }; let pruned_daft_schema = prune_fields_from_schema(full_daft_schema, columns.as_deref())?; @@ -999,7 +1006,7 @@ fn _read_parquet_into_loaded_micropartition>( } #[allow(clippy::too_many_arguments)] -pub(crate) fn read_parquet_into_micropartition>( +pub fn read_parquet_into_micropartition>( uris: &[&str], columns: Option<&[T]>, start_offset: Option, @@ -1120,14 +1127,13 @@ pub(crate) fn read_parquet_into_micropartition>( // by constructing an appropriate ScanTask if let Some(stats) = stats { // Prefer using the `catalog_provided_schema` but fall back onto inferred schema from Parquet files - let scan_task_daft_schema = match catalog_provided_schema { - Some(catalog_provided_schema) => catalog_provided_schema, - None => { - let unioned_schema = schemas - .into_iter() - .try_reduce(|l, r| l.union(&r).map(Arc::new))?; - unioned_schema.expect("we need at least 1 schema") - } + let scan_task_daft_schema = if let Some(catalog_provided_schema) = catalog_provided_schema { + catalog_provided_schema + } else { + let unioned_schema = schemas + .into_iter() + .try_reduce(|l, r| l.union(&r).map(Arc::new))?; + unioned_schema.expect("we need at least 1 schema") }; // Get total number of rows, accounting for selected `row_groups` and the indicated `num_rows` @@ -1145,11 +1151,11 @@ pub(crate) fn read_parquet_into_micropartition>( }) .sum(), }; - let total_rows = num_rows - .map(|num_rows| num_rows.min(total_rows_no_limit)) - .unwrap_or(total_rows_no_limit); + let total_rows = num_rows.map_or(total_rows_no_limit, |num_rows| { + num_rows.min(total_rows_no_limit) + }); - let owned_urls = uris.iter().map(|s| s.to_string()).collect::>(); + let owned_urls = uris.iter().map(|s| (*s).to_string()); let size_bytes = metadata .iter() .map(|m| -> u64 { @@ -1245,7 +1251,7 @@ impl Display for MicroPartition { writeln!(f, "MicroPartition with {} rows:", self.len())?; - match guard.deref() { + match &*guard { TableState::Unloaded(..) => { writeln!(f, "{}\n{}", self.schema, guard)?; } @@ -1253,12 +1259,12 @@ impl Display for MicroPartition { if tables.len() == 0 { writeln!(f, "{}", self.schema)?; } - writeln!(f, "{}", guard)?; + writeln!(f, "{guard}")?; } }; match &self.statistics { - Some(t) => writeln!(f, "Statistics\n{}", t)?, + Some(t) => writeln!(f, "Statistics\n{t}")?, None => writeln!(f, "Statistics: missing")?, } diff --git a/src/daft-micropartition/src/ops/cast_to_schema.rs b/src/daft-micropartition/src/ops/cast_to_schema.rs index 1612a83eae..96b4b1b9af 100644 --- a/src/daft-micropartition/src/ops/cast_to_schema.rs +++ b/src/daft-micropartition/src/ops/cast_to_schema.rs @@ -1,4 +1,4 @@ -use std::{ops::Deref, sync::Arc}; +use std::sync::Arc; use common_error::DaftResult; use daft_core::prelude::SchemaRef; @@ -16,7 +16,7 @@ impl MicroPartition { .transpose()?; let guard = self.state.lock().unwrap(); - match guard.deref() { + match &*guard { // Replace schema if Unloaded, which should be applied when data is lazily loaded TableState::Unloaded(scan_task) => { let maybe_new_scan_task = if scan_task.schema == schema { diff --git a/src/daft-micropartition/src/ops/concat.rs b/src/daft-micropartition/src/ops/concat.rs index 682f75f4ce..2108cc01e3 100644 --- a/src/daft-micropartition/src/ops/concat.rs +++ b/src/daft-micropartition/src/ops/concat.rs @@ -30,7 +30,7 @@ impl MicroPartition { let mut all_tables = vec![]; - for m in mps.iter() { + for m in mps { let tables = m.tables_or_read(io_stats.clone())?; all_tables.extend_from_slice(tables.as_slice()); } @@ -45,7 +45,7 @@ impl MicroPartition { all_stats = Some(curr_stats.union(stats)?); } } - let new_len = all_tables.iter().map(|t| t.len()).sum(); + let new_len = all_tables.iter().map(daft_table::Table::len).sum(); Ok(Self { schema: mps.first().unwrap().schema.clone(), diff --git a/src/daft-micropartition/src/ops/eval_expressions.rs b/src/daft-micropartition/src/ops/eval_expressions.rs index 8ac5966a2e..14baff4c37 100644 --- a/src/daft-micropartition/src/ops/eval_expressions.rs +++ b/src/daft-micropartition/src/ops/eval_expressions.rs @@ -16,7 +16,7 @@ fn infer_schema(exprs: &[ExprRef], schema: &Schema) -> DaftResult { .collect::>>()?; let mut seen: HashSet = HashSet::new(); - for field in fields.iter() { + for field in &fields { let name = &field.name; if seen.contains(name) { return Err(DaftError::ValueError(format!( @@ -33,16 +33,18 @@ impl MicroPartition { let io_stats = IOStatsContext::new("MicroPartition::eval_expression_list"); let expected_schema = infer_schema(exprs, &self.schema)?; + let tables = self.tables_or_read(io_stats)?; - let evaluated_tables = tables + + let evaluated_tables: Vec<_> = tables .iter() - .map(|t| t.eval_expression_list(exprs)) - .collect::>>()?; + .map(|table| table.eval_expression_list(exprs)) + .try_collect()?; let eval_stats = self .statistics .as_ref() - .map(|s| s.eval_expression_list(exprs, &expected_schema)) + .map(|table_statistics| table_statistics.eval_expression_list(exprs, &expected_schema)) .transpose()?; Ok(Self::new_loaded( @@ -63,7 +65,7 @@ impl MicroPartition { let expected_new_columns = infer_schema(exprs, &self.schema)?; let eval_stats = if let Some(stats) = &self.statistics { let mut new_stats = stats.columns.clone(); - for (name, _) in expected_new_columns.fields.iter() { + for (name, _) in &expected_new_columns.fields { if let Some(v) = new_stats.get_mut(name) { *v = ColumnRangeStatistics::Missing; } else { @@ -77,7 +79,7 @@ impl MicroPartition { let mut expected_schema = Schema::new(self.schema.fields.values().cloned().collect::>())?; - for (name, field) in expected_new_columns.fields.into_iter() { + for (name, field) in expected_new_columns.fields { if let Some(v) = expected_schema.fields.get_mut(&name) { *v = field; } else { diff --git a/src/daft-micropartition/src/ops/filter.rs b/src/daft-micropartition/src/ops/filter.rs index a097192f30..e555c267af 100644 --- a/src/daft-micropartition/src/ops/filter.rs +++ b/src/daft-micropartition/src/ops/filter.rs @@ -16,7 +16,7 @@ impl MicroPartition { let folded_expr = predicate .iter() .cloned() - .reduce(|a, b| a.and(b)) + .reduce(daft_dsl::Expr::and) .expect("should have at least 1 expr"); let eval_result = statistics.eval_expression(&folded_expr)?; let tv = eval_result.to_truth_value(); diff --git a/src/daft-micropartition/src/ops/join.rs b/src/daft-micropartition/src/ops/join.rs index bac67f12db..0d671d0fe3 100644 --- a/src/daft-micropartition/src/ops/join.rs +++ b/src/daft-micropartition/src/ops/join.rs @@ -24,12 +24,9 @@ impl MicroPartition { { let join_schema = infer_join_schema(&self.schema, &right.schema, left_on, right_on, how)?; match (how, self.len(), right.len()) { - (JoinType::Inner, 0, _) - | (JoinType::Inner, _, 0) - | (JoinType::Left, 0, _) - | (JoinType::Right, _, 0) - | (JoinType::Outer, 0, 0) - | (JoinType::Semi, 0, _) => { + (JoinType::Inner | JoinType::Left | JoinType::Semi, 0, _) + | (JoinType::Inner | JoinType::Right, _, 0) + | (JoinType::Outer, 0, 0) => { return Ok(Self::empty(Some(join_schema))); } _ => {} @@ -49,7 +46,7 @@ impl MicroPartition { .values() .zip(r_eval_stats.columns.values()) { - if let TruthValue::False = lc.equal(rc)?.to_truth_value() { + if lc.equal(rc)?.to_truth_value() == TruthValue::False { curr_tv = TruthValue::False; break; } @@ -57,7 +54,7 @@ impl MicroPartition { curr_tv } }; - if let TruthValue::False = tv { + if tv == TruthValue::False { return Ok(Self::empty(Some(join_schema))); } } diff --git a/src/daft-micropartition/src/ops/partition.rs b/src/daft-micropartition/src/ops/partition.rs index 8ca24e276f..14f20bd11d 100644 --- a/src/daft-micropartition/src/ops/partition.rs +++ b/src/daft-micropartition/src/ops/partition.rs @@ -12,7 +12,10 @@ fn transpose2(v: Vec>) -> Vec> { return v; } let len = v[0].len(); - let mut iters: Vec<_> = v.into_iter().map(|n| n.into_iter()).collect(); + let mut iters: Vec<_> = v + .into_iter() + .map(std::iter::IntoIterator::into_iter) + .collect(); (0..len) .map(|_| { iters diff --git a/src/daft-micropartition/src/ops/pivot.rs b/src/daft-micropartition/src/ops/pivot.rs index 3a4ad964b9..15ff085382 100644 --- a/src/daft-micropartition/src/ops/pivot.rs +++ b/src/daft-micropartition/src/ops/pivot.rs @@ -21,7 +21,7 @@ impl MicroPartition { [] => { let empty_table = Table::empty(Some(self.schema.clone()))?; let pivoted = empty_table.pivot(group_by, pivot_col, values_col, names)?; - Ok(Self::empty(Some(pivoted.schema.clone()))) + Ok(Self::empty(Some(pivoted.schema))) } [t] => { let pivoted = t.pivot(group_by, pivot_col, values_col, names)?; diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 72f9dab86c..2a4ac6eb18 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -1,7 +1,4 @@ -use std::{ - ops::Deref, - sync::{Arc, Mutex}, -}; +use std::sync::{Arc, Mutex}; use common_error::DaftResult; use daft_core::{ @@ -162,7 +159,8 @@ impl PyMicroPartition { } pub fn eval_expression_list(&self, py: Python, exprs: Vec) -> PyResult { - let converted_exprs: Vec = exprs.into_iter().map(|e| e.into()).collect(); + let converted_exprs: Vec = + exprs.into_iter().map(std::convert::Into::into).collect(); py.allow_threads(|| { Ok(self .inner @@ -176,7 +174,8 @@ impl PyMicroPartition { } pub fn filter(&self, py: Python, exprs: Vec) -> PyResult { - let converted_exprs: Vec = exprs.into_iter().map(|e| e.into()).collect(); + let converted_exprs: Vec = + exprs.into_iter().map(std::convert::Into::into).collect(); py.allow_threads(|| Ok(self.inner.filter(converted_exprs.as_slice())?.into())) } @@ -186,8 +185,10 @@ impl PyMicroPartition { sort_keys: Vec, descending: Vec, ) -> PyResult { - let converted_exprs: Vec = - sort_keys.into_iter().map(|e| e.into()).collect(); + let converted_exprs: Vec = sort_keys + .into_iter() + .map(std::convert::Into::into) + .collect(); py.allow_threads(|| { Ok(self .inner @@ -202,8 +203,10 @@ impl PyMicroPartition { sort_keys: Vec, descending: Vec, ) -> PyResult { - let converted_exprs: Vec = - sort_keys.into_iter().map(|e| e.into()).collect(); + let converted_exprs: Vec = sort_keys + .into_iter() + .map(std::convert::Into::into) + .collect(); py.allow_threads(|| { Ok(self .inner @@ -214,9 +217,9 @@ impl PyMicroPartition { pub fn agg(&self, py: Python, to_agg: Vec, group_by: Vec) -> PyResult { let converted_to_agg: Vec = - to_agg.into_iter().map(|e| e.into()).collect(); + to_agg.into_iter().map(std::convert::Into::into).collect(); let converted_group_by: Vec = - group_by.into_iter().map(|e| e.into()).collect(); + group_by.into_iter().map(std::convert::Into::into).collect(); py.allow_threads(|| { Ok(self .inner @@ -234,7 +237,7 @@ impl PyMicroPartition { names: Vec, ) -> PyResult { let converted_group_by: Vec = - group_by.into_iter().map(|e| e.into()).collect(); + group_by.into_iter().map(std::convert::Into::into).collect(); let converted_pivot_col: daft_dsl::ExprRef = pivot_col.into(); let converted_values_col: daft_dsl::ExprRef = values_col.into(); py.allow_threads(|| { @@ -258,8 +261,10 @@ impl PyMicroPartition { right_on: Vec, how: JoinType, ) -> PyResult { - let left_exprs: Vec = left_on.into_iter().map(|e| e.into()).collect(); - let right_exprs: Vec = right_on.into_iter().map(|e| e.into()).collect(); + let left_exprs: Vec = + left_on.into_iter().map(std::convert::Into::into).collect(); + let right_exprs: Vec = + right_on.into_iter().map(std::convert::Into::into).collect(); py.allow_threads(|| { Ok(self .inner @@ -281,8 +286,10 @@ impl PyMicroPartition { right_on: Vec, is_sorted: bool, ) -> PyResult { - let left_exprs: Vec = left_on.into_iter().map(|e| e.into()).collect(); - let right_exprs: Vec = right_on.into_iter().map(|e| e.into()).collect(); + let left_exprs: Vec = + left_on.into_iter().map(std::convert::Into::into).collect(); + let right_exprs: Vec = + right_on.into_iter().map(std::convert::Into::into).collect(); py.allow_threads(|| { Ok(self .inner @@ -311,9 +318,10 @@ impl PyMicroPartition { variable_name: &str, value_name: &str, ) -> PyResult { - let converted_ids: Vec = ids.into_iter().map(|e| e.into()).collect(); + let converted_ids: Vec = + ids.into_iter().map(std::convert::Into::into).collect(); let converted_values: Vec = - values.into_iter().map(|e| e.into()).collect(); + values.into_iter().map(std::convert::Into::into).collect(); py.allow_threads(|| { Ok(self .inner @@ -405,13 +413,14 @@ impl PyMicroPartition { "Can not partition into negative number of partitions: {num_partitions}" ))); } - let exprs: Vec = exprs.into_iter().map(|e| e.into()).collect(); + let exprs: Vec = + exprs.into_iter().map(std::convert::Into::into).collect(); py.allow_threads(|| { Ok(self .inner .partition_by_hash(exprs.as_slice(), num_partitions as usize)? .into_iter() - .map(|t| t.into()) + .map(std::convert::Into::into) .collect::>()) }) } @@ -438,7 +447,7 @@ impl PyMicroPartition { .inner .partition_by_random(num_partitions as usize, seed as u64)? .into_iter() - .map(|t| t.into()) + .map(std::convert::Into::into) .collect::>()) }) } @@ -450,13 +459,16 @@ impl PyMicroPartition { boundaries: &PyTable, descending: Vec, ) -> PyResult> { - let exprs: Vec = partition_keys.into_iter().map(|e| e.into()).collect(); + let exprs: Vec = partition_keys + .into_iter() + .map(std::convert::Into::into) + .collect(); py.allow_threads(|| { Ok(self .inner .partition_by_range(exprs.as_slice(), &boundaries.table, descending.as_slice())? .into_iter() - .map(|t| t.into()) + .map(std::convert::Into::into) .collect::>()) }) } @@ -466,10 +478,16 @@ impl PyMicroPartition { py: Python, partition_keys: Vec, ) -> PyResult<(Vec, Self)> { - let exprs: Vec = partition_keys.into_iter().map(|e| e.into()).collect(); + let exprs: Vec = partition_keys + .into_iter() + .map(std::convert::Into::into) + .collect(); py.allow_threads(|| { let (mps, values) = self.inner.partition_by_value(exprs.as_slice())?; - let mps = mps.into_iter().map(|m| m.into()).collect::>(); + let mps = mps + .into_iter() + .map(std::convert::Into::into) + .collect::>(); let values = values.into(); Ok((mps, values)) }) @@ -719,7 +737,7 @@ impl PyMicroPartition { PyBytes::new_bound(py, &bincode::serialize(&self.inner.statistics).unwrap()); let guard = self.inner.state.lock().unwrap(); - if let TableState::Loaded(tables) = guard.deref() { + if let TableState::Loaded(tables) = &*guard { let _from_pytable = py .import_bound(pyo3::intern!(py, "daft.table"))? .getattr(pyo3::intern!(py, "Table"))? @@ -735,7 +753,7 @@ impl PyMicroPartition { .into(), (schema_bytes, pyobjs, py_metadata_bytes, py_stats_bytes).to_object(py), )) - } else if let TableState::Unloaded(params) = guard.deref() { + } else if let TableState::Unloaded(params) = &*guard { let py_params_bytes = PyBytes::new_bound(py, &bincode::serialize(params).unwrap()); Ok(( Self::type_object_bound(py) @@ -918,7 +936,7 @@ pub fn read_pyfunc_into_table_iter( let scan_task_filters = scan_task.pushdowns.filters.clone(); let res = table_iterators .into_iter() - .flat_map(|iter| { + .filter_map(|iter| { Python::with_gil(|py| { iter.downcast_bound::(py) .expect("Function must return an iterator of tables") diff --git a/src/daft-minhash/src/minhash.rs b/src/daft-minhash/src/minhash.rs index 3a7e666de8..3228d09451 100644 --- a/src/daft-minhash/src/minhash.rs +++ b/src/daft-minhash/src/minhash.rs @@ -11,7 +11,7 @@ const SIMD_LANES: usize = 8; type S = Simd; const MERSENNE_EXP: u64 = 61; -const MAX_HASH: u64 = 0xffffffff; +const MAX_HASH: u64 = 0xffff_ffff; const MAX_HASH_SIMD: S = S::from_array([MAX_HASH; SIMD_LANES]); // Fails with probability <= 2^-58, which is good enough for hashing @@ -43,7 +43,7 @@ fn simd_rem(hh: u64, aa: &[S], bb: &[S], out: &mut [S]) { // Precalculate the SIMD vectors of the permutations, to save time. // Output of this should be passed into the `perm_simd` argument of minhash. pub fn load_simd(mut v: impl Iterator, num_hashes: usize) -> Vec { - let num_simd = (num_hashes + SIMD_LANES - 1) / SIMD_LANES; + let num_simd = num_hashes.div_ceil(SIMD_LANES); let mut out = Vec::with_capacity(num_simd); loop { @@ -71,7 +71,7 @@ pub fn minhash( seed: u32, ) -> DaftResult> { let (perm_a_simd, perm_b_simd) = perm_simd; - let num_simd = (num_hashes + SIMD_LANES - 1) / SIMD_LANES; + let num_simd = num_hashes.div_ceil(SIMD_LANES); let mut out: Vec = vec![MAX_HASH_SIMD; num_simd]; @@ -86,7 +86,7 @@ pub fn minhash( let s_bytes = s.as_bytes(); if spaces.len() < ngram_size { // hash whole string at once - hashes.push(murmurhash3_x86_32(s_bytes, seed) as u64); + hashes.push(u64::from(murmurhash3_x86_32(s_bytes, seed))); } else { for i in 0..ngram_count { // looking at the substring that starts BEFORE the current space @@ -97,7 +97,10 @@ pub fn minhash( } else { spaces[i + ngram_size - 1] }; - hashes.push(murmurhash3_x86_32(&s_bytes[start_ind..end_ind], seed) as u64); + hashes.push(u64::from(murmurhash3_x86_32( + &s_bytes[start_ind..end_ind], + seed, + ))); if hashes.len() >= SIMD_LANES { // We have enough hashes to run with SIMD let hashes_simd = S::from_slice(&hashes); @@ -113,7 +116,7 @@ pub fn minhash( } let rem_out: Vec = out .iter() - .flat_map(|x| x.as_array()) + .flat_map(std::simd::Simd::as_array) .take(num_hashes) .map(|x| *x as u32) .collect(); @@ -151,7 +154,7 @@ mod tests { let aa = vec![simd_a]; let simd_b = S::splat(33); let bb = vec![simd_b]; - let simd_out = S::splat(123456); + let simd_out = S::splat(123_456); let mut out = vec![simd_out]; simd_min(simd_h, &aa, &bb, &mut out); let out_arr = out[0].as_array(); diff --git a/src/daft-parquet/src/file.rs b/src/daft-parquet/src/file.rs index a3b36d4a34..02b57ed6f9 100644 --- a/src/daft-parquet/src/file.rs +++ b/src/daft-parquet/src/file.rs @@ -29,7 +29,7 @@ use crate::{ UnableToParseSchemaFromMetadataSnafu, UnableToRunExpressionOnStatsSnafu, }; -pub(crate) struct ParquetReaderBuilder { +pub struct ParquetReaderBuilder { pub uri: String, pub metadata: parquet2::metadata::FileMetaData, selected_columns: Option>, @@ -100,7 +100,7 @@ where } } -pub(crate) fn build_row_ranges( +pub fn build_row_ranges( limit: Option, row_start_offset: usize, row_groups: Option<&[i64]>, @@ -155,7 +155,7 @@ pub(crate) fn build_row_ranges( } else { let mut rows_to_add = limit.unwrap_or(metadata.num_rows as i64); - for (i, rg) in metadata.row_groups.iter() { + for (i, rg) in &metadata.row_groups { if (curr_row_index + rg.num_rows()) < row_start_offset { curr_row_index += rg.num_rows(); continue; @@ -297,13 +297,13 @@ impl ParquetReaderBuilder { } #[derive(Copy, Clone)] -pub(crate) struct RowGroupRange { +pub struct RowGroupRange { pub row_group_index: usize, pub start: usize, pub num_rows: usize, } -pub(crate) struct ParquetFileReader { +pub struct ParquetFileReader { uri: String, metadata: Arc, arrow_schema: arrow2::datatypes::SchemaRef, @@ -351,7 +351,7 @@ impl ParquetFileReader { .unwrap(); let columns = rg.columns(); - for field in arrow_fields.iter() { + for field in arrow_fields { let field_name = field.name.clone(); let filtered_cols = columns .iter() @@ -454,7 +454,7 @@ impl ParquetFileReader { Vec::with_capacity(filtered_columns.len()); let mut ptypes = Vec::with_capacity(filtered_columns.len()); let mut num_values = Vec::with_capacity(filtered_columns.len()); - for col in filtered_columns.into_iter() { + for col in filtered_columns { num_values.push(col.metadata().num_values as usize); ptypes.push(col.descriptor().descriptor.primitive_type.clone()); @@ -481,8 +481,10 @@ impl ParquetFileReader { let page_stream = streaming_decompression(compressed_page_stream); let pinned_stream = Box::pin(page_stream); - decompressed_iters - .push(StreamIterator::new(pinned_stream, rt_handle.clone())) + decompressed_iters.push(StreamIterator::new( + pinned_stream, + rt_handle.clone(), + )); } let arr_iter = column_iter_to_arrays( decompressed_iters, @@ -605,9 +607,9 @@ impl ParquetFileReader { let handle = tokio::task::spawn(async move { let mut range_readers = Vec::with_capacity(filtered_cols_idx.len()); - for range in needed_byte_ranges.into_iter() { + for range in needed_byte_ranges { let range_reader = ranges.get_range_reader(range).await?; - range_readers.push(Box::pin(range_reader)) + range_readers.push(Box::pin(range_reader)); } let mut decompressed_iters = @@ -643,7 +645,7 @@ impl ParquetFileReader { let page_stream = streaming_decompression(compressed_page_stream); let pinned_stream = Box::pin(page_stream); decompressed_iters - .push(StreamIterator::new(pinned_stream, rt_handle.clone())) + .push(StreamIterator::new(pinned_stream, rt_handle.clone())); } let (send, recv) = tokio::sync::oneshot::channel(); @@ -788,9 +790,9 @@ impl ParquetFileReader { let handle = tokio::task::spawn(async move { let mut range_readers = Vec::with_capacity(filtered_cols_idx.len()); - for range in needed_byte_ranges.into_iter() { + for range in needed_byte_ranges { let range_reader = ranges.get_range_reader(range).await?; - range_readers.push(Box::pin(range_reader)) + range_readers.push(Box::pin(range_reader)); } let mut decompressed_iters = @@ -827,7 +829,7 @@ impl ParquetFileReader { let page_stream = streaming_decompression(compressed_page_stream); let pinned_stream = Box::pin(page_stream); decompressed_iters - .push(StreamIterator::new(pinned_stream, rt_handle.clone())) + .push(StreamIterator::new(pinned_stream, rt_handle.clone())); } let (send, recv) = tokio::sync::oneshot::channel(); diff --git a/src/daft-parquet/src/metadata.rs b/src/daft-parquet/src/metadata.rs index c769262d8e..32c1090ddd 100644 --- a/src/daft-parquet/src/metadata.rs +++ b/src/daft-parquet/src/metadata.rs @@ -5,7 +5,7 @@ use daft_core::datatypes::Field; use daft_dsl::common_treenode::{Transformed, TreeNode, TreeNodeRecursion}; use daft_io::{IOClient, IOStatsRef}; pub use parquet2::metadata::{FileMetaData, RowGroupMetaData}; -use parquet2::{metadata::RowGroupList, read::deserialize_metadata, schema::types::ParquetType}; +use parquet2::{read::deserialize_metadata, schema::types::ParquetType}; use snafu::ResultExt; use crate::{Error, JoinSnafu, UnableToParseMetadataSnafu}; @@ -24,7 +24,7 @@ impl TreeNode for ParquetTypeWrapper { match &self.0 { ParquetType::PrimitiveType(..) => Ok(TreeNodeRecursion::Jump), ParquetType::GroupType { fields, .. } => { - for child in fields.iter() { + for child in fields { // TODO: Expensive clone here because of ParquetTypeWrapper type, can we get rid of this? match op(&Self(child.clone()))? { TreeNodeRecursion::Continue => {} @@ -105,8 +105,7 @@ fn rewrite_parquet_type_with_field_id_mapping( fields.retain(|f| { f.get_field_info() .id - .map(|field_id| field_id_mapping.contains_key(&field_id)) - .unwrap_or(false) + .is_some_and(|field_id| field_id_mapping.contains_key(&field_id)) }); } }; @@ -125,10 +124,7 @@ fn apply_field_ids_to_parquet_type( field_id_mapping: &BTreeMap, ) -> Option { let field_id = parquet_type.get_field_info().id; - if field_id - .map(|field_id| field_id_mapping.contains_key(&field_id)) - .unwrap_or(false) - { + if field_id.is_some_and(|field_id| field_id_mapping.contains_key(&field_id)) { let rewritten_pq_type = ParquetTypeWrapper(parquet_type) .transform(&|pq_type| { rewrite_parquet_type_with_field_id_mapping(pq_type, field_id_mapping) @@ -178,7 +174,7 @@ fn apply_field_ids_to_parquet_file_metadata( }) .collect::>(); - let new_row_groups_list = file_metadata + let new_row_groups = file_metadata .row_groups .into_values() .map(|rg| { @@ -207,9 +203,8 @@ fn apply_field_ids_to_parquet_file_metadata( new_total_uncompressed_size, ) }) - .collect::>(); - - let new_row_groups = RowGroupList::from_iter(new_row_groups_list.into_iter().enumerate()); + .enumerate() + .collect(); Ok(FileMetaData { row_groups: new_row_groups, diff --git a/src/daft-parquet/src/python.rs b/src/daft-parquet/src/python.rs index 2d965053c2..23b627612e 100644 --- a/src/daft-parquet/src/python.rs +++ b/src/daft-parquet/src/python.rs @@ -46,7 +46,7 @@ pub mod pylib { row_groups, predicate.map(|e| e.expr), io_client, - Some(io_stats.clone()), + Some(io_stats), multithreaded_io.unwrap_or(true), schema_infer_options, None, @@ -74,7 +74,7 @@ pub mod pylib { .into_iter() .map(|v| { v.into_iter() - .map(|a| to_py_array(py, a, pyarrow).map(|pyarray| pyarray.unbind())) + .map(|a| to_py_array(py, a, pyarrow).map(pyo3::Bound::unbind)) .collect::>>() }) .collect::>>()?; @@ -172,7 +172,7 @@ pub mod pylib { None, )? .into_iter() - .map(|v| v.into()) + .map(std::convert::Into::into) .collect()) }) } diff --git a/src/daft-parquet/src/read.rs b/src/daft-parquet/src/read.rs index 3b6c498cf6..647ad5f7bd 100644 --- a/src/daft-parquet/src/read.rs +++ b/src/daft-parquet/src/read.rs @@ -74,6 +74,7 @@ pub struct ParquetSchemaInferenceOptions { } impl ParquetSchemaInferenceOptions { + #[must_use] pub fn new(coerce_int96_timestamp_unit: Option) -> Self { let coerce_int96_timestamp_unit = coerce_int96_timestamp_unit.unwrap_or(TimeUnit::Nanoseconds); @@ -124,7 +125,7 @@ fn limit_with_delete_rows( } else { delete_rows.iter().map(|r| *r as usize).collect::>() }; - delete_rows_sorted.sort(); + delete_rows_sorted.sort_unstable(); delete_rows_sorted.dedup(); for r in delete_rows_sorted { @@ -162,7 +163,7 @@ async fn read_parquet_single( let columns_to_return = columns; let num_rows_to_return = num_rows; let mut num_rows_to_read = num_rows; - let requested_columns = columns_to_read.as_ref().map(|v| v.len()); + let requested_columns = columns_to_read.as_ref().map(std::vec::Vec::len); if let Some(ref pred) = predicate { num_rows_to_read = None; @@ -375,11 +376,11 @@ async fn stream_parquet_single( maintain_order: bool, ) -> DaftResult> + Send> { let field_id_mapping_provided = field_id_mapping.is_some(); - let columns_to_return = columns.map(|s| s.iter().map(|s| s.to_string()).collect_vec()); + let columns_to_return = columns.map(|s| s.iter().map(|s| (*s).to_string()).collect_vec()); let num_rows_to_return = num_rows; let mut num_rows_to_read = num_rows; - let mut columns_to_read = columns.map(|s| s.iter().map(|s| s.to_string()).collect_vec()); - let requested_columns = columns_to_read.as_ref().map(|v| v.len()); + let mut columns_to_read = columns.map(|s| s.iter().map(|s| (*s).to_string()).collect_vec()); + let requested_columns = columns_to_read.as_ref().map(std::vec::Vec::len); if let Some(ref pred) = predicate { num_rows_to_read = None; @@ -575,7 +576,7 @@ async fn read_parquet_single_into_arrow( let rows_per_row_groups = metadata .row_groups .values() - .map(|m| m.num_rows()) + .map(parquet2::metadata::RowGroupMetaData::num_rows) .collect::>(); let metadata_num_rows = metadata.num_rows; @@ -767,7 +768,7 @@ pub fn read_parquet_bulk>( let tables = runtime_handle .block_on_current_thread(async move { let task_stream = futures::stream::iter(uris.iter().enumerate().map(|(i, uri)| { - let uri = uri.to_string(); + let uri = (*uri).to_string(); let owned_columns = columns.clone(); let owned_row_group = row_groups.as_ref().and_then(|rgs| rgs[i].clone()); let owned_predicate = predicate.clone(); @@ -885,7 +886,7 @@ pub fn read_parquet_into_pyarrow_bulk>( let tables = runtime_handle .block_on_current_thread(async move { futures::stream::iter(uris.iter().enumerate().map(|(i, uri)| { - let uri = uri.to_string(); + let uri = (*uri).to_string(); let owned_columns = columns.clone(); let owned_row_group = row_groups.as_ref().and_then(|rgs| rgs[i].clone()); @@ -957,7 +958,7 @@ pub async fn read_parquet_metadata_bulk( field_id_mapping: Option>>, ) -> DaftResult> { let handles_iter = uris.iter().map(|uri| { - let owned_string = uri.to_string(); + let owned_string = (*uri).to_string(); let owned_client = io_client.clone(); let owned_io_stats = io_stats.clone(); let owned_field_id_mapping = field_id_mapping.clone(); @@ -997,7 +998,7 @@ pub fn read_parquet_statistics( let values = path_array.as_arrow(); let handles_iter = values.iter().map(|uri| { - let owned_string = uri.map(|v| v.to_string()); + let owned_string = uri.map(std::string::ToString::to_string); let owned_client = io_client.clone(); let io_stats = io_stats.clone(); let owned_field_id_mapping = field_id_mapping.clone(); diff --git a/src/daft-parquet/src/read_planner.rs b/src/daft-parquet/src/read_planner.rs index aca3b3c870..85b67ac70c 100644 --- a/src/daft-parquet/src/read_planner.rs +++ b/src/daft-parquet/src/read_planner.rs @@ -68,7 +68,7 @@ impl ReadPlanPass for SplitLargeRequestPass { } let mut new_ranges = vec![]; - for range in ranges.iter() { + for range in &ranges { if (range.end - range.start) > self.split_threshold { let mut curr_start = range.start; while curr_start < range.end { @@ -99,8 +99,8 @@ struct RangeCacheEntry { impl RangeCacheEntry { async fn get_or_wait(&self, range: Range) -> std::result::Result { { - let mut _guard = self.state.lock().await; - match &mut *_guard { + let mut guard = self.state.lock().await; + match &mut *guard { RangeCacheState::InFlight(f) => { // TODO(sammy): thread in url for join error let v = f @@ -112,7 +112,7 @@ impl RangeCacheEntry { .as_ref() .map(|b| b.slice(range)) .map_err(|e| daft_io::Error::CachedError { source: e.clone() }); - *_guard = RangeCacheState::Ready(v); + *guard = RangeCacheState::Ready(v); sliced } RangeCacheState::Ready(v) => v @@ -124,7 +124,7 @@ impl RangeCacheEntry { } } -pub(crate) struct ReadPlanner { +pub struct ReadPlanner { source: String, ranges: RangeList, passes: Vec>, @@ -148,7 +148,7 @@ impl ReadPlanner { } pub fn run_passes(&mut self) -> super::Result<()> { - for pass in self.passes.iter() { + for pass in &self.passes { let (changed, ranges) = pass.run(&self.ranges)?; if changed { self.ranges = ranges; @@ -193,7 +193,7 @@ impl ReadPlanner { } } -pub(crate) struct RangesContainer { +pub struct RangesContainer { ranges: Vec>, } @@ -280,7 +280,7 @@ impl RangesContainer { impl Display for ReadPlanner { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, "ReadPlanBuilder: {} ranges", self.ranges.len())?; - for range in self.ranges.iter() { + for range in &self.ranges { writeln!( f, "{}-{}, {}", diff --git a/src/daft-parquet/src/statistics/column_range.rs b/src/daft-parquet/src/statistics/column_range.rs index 6910eb7ad5..ac62627ea4 100644 --- a/src/daft-parquet/src/statistics/column_range.rs +++ b/src/daft-parquet/src/statistics/column_range.rs @@ -10,8 +10,8 @@ use parquet2::{ use snafu::{OptionExt, ResultExt}; use super::{ - utils::*, DaftStatsSnafu, MissingParquetColumnStatisticsSnafu, - UnableToParseUtf8FromBinarySnafu, Wrap, + utils::{convert_i128, convert_i96_to_i64_timestamp}, + DaftStatsSnafu, MissingParquetColumnStatisticsSnafu, UnableToParseUtf8FromBinarySnafu, Wrap, }; impl TryFrom<&BooleanStatistics> for Wrap { @@ -389,7 +389,7 @@ fn convert_int96_column_range_statistics( Ok(ColumnRangeStatistics::Missing) } -pub(crate) fn parquet_statistics_to_column_range_statistics( +pub fn parquet_statistics_to_column_range_statistics( pq_stats: &dyn Statistics, daft_dtype: &DataType, ) -> Result { diff --git a/src/daft-parquet/src/statistics/mod.rs b/src/daft-parquet/src/statistics/mod.rs index 2827c84355..0aa3d4bb57 100644 --- a/src/daft-parquet/src/statistics/mod.rs +++ b/src/daft-parquet/src/statistics/mod.rs @@ -10,7 +10,7 @@ pub use table_stats::row_group_metadata_to_table_stats; #[derive(Debug, Snafu)] #[snafu(visibility(pub(crate)))] -pub(super) enum Error { +pub enum Error { #[snafu(display("MissingParquetColumnStatistics"))] MissingParquetColumnStatistics {}, #[snafu(display("UnableToParseParquetColumnStatistics: {source}"))] @@ -43,7 +43,7 @@ impl From for DaftError { } } -pub(super) struct Wrap(T); +pub struct Wrap(T); impl From for Wrap { fn from(value: T) -> Self { diff --git a/src/daft-parquet/src/statistics/utils.rs b/src/daft-parquet/src/statistics/utils.rs index 28b65e15ac..c80a533382 100644 --- a/src/daft-parquet/src/statistics/utils.rs +++ b/src/daft-parquet/src/statistics/utils.rs @@ -15,8 +15,8 @@ fn int96_to_i64_us(value: [u32; 3]) -> i64 { const SECONDS_PER_DAY: i64 = 86_400; const MICROS_PER_SECOND: i64 = 1_000_000; - let day = value[2] as i64; - let microseconds = (((value[1] as i64) << 32) + value[0] as i64) / 1_000; + let day = i64::from(value[2]); + let microseconds = ((i64::from(value[1]) << 32) + i64::from(value[0])) / 1_000; let seconds = (day - JULIAN_DAY_OF_EPOCH) * SECONDS_PER_DAY; seconds * MICROS_PER_SECOND + microseconds @@ -28,8 +28,8 @@ fn int96_to_i64_ms(value: [u32; 3]) -> i64 { const SECONDS_PER_DAY: i64 = 86_400; const MILLIS_PER_SECOND: i64 = 1_000; - let day = value[2] as i64; - let milliseconds = (((value[1] as i64) << 32) + value[0] as i64) / 1_000_000; + let day = i64::from(value[2]); + let milliseconds = ((i64::from(value[1]) << 32) + i64::from(value[0])) / 1_000_000; let seconds = (day - JULIAN_DAY_OF_EPOCH) * SECONDS_PER_DAY; seconds * MILLIS_PER_SECOND + milliseconds diff --git a/src/daft-parquet/src/stream_reader.rs b/src/daft-parquet/src/stream_reader.rs index 1e8c3f9d27..fd77efe886 100644 --- a/src/daft-parquet/src/stream_reader.rs +++ b/src/daft-parquet/src/stream_reader.rs @@ -48,7 +48,7 @@ fn prune_fields_from_schema( } } -pub(crate) fn arrow_column_iters_to_table_iter( +pub fn arrow_column_iters_to_table_iter( arr_iters: ArrowChunkIters, row_range_start: usize, schema_ref: SchemaRef, @@ -67,7 +67,10 @@ pub(crate) fn arrow_column_iters_to_table_iter( type Item = arrow2::error::Result; fn next(&mut self) -> Option { - self.iters.par_iter_mut().map(|iter| iter.next()).collect() + self.iters + .par_iter_mut() + .map(std::iter::Iterator::next) + .collect() } } let par_lock_step_iter = ParallelLockStepIter { iters: arr_iters }; @@ -75,7 +78,7 @@ pub(crate) fn arrow_column_iters_to_table_iter( // Keep track of the current index in the row group so we can throw away arrays that are not needed // and slice arrays that are partially needed. let mut index_so_far = 0; - let owned_schema_ref = schema_ref.clone(); + let owned_schema_ref = schema_ref; let table_iter = par_lock_step_iter.into_iter().map(move |chunk| { let chunk = chunk.with_context(|_| { super::UnableToCreateChunkFromStreamingFileReaderSnafu { path: uri.clone() } @@ -101,7 +104,7 @@ pub(crate) fn arrow_column_iters_to_table_iter( let len = all_series .first() - .map(|s| s.len()) + .map(daft_core::series::Series::len) .expect("All series should not be empty when creating table from parquet chunks"); if all_series.iter().any(|s| s.len() != len) { return Err(super::Error::ParquetColumnsDontHaveEqualRows { path: uri.clone() }.into()); @@ -176,12 +179,12 @@ where impl Drop for CountingReader { fn drop(&mut self) { - self.update_count() + self.update_count(); } } #[allow(clippy::too_many_arguments)] -pub(crate) fn local_parquet_read_into_column_iters( +pub fn local_parquet_read_into_column_iters( uri: &str, columns: Option<&[String]>, start_offset: Option, @@ -201,8 +204,8 @@ pub(crate) fn local_parquet_read_into_column_iters( const LOCAL_PROTOCOL: &str = "file://"; let uri = uri .strip_prefix(LOCAL_PROTOCOL) - .map(|s| s.to_string()) - .unwrap_or(uri.to_string()); + .map(std::string::ToString::to_string) + .unwrap_or_else(|| uri.to_string()); let reader = File::open(uri.clone()).with_context(|_| super::InternalIOSnafu { path: uri.to_string(), @@ -250,7 +253,7 @@ pub(crate) fn local_parquet_read_into_column_iters( num_rows, start_offset.unwrap_or(0), row_groups, - predicate.clone(), + predicate, &daft_schema, &metadata, &uri, @@ -286,7 +289,7 @@ pub(crate) fn local_parquet_read_into_column_iters( } #[allow(clippy::too_many_arguments)] -pub(crate) fn local_parquet_read_into_arrow( +pub fn local_parquet_read_into_arrow( uri: &str, columns: Option<&[String]>, start_offset: Option, @@ -426,7 +429,7 @@ pub(crate) fn local_parquet_read_into_arrow( } #[allow(clippy::too_many_arguments)] -pub(crate) async fn local_parquet_read_async( +pub async fn local_parquet_read_async( uri: &str, columns: Option>, start_offset: Option, @@ -488,7 +491,7 @@ pub(crate) async fn local_parquet_read_async( } #[allow(clippy::too_many_arguments)] -pub(crate) fn local_parquet_stream( +pub fn local_parquet_stream( uri: &str, original_columns: Option>, columns: Option>, @@ -590,7 +593,7 @@ pub(crate) fn local_parquet_stream( } #[allow(clippy::too_many_arguments)] -pub(crate) async fn local_parquet_read_into_arrow_async( +pub async fn local_parquet_read_into_arrow_async( uri: &str, columns: Option>, start_offset: Option, diff --git a/src/daft-physical-plan/src/local_plan.rs b/src/daft-physical-plan/src/local_plan.rs index 035065ef97..9fd5f96b3a 100644 --- a/src/daft-physical-plan/src/local_plan.rs +++ b/src/daft-physical-plan/src/local_plan.rs @@ -46,11 +46,13 @@ pub enum LocalPhysicalPlan { } impl LocalPhysicalPlan { + #[must_use] pub fn name(&self) -> &'static str { // uses strum::IntoStaticStr self.into() } + #[must_use] pub fn arced(self) -> LocalPhysicalPlanRef { self.into() } @@ -198,6 +200,7 @@ impl LocalPhysicalPlan { .arced() } + #[must_use] pub fn schema(&self) -> &SchemaRef { match self { Self::PhysicalScan(PhysicalScan { schema, .. }) @@ -217,13 +220,12 @@ impl LocalPhysicalPlan { } #[derive(Debug)] - pub struct InMemoryScan { pub info: InMemoryInfo, pub plan_stats: PlanStats, } -#[derive(Debug)] +#[derive(Debug)] pub struct PhysicalScan { pub scan_tasks: Vec, pub schema: SchemaRef, @@ -231,37 +233,36 @@ pub struct PhysicalScan { } #[derive(Debug)] - pub struct EmptyScan { pub schema: SchemaRef, pub plan_stats: PlanStats, } -#[derive(Debug)] +#[derive(Debug)] pub struct Project { pub input: LocalPhysicalPlanRef, pub projection: Vec, pub schema: SchemaRef, pub plan_stats: PlanStats, } -#[derive(Debug)] +#[derive(Debug)] pub struct Filter { pub input: LocalPhysicalPlanRef, pub predicate: ExprRef, pub schema: SchemaRef, pub plan_stats: PlanStats, } -#[derive(Debug)] +#[derive(Debug)] pub struct Limit { pub input: LocalPhysicalPlanRef, pub num_rows: i64, pub schema: SchemaRef, pub plan_stats: PlanStats, } -#[derive(Debug)] +#[derive(Debug)] pub struct Sort { pub input: LocalPhysicalPlanRef, pub sort_by: Vec, @@ -269,16 +270,16 @@ pub struct Sort { pub schema: SchemaRef, pub plan_stats: PlanStats, } -#[derive(Debug)] +#[derive(Debug)] pub struct UnGroupedAggregate { pub input: LocalPhysicalPlanRef, pub aggregations: Vec, pub schema: SchemaRef, pub plan_stats: PlanStats, } -#[derive(Debug)] +#[derive(Debug)] pub struct HashAggregate { pub input: LocalPhysicalPlanRef, pub aggregations: Vec, @@ -288,7 +289,6 @@ pub struct HashAggregate { } #[derive(Debug)] - pub struct HashJoin { pub left: LocalPhysicalPlanRef, pub right: LocalPhysicalPlanRef, @@ -299,7 +299,6 @@ pub struct HashJoin { } #[derive(Debug)] - pub struct Concat { pub input: LocalPhysicalPlanRef, pub other: LocalPhysicalPlanRef, @@ -308,8 +307,7 @@ pub struct Concat { } #[derive(Debug)] - pub struct PhysicalWrite {} -#[derive(Debug)] +#[derive(Debug)] pub struct PlanStats {} diff --git a/src/daft-physical-plan/src/translate.rs b/src/daft-physical-plan/src/translate.rs index dc38d88b4f..ce1683fed8 100644 --- a/src/daft-physical-plan/src/translate.rs +++ b/src/daft-physical-plan/src/translate.rs @@ -85,7 +85,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { )) } LogicalPlan::Distinct(distinct) => { - let schema = distinct.input.schema().clone(); + let schema = distinct.input.schema(); let input = translate(&distinct.input)?; let col_exprs = input .schema() diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index 7de191fe81..13d2c4307f 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -34,6 +34,7 @@ log = {workspace = true} pyo3 = {workspace = true, optional = true} serde = {workspace = true, features = ["rc"]} snafu = {workspace = true} +uuid = {version = "1", features = ["v4"]} [dev-dependencies] daft-dsl = {path = "../daft-dsl", features = ["test-utils"]} diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 982a3634a9..e8c651f6bc 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -1,17 +1,27 @@ use std::{ - collections::{HashMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, sync::Arc, }; use common_daft_config::DaftPlanningConfig; use common_display::mermaid::MermaidDisplayOptions; use common_error::DaftResult; -use common_file_formats::FileFormat; +use common_file_formats::{FileFormat, FileFormatConfig, ParquetSourceConfig}; use common_io_config::IOConfig; -use daft_core::join::{JoinStrategy, JoinType}; +use daft_core::{ + join::{JoinStrategy, JoinType}, + prelude::TimeUnit, +}; use daft_dsl::{col, ExprRef}; -use daft_scan::{PhysicalScanInfo, Pushdowns, ScanOperatorRef}; -use daft_schema::schema::{Schema, SchemaRef}; +use daft_scan::{ + glob::GlobScanOperator, + storage_config::{NativeStorageConfig, StorageConfig}, + PhysicalScanInfo, Pushdowns, ScanOperatorRef, +}; +use daft_schema::{ + field::Field, + schema::{Schema, SchemaRef}, +}; #[cfg(feature = "python")] use { crate::sink_info::{CatalogInfo, IcebergCatalogInfo}, @@ -73,7 +83,29 @@ impl From<&LogicalPlanBuilder> for LogicalPlanRef { value.plan.clone() } } - +pub trait IntoGlobPath { + fn into_glob_path(self) -> Vec; +} +impl IntoGlobPath for Vec { + fn into_glob_path(self) -> Vec { + self + } +} +impl IntoGlobPath for String { + fn into_glob_path(self) -> Vec { + vec![self] + } +} +impl IntoGlobPath for &str { + fn into_glob_path(self) -> Vec { + vec![self.to_string()] + } +} +impl IntoGlobPath for Vec<&str> { + fn into_glob_path(self) -> Vec { + self.iter().map(|s| (*s).to_string()).collect() + } +} impl LogicalPlanBuilder { /// Replace the LogicalPlanBuilder's plan with the provided plan pub fn with_new_plan>>(&self, plan: LP) -> Self { @@ -103,11 +135,52 @@ impl LogicalPlanBuilder { num_rows, None, // TODO(sammy) thread through clustering spec to Python )); - let logical_plan: LogicalPlan = - logical_ops::Source::new(schema.clone(), source_info.into()).into(); + let logical_plan: LogicalPlan = logical_ops::Source::new(schema, source_info.into()).into(); + Ok(Self::new(logical_plan.into(), None)) } + #[cfg(feature = "python")] + pub fn delta_scan>( + glob_path: T, + io_config: Option, + multithreaded_io: bool, + ) -> DaftResult { + use daft_scan::storage_config::PyStorageConfig; + + Python::with_gil(|py| { + let io_config = io_config.unwrap_or_default(); + + let native_storage_config = NativeStorageConfig { + io_config: Some(io_config), + multithreaded_io, + }; + + let py_storage_config: PyStorageConfig = + Arc::new(StorageConfig::Native(Arc::new(native_storage_config))).into(); + + // let py_io_config = PyIOConfig { config: io_config }; + let delta_lake_scan = PyModule::import_bound(py, "daft.delta_lake.delta_lake_scan")?; + let delta_lake_scan_operator = + delta_lake_scan.getattr(pyo3::intern!(py, "DeltaLakeScanOperator"))?; + let delta_lake_operator = delta_lake_scan_operator + .call1((glob_path.as_ref(), py_storage_config))? + .to_object(py); + let scan_operator_handle = + ScanOperatorHandle::from_python_scan_operator(delta_lake_operator, py)?; + Self::table_scan(scan_operator_handle.into(), None) + }) + } + + #[cfg(not(feature = "python"))] + pub fn delta_scan( + glob_path: T, + io_config: Option, + multithreaded_io: bool, + ) -> DaftResult { + panic!("Delta Lake scan requires the 'python' feature to be enabled.") + } + pub fn table_scan( scan_operator: ScanOperatorRef, pushdowns: Option, @@ -135,13 +208,17 @@ impl LogicalPlanBuilder { .collect::>(); Arc::new(Schema::new(pruned_upstream_schema)?) } else { - schema.clone() + schema }; let logical_plan: LogicalPlan = logical_ops::Source::new(output_schema, source_info.into()).into(); Ok(Self::new(logical_plan.into(), None)) } + pub fn parquet_scan(glob_path: T) -> ParquetScanBuilder { + ParquetScanBuilder::new(glob_path) + } + pub fn select(&self, to_select: Vec) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Project::try_new(self.plan.clone(), to_select)?.into(); @@ -351,7 +428,7 @@ impl LogicalPlanBuilder { ) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Join::try_new( self.plan.clone(), - right.into().clone(), + right.into(), left_on, right_on, join_type, @@ -498,6 +575,95 @@ impl LogicalPlanBuilder { } } +pub struct ParquetScanBuilder { + pub glob_paths: Vec, + pub infer_schema: bool, + pub coerce_int96_timestamp_unit: TimeUnit, + pub field_id_mapping: Option>>, + pub row_groups: Option>>>, + pub chunk_size: Option, + pub io_config: Option, + pub multithreaded: bool, + pub schema: Option, +} + +impl ParquetScanBuilder { + pub fn new(glob_paths: T) -> Self { + let glob_paths = glob_paths.into_glob_path(); + Self::new_impl(glob_paths) + } + + // concrete implementation to reduce LLVM code duplication + fn new_impl(glob_paths: Vec) -> Self { + Self { + glob_paths, + infer_schema: true, + coerce_int96_timestamp_unit: TimeUnit::Nanoseconds, + field_id_mapping: None, + row_groups: None, + chunk_size: None, + multithreaded: true, + schema: None, + io_config: None, + } + } + pub fn infer_schema(mut self, infer_schema: bool) -> Self { + self.infer_schema = infer_schema; + self + } + pub fn coerce_int96_timestamp_unit(mut self, unit: TimeUnit) -> Self { + self.coerce_int96_timestamp_unit = unit; + self + } + pub fn field_id_mapping(mut self, field_id_mapping: Arc>) -> Self { + self.field_id_mapping = Some(field_id_mapping); + self + } + pub fn row_groups(mut self, row_groups: Vec>>) -> Self { + self.row_groups = Some(row_groups); + self + } + pub fn chunk_size(mut self, chunk_size: usize) -> Self { + self.chunk_size = Some(chunk_size); + self + } + + pub fn io_config(mut self, io_config: IOConfig) -> Self { + self.io_config = Some(io_config); + self + } + + pub fn multithreaded(mut self, multithreaded: bool) -> Self { + self.multithreaded = multithreaded; + self + } + pub fn schema(mut self, schema: SchemaRef) -> Self { + self.schema = Some(schema); + self + } + + pub fn finish(self) -> DaftResult { + let cfg = ParquetSourceConfig { + coerce_int96_timestamp_unit: self.coerce_int96_timestamp_unit, + field_id_mapping: self.field_id_mapping, + row_groups: self.row_groups, + chunk_size: self.chunk_size, + }; + + let operator = Arc::new(GlobScanOperator::try_new( + self.glob_paths, + Arc::new(FileFormatConfig::Parquet(cfg)), + Arc::new(StorageConfig::Native(Arc::new( + NativeStorageConfig::new_internal(self.multithreaded, self.io_config), + ))), + self.infer_schema, + self.schema, + )?); + + LogicalPlanBuilder::table_scan(ScanOperatorRef(operator), None) + } +} + /// A Python-facing wrapper of the LogicalPlanBuilder. /// /// This lightweight proxy interface should hold as much of the Python-specific logic diff --git a/src/daft-plan/src/lib.rs b/src/daft-plan/src/lib.rs index 50e309916b..2541a143db 100644 --- a/src/daft-plan/src/lib.rs +++ b/src/daft-plan/src/lib.rs @@ -19,7 +19,7 @@ pub mod source_info; mod test; mod treenode; -pub use builder::{LogicalPlanBuilder, PyLogicalPlanBuilder}; +pub use builder::{LogicalPlanBuilder, ParquetScanBuilder, PyLogicalPlanBuilder}; pub use daft_core::join::{JoinStrategy, JoinType}; pub use logical_plan::{LogicalPlan, LogicalPlanRef}; pub use partitioning::ClusteringSpec; diff --git a/src/daft-plan/src/logical_ops/agg.rs b/src/daft-plan/src/logical_ops/agg.rs index 1eddc14a5f..2a7be5337c 100644 --- a/src/daft-plan/src/logical_ops/agg.rs +++ b/src/daft-plan/src/logical_ops/agg.rs @@ -41,10 +41,10 @@ impl Aggregate { let output_schema = Schema::new(fields).context(CreationSnafu)?.into(); Ok(Self { + input, aggregations, groupby, output_schema, - input, }) } diff --git a/src/daft-plan/src/logical_ops/join.rs b/src/daft-plan/src/logical_ops/join.rs index 2a68390066..d219d24211 100644 --- a/src/daft-plan/src/logical_ops/join.rs +++ b/src/daft-plan/src/logical_ops/join.rs @@ -3,7 +3,7 @@ use std::{ sync::Arc, }; -use common_error::DaftError; +use common_error::{DaftError, DaftResult}; use daft_core::prelude::*; use daft_dsl::{ col, @@ -13,6 +13,7 @@ use daft_dsl::{ }; use itertools::Itertools; use snafu::ResultExt; +use uuid::Uuid; use crate::{ logical_ops::Project, @@ -54,14 +55,31 @@ impl Join { join_type: JoinType, join_strategy: Option, ) -> logical_plan::Result { - let (left_on, left_fields) = - resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?; - let (right_on, right_fields) = + let (left_on, _) = resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?; + let (right_on, _) = resolve_exprs(right_on, &right.schema(), false).context(CreationSnafu)?; - for (on_exprs, on_fields) in [(&left_on, left_fields), (&right_on, right_fields)] { - let on_schema = Schema::new(on_fields).context(CreationSnafu)?; - for (field, expr) in on_schema.fields.values().zip(on_exprs.iter()) { + let (unique_left_on, unique_right_on) = + Self::rename_join_keys(left_on.clone(), right_on.clone()); + + let left_fields: Vec = unique_left_on + .iter() + .map(|e| e.to_field(&left.schema())) + .collect::>>() + .context(CreationSnafu)?; + + let right_fields: Vec = unique_right_on + .iter() + .map(|e| e.to_field(&right.schema())) + .collect::>>() + .context(CreationSnafu)?; + + for (on_exprs, on_fields) in [ + (&unique_left_on, &left_fields), + (&unique_right_on, &right_fields), + ] { + for (field, expr) in on_fields.iter().zip(on_exprs.iter()) { + // Null type check for both fields and expressions if matches!(field.dtype, DataType::Null) { return Err(DaftError::ValueError(format!( "Can't join on null type expressions: {expr}" @@ -74,7 +92,7 @@ impl Join { if matches!(join_type, JoinType::Anti | JoinType::Semi) { // The output schema is the same as the left input schema for anti and semi joins. - let output_schema = left.schema().clone(); + let output_schema = left.schema(); Ok(Self { left, @@ -167,6 +185,60 @@ impl Join { } } + /// Renames join keys for the given left and right expressions. This is required to + /// prevent errors when the join keys on the left and right expressions have the same key + /// name. + /// + /// This function takes two vectors of expressions (`left_exprs` and `right_exprs`) and + /// checks for pairs of column expressions that differ. If both expressions in a pair + /// are column expressions and they are not identical, it generates a unique identifier + /// and renames both expressions by appending this identifier to their original names. + /// + /// The function returns two vectors of expressions, where the renamed expressions are + /// substituted for the original expressions in the cases where renaming occurred. + /// + /// # Parameters + /// - `left_exprs`: A vector of expressions from the left side of a join. + /// - `right_exprs`: A vector of expressions from the right side of a join. + /// + /// # Returns + /// A tuple containing two vectors of expressions, one for the left side and one for the + /// right side, where expressions that needed to be renamed have been modified. + /// + /// # Example + /// ``` + /// let (renamed_left, renamed_right) = rename_join_keys(left_expressions, right_expressions); + /// ``` + /// + /// For more details, see [issue #2649](https://github.com/Eventual-Inc/Daft/issues/2649). + + fn rename_join_keys( + left_exprs: Vec>, + right_exprs: Vec>, + ) -> (Vec>, Vec>) { + left_exprs + .into_iter() + .zip(right_exprs) + .map( + |(left_expr, right_expr)| match (&*left_expr, &*right_expr) { + (Expr::Column(left_name), Expr::Column(right_name)) + if left_name == right_name => + { + (left_expr, right_expr) + } + _ => { + let unique_id = Uuid::new_v4().to_string(); + let renamed_left_expr = + left_expr.alias(format!("{}_{}", left_expr.name(), unique_id)); + let renamed_right_expr = + right_expr.alias(format!("{}_{}", right_expr.name(), unique_id)); + (renamed_left_expr, renamed_right_expr) + } + }, + ) + .unzip() + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("Join: Type = {}", self.join_type)); diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 41101fcd17..78de22bea6 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -99,7 +99,7 @@ impl Project { // all existing names must also be converted to semantic IDs. let mut column_name_substitutions = IndexMap::new(); - let mut exprs_to_walk: Vec> = exprs.to_vec(); + let mut exprs_to_walk: Vec> = exprs.clone(); while !exprs_to_walk.is_empty() { exprs_to_walk = exprs_to_walk .iter() @@ -121,7 +121,7 @@ impl Project { } else { // If previously seen, cache the expression (if it involves computation) if optimization::requires_computation(expr) { - subexpressions_to_cache.insert(expr_id.clone(), expr.clone()); + subexpressions_to_cache.insert(expr_id, expr.clone()); } // Stop recursing if previously seen; // we only want top-level repeated subexpressions @@ -133,7 +133,7 @@ impl Project { } if subexpressions_to_cache.is_empty() { - (exprs.to_vec(), IndexMap::new()) + (exprs, IndexMap::new()) } else { // Then, substitute all the cached subexpressions in the original expressions. let subexprs_to_replace = subexpressions_to_cache @@ -154,7 +154,7 @@ impl Project { if new_expr.name() != old_name { new_expr.alias(old_name) } else { - new_expr.clone() + new_expr } }) .collect::>(); @@ -182,7 +182,7 @@ fn replace_column_with_semantic_id( let sem_id = e.semantic_id(schema); if subexprs_to_replace.contains(&sem_id) { - let new_expr = Expr::Column(sem_id.id.clone()); + let new_expr = Expr::Column(sem_id.id); let new_expr = match e.as_ref() { Expr::Alias(_, name) => Expr::Alias(new_expr.into(), name.clone()), _ => new_expr, @@ -246,9 +246,7 @@ fn replace_column_with_semantic_id( if !child.transformed && !fill_value.transformed { Transformed::no(e) } else { - Transformed::yes( - Expr::FillNull(child.data.clone(), fill_value.data.clone()).into(), - ) + Transformed::yes(Expr::FillNull(child.data, fill_value.data).into()) } } Expr::IsIn(child, items) => { @@ -259,7 +257,7 @@ fn replace_column_with_semantic_id( if !child.transformed && !items.transformed { Transformed::no(e) } else { - Transformed::yes(Expr::IsIn(child.data.clone(), items.data.clone()).into()) + Transformed::yes(Expr::IsIn(child.data, items.data).into()) } } Expr::Between(child, lower, upper) => { @@ -272,10 +270,7 @@ fn replace_column_with_semantic_id( if !child.transformed && !lower.transformed && !upper.transformed { Transformed::no(e) } else { - Transformed::yes( - Expr::Between(child.data.clone(), lower.data.clone(), upper.data.clone()) - .into(), - ) + Transformed::yes(Expr::Between(child.data, lower.data, upper.data).into()) } } Expr::BinaryOp { op, left, right } => { @@ -289,8 +284,8 @@ fn replace_column_with_semantic_id( Transformed::yes( Expr::BinaryOp { op: *op, - left: left.data.clone(), - right: right.data.clone(), + left: left.data, + right: right.data, } .into(), ) @@ -312,9 +307,9 @@ fn replace_column_with_semantic_id( } else { Transformed::yes( Expr::IfElse { - predicate: predicate.data.clone(), - if_true: if_true.data.clone(), - if_false: if_false.data.clone(), + predicate: predicate.data, + if_true: if_true.data, + if_false: if_false.data, } .into(), ) @@ -373,24 +368,24 @@ fn replace_column_with_semantic_id_aggexpr( AggExpr::Count(ref child, mode) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::Count(transformed_child, mode), - |_| e.clone(), + |_| e, ) } AggExpr::Sum(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Sum, |_| e.clone()) + .map_yes_no(AggExpr::Sum, |_| e) } AggExpr::ApproxPercentile(ApproxPercentileParams { ref child, ref percentiles, - ref force_list_output, + force_list_output, }) => replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) .map_yes_no( |transformed_child| { AggExpr::ApproxPercentile(ApproxPercentileParams { child: transformed_child, percentiles: percentiles.clone(), - force_list_output: *force_list_output, + force_list_output, }) }, |_| e.clone(), @@ -402,40 +397,44 @@ fn replace_column_with_semantic_id_aggexpr( AggExpr::ApproxSketch(ref child, sketch_type) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::ApproxSketch(transformed_child, sketch_type), - |_| e.clone(), + |_| e, ) } AggExpr::MergeSketch(ref child, sketch_type) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::MergeSketch(transformed_child, sketch_type), - |_| e.clone(), + |_| e, ) } AggExpr::Mean(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Mean, |_| e.clone()) + .map_yes_no(AggExpr::Mean, |_| e) + } + AggExpr::Stddev(ref child) => { + replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) + .map_yes_no(AggExpr::Stddev, |_| e) } AggExpr::Min(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Min, |_| e.clone()) + .map_yes_no(AggExpr::Min, |_| e) } AggExpr::Max(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Max, |_| e.clone()) + .map_yes_no(AggExpr::Max, |_| e) } AggExpr::AnyValue(ref child, ignore_nulls) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::AnyValue(transformed_child, ignore_nulls), - |_| e.clone(), + |_| e, ) } AggExpr::List(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::List, |_| e.clone()) + .map_yes_no(AggExpr::List, |_| e) } AggExpr::Concat(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Concat, |_| e.clone()) + .map_yes_no(AggExpr::Concat, |_| e) } AggExpr::MapGroups { func, inputs } => { let transforms = inputs @@ -446,7 +445,7 @@ fn replace_column_with_semantic_id_aggexpr( Transformed::no(AggExpr::MapGroups { func, inputs }) } else { Transformed::yes(AggExpr::MapGroups { - func: func.clone(), + func, inputs: transforms.iter().map(|t| t.data.clone()).collect(), }) } @@ -487,26 +486,24 @@ mod tests { let a4 = binary_op(Operator::Plus, a2.clone(), a2.clone()); let a4_colname = a4.semantic_id(&source.schema()).id; - let a8 = binary_op(Operator::Plus, a4.clone(), a4.clone()); + let a8 = binary_op(Operator::Plus, a4.clone(), a4); let expressions = vec![a8.alias("x")]; - let result_projection = Project::try_new(source.clone(), expressions)?; + let result_projection = Project::try_new(source, expressions)?; let a4_col = col(a4_colname.clone()); let expected_result_projection = - vec![binary_op(Operator::Plus, a4_col.clone(), a4_col.clone()).alias("x")]; + vec![binary_op(Operator::Plus, a4_col.clone(), a4_col).alias("x")]; assert_eq!(result_projection.projection, expected_result_projection); let a2_col = col(a2_colname.clone()); let expected_subprojection = - vec![ - binary_op(Operator::Plus, a2_col.clone(), a2_col.clone()).alias(a4_colname.clone()) - ]; + vec![binary_op(Operator::Plus, a2_col.clone(), a2_col).alias(a4_colname)]; let LogicalPlan::Project(subprojection) = result_projection.input.as_ref() else { panic!() }; assert_eq!(subprojection.projection, expected_subprojection); - let expected_third_projection = vec![a2.alias(a2_colname.clone())]; + let expected_third_projection = vec![a2.alias(a2_colname)]; let LogicalPlan::Project(third_projection) = subprojection.input.as_ref() else { panic!() }; @@ -533,10 +530,10 @@ mod tests { let a2_colname = a2.semantic_id(&source.schema()).id; let expressions = vec![ - a2.clone().alias("x"), + a2.alias("x"), binary_op(Operator::Plus, a2.clone(), col("a")).alias("y"), ]; - let result_projection = Project::try_new(source.clone(), expressions)?; + let result_projection = Project::try_new(source, expressions)?; let a2_col = col(a2_colname.clone()); let expected_result_projection = vec![ @@ -545,8 +542,7 @@ mod tests { ]; assert_eq!(result_projection.projection, expected_result_projection); - let expected_subprojection = - vec![a2.clone().alias(a2_colname.clone()), col("a").alias("a")]; + let expected_subprojection = vec![a2.alias(a2_colname), col("a").alias("a")]; let LogicalPlan::Project(subprojection) = result_projection.input.as_ref() else { panic!() }; diff --git a/src/daft-plan/src/logical_ops/sample.rs b/src/daft-plan/src/logical_ops/sample.rs index 7b63c5b6ad..9d96594666 100644 --- a/src/daft-plan/src/logical_ops/sample.rs +++ b/src/daft-plan/src/logical_ops/sample.rs @@ -22,6 +22,7 @@ impl Hash for Sample { self.input.hash(state); // Convert the `f64` to a stable format with 6 decimal places. + #[expect(clippy::collection_is_never_read, reason = "nursery bug pretty sure")] let fraction_str = format!("{:.6}", self.fraction); fraction_str.hash(state); diff --git a/src/daft-plan/src/logical_optimization/optimizer.rs b/src/daft-plan/src/logical_optimization/optimizer.rs index 535eb16448..a53d5980da 100644 --- a/src/daft-plan/src/logical_optimization/optimizer.rs +++ b/src/daft-plan/src/logical_optimization/optimizer.rs @@ -405,7 +405,7 @@ mod tests { // 3 + 2 + 1 = 6 assert_eq!(pass_count, 6); - let mut new_proj_exprs = proj_exprs.clone(); + let mut new_proj_exprs = proj_exprs; new_proj_exprs.rotate_left(2); let new_pred = filter_predicate .or(lit(false)) @@ -446,7 +446,7 @@ mod tests { }; let new_predicate = filter.predicate.or(lit(false)); Ok(Transformed::yes( - LogicalPlan::from(Filter::try_new(filter.input.clone(), new_predicate)?).into(), + LogicalPlan::from(Filter::try_new(filter.input, new_predicate)?).into(), )) }) } @@ -473,7 +473,7 @@ mod tests { }; let new_predicate = filter.predicate.and(lit(true)); Ok(Transformed::yes( - LogicalPlan::from(Filter::try_new(filter.input.clone(), new_predicate)?).into(), + LogicalPlan::from(Filter::try_new(filter.input, new_predicate)?).into(), )) }) } @@ -511,7 +511,7 @@ mod tests { exprs.rotate_left(1); } Ok(Transformed::yes( - LogicalPlan::from(Project::try_new(project.input.clone(), exprs)?).into(), + LogicalPlan::from(Project::try_new(project.input, exprs)?).into(), )) }) } diff --git a/src/daft-plan/src/logical_optimization/rules/drop_repartition.rs b/src/daft-plan/src/logical_optimization/rules/drop_repartition.rs index 727ebec298..112f89bd2d 100644 --- a/src/daft-plan/src/logical_optimization/rules/drop_repartition.rs +++ b/src/daft-plan/src/logical_optimization/rules/drop_repartition.rs @@ -85,7 +85,7 @@ mod tests { .hash_repartition(Some(num_partitions2), partition_by.clone())? .build(); let expected = dummy_scan_node(scan_op) - .hash_repartition(Some(num_partitions2), partition_by.clone())? + .hash_repartition(Some(num_partitions2), partition_by)? .build(); assert_optimized_plan_eq(plan, expected)?; Ok(()) diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs b/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs index 14728dab25..cffd5588ce 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs @@ -57,7 +57,7 @@ impl PushDownFilter { // Split predicate expression on conjunctions (ANDs). let parent_predicates = split_conjuction(&filter.predicate); - let predicate_set: HashSet<&ExprRef> = parent_predicates.iter().cloned().collect(); + let predicate_set: HashSet<&ExprRef> = parent_predicates.iter().copied().collect(); // Add child predicate expressions to parent predicate expressions, eliminating duplicates. let new_predicates: Vec = parent_predicates .iter() @@ -76,7 +76,6 @@ impl PushDownFilter { self.try_optimize_node(new_filter.clone())? .or(Transformed::yes(new_filter)) .data - .clone() } LogicalPlan::Source(source) => { match source.source_info.as_ref() { @@ -97,7 +96,7 @@ impl PushDownFilter { .filters .as_ref() .map(|f| predicate.clone().and(f.clone())) - .unwrap_or(predicate.clone()); + .unwrap_or_else(|| predicate.clone()); // We split the predicate into three groups: // 1. All partition-only filters, which can be applied directly to partition values and can be // dropped from the data-level filter. @@ -681,19 +680,13 @@ mod tests { let expected_left_filter_scan = if push_into_left_scan { dummy_scan_node_with_pushdowns( left_scan_op.clone(), - Pushdowns::default().with_filters(Some(pred.clone())), + Pushdowns::default().with_filters(Some(pred)), ) } else { - left_scan_plan.filter(pred.clone())? + left_scan_plan.filter(pred)? }; let expected = expected_left_filter_scan - .join( - &right_scan_plan, - join_on.clone(), - join_on.clone(), - how, - None, - )? + .join(&right_scan_plan, join_on.clone(), join_on, how, None)? .build(); assert_optimized_plan_eq(plan, expected)?; Ok(()) @@ -733,16 +726,16 @@ mod tests { let expected_right_filter_scan = if push_into_right_scan { dummy_scan_node_with_pushdowns( right_scan_op.clone(), - Pushdowns::default().with_filters(Some(pred.clone())), + Pushdowns::default().with_filters(Some(pred)), ) } else { - right_scan_plan.filter(pred.clone())? + right_scan_plan.filter(pred)? }; let expected = left_scan_plan .join( &expected_right_filter_scan, join_on.clone(), - join_on.clone(), + join_on, how, None, )? @@ -815,7 +808,7 @@ mod tests { .join( &expected_right_filter_scan, join_on.clone(), - join_on.clone(), + join_on, how, None, )? @@ -842,14 +835,8 @@ mod tests { let join_on = vec![col("b")]; let pred = col("a").lt(lit(2)); let plan = left_scan_plan - .join( - &right_scan_plan, - join_on.clone(), - join_on.clone(), - how, - None, - )? - .filter(pred.clone())? + .join(&right_scan_plan, join_on.clone(), join_on, how, None)? + .filter(pred)? .build(); // should not push down filter let expected = plan.clone(); @@ -875,14 +862,8 @@ mod tests { let join_on = vec![col("b")]; let pred = col("c").lt(lit(2.0)); let plan = left_scan_plan - .join( - &right_scan_plan, - join_on.clone(), - join_on.clone(), - how, - None, - )? - .filter(pred.clone())? + .join(&right_scan_plan, join_on.clone(), join_on, how, None)? + .filter(pred)? .build(); // should not push down filter let expected = plan.clone(); diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_limit.rs b/src/daft-plan/src/logical_optimization/rules/push_down_limit.rs index b8a3a223bd..66351c77d5 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_limit.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_limit.rs @@ -106,8 +106,7 @@ impl PushDownLimit { let optimized = self .try_optimize_node(new_plan.clone())? .or(Transformed::yes(new_plan)) - .data - .clone(); + .data; Ok(Transformed::yes(optimized)) } _ => Ok(Transformed::no(plan)), diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs index b063823b9b..399504050a 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs @@ -65,7 +65,7 @@ impl PushDownProjection { let upstream_computations = upstream_projection .projection .iter() - .flat_map(|e| { + .filter_map(|e| { e.input_mapping().map_or_else( // None means computation required -> Some(colname) || Some(e.name().to_string()), @@ -76,7 +76,7 @@ impl PushDownProjection { .collect::>(); // For each of them, make sure they are used only once in this downstream projection. - let mut exprs_to_walk: Vec> = projection.projection.to_vec(); + let mut exprs_to_walk: Vec> = projection.projection.clone(); let mut upstream_computations_used = IndexSet::new(); let mut okay_to_merge = true; @@ -91,8 +91,8 @@ impl PushDownProjection { && let Expr::Column(name) = expr.as_ref() && upstream_computations.contains(name.as_ref()) { - okay_to_merge = - okay_to_merge && upstream_computations_used.insert(name.to_string()) + okay_to_merge = okay_to_merge + && upstream_computations_used.insert(name.to_string()); }; if okay_to_merge { expr.children() @@ -130,7 +130,7 @@ impl PushDownProjection { // Root node is changed, look at it again. let new_plan = self .try_optimize_node(new_plan.clone())? - .or(Transformed::yes(new_plan.clone())); + .or(Transformed::yes(new_plan)); return Ok(new_plan); } } @@ -402,9 +402,8 @@ impl PushDownProjection { let new_left_subprojection: LogicalPlan = { Project::try_new(concat.input.clone(), pushdown_column_exprs.clone())?.into() }; - let new_right_subprojection: LogicalPlan = { - Project::try_new(concat.other.clone(), pushdown_column_exprs.clone())?.into() - }; + let new_right_subprojection: LogicalPlan = + { Project::try_new(concat.other.clone(), pushdown_column_exprs)?.into() }; let new_upstream = upstream_plan.with_new_children(&[ new_left_subprojection.into(), @@ -447,10 +446,8 @@ impl PushDownProjection { .collect(); if combined_dependencies.len() < upstream_names.len() { - let pushdown_column_exprs: Vec = combined_dependencies - .into_iter() - .map(|d| col(d.to_string())) - .collect(); + let pushdown_column_exprs: Vec = + combined_dependencies.into_iter().map(col).collect(); let new_project: LogicalPlan = Project::try_new(side.clone(), pushdown_column_exprs)?.into(); Ok(Transformed::yes(new_project.into())) @@ -474,10 +471,8 @@ impl PushDownProjection { Ok(Transformed::no(plan)) } else { // If either pushdown is possible, create a new Join node. - let new_join = upstream_plan.with_new_children(&[ - new_left_upstream.data.clone(), - new_right_upstream.data.clone(), - ]); + let new_join = upstream_plan + .with_new_children(&[new_left_upstream.data, new_right_upstream.data]); let new_plan = Arc::new(plan.with_new_children(&[new_join.into()])); @@ -696,7 +691,7 @@ mod tests { /// Projection merging: Ensure factored projections do not get merged. #[test] fn test_merge_does_not_unfactor() -> DaftResult<()> { - let a2 = col("a").clone().add(col("a")); + let a2 = col("a").add(col("a")); let a4 = a2.clone().add(a2); let a8 = a4.clone().add(a4); let expressions = vec![a8.alias("x")]; @@ -1001,10 +996,11 @@ mod tests { // Select the `udf_results` column, so the ActorPoolProject should apply column pruning to the other columns let plan = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( - scan_node.clone(), + scan_node, vec![col("a"), col("b"), mock_stateful_udf.alias("udf_results_0")], )?) .arced(); + let plan = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( plan, vec![ @@ -1015,6 +1011,7 @@ mod tests { ], )?) .arced(); + let plan = LogicalPlan::Project(Project::try_new( plan, vec![ @@ -1035,7 +1032,7 @@ mod tests { )?) .arced(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( - expected.clone(), + expected, vec![ // Absorbed a non-computational expression (alias) from the Projection col("udf_results_0").alias("udf_results_0_alias"), @@ -1086,7 +1083,7 @@ mod tests { // Select only col("a"), so the ActorPoolProject node is now redundant and should be removed let actor_pool_project = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( - scan_node.clone(), + scan_node, vec![col("a"), col("b"), mock_stateful_udf.alias("udf_results")], )?) .arced(); diff --git a/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs b/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs index 1d77502221..170eace0d6 100644 --- a/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs +++ b/src/daft-plan/src/logical_optimization/rules/split_actor_pool_projects.rs @@ -122,7 +122,7 @@ impl SplitActorPoolProjects { impl OptimizerRule for SplitActorPoolProjects { fn try_optimize(&self, plan: Arc) -> DaftResult>> { plan.transform_down(|node| match node.as_ref() { - LogicalPlan::Project(projection) => try_optimize_project(projection, node.clone(), 0), + LogicalPlan::Project(projection) => try_optimize_project(projection, node.clone()), _ => Ok(Transformed::no(node)), }) } @@ -355,6 +355,10 @@ fn split_projection( } else { truncated_exprs.push(expr.clone()); for required_col_name in get_required_columns(expr) { + #[expect( + clippy::set_contains_or_insert, + reason = "we are arcing it later; we might want to use contains separately unless there is a better way" + )] if !new_children_seen.contains(&required_col_name) { let colexpr = Expr::Column(required_col_name.as_str().into()).arced(); new_children_seen.insert(required_col_name); @@ -370,8 +374,34 @@ fn split_projection( fn try_optimize_project( projection: &Project, plan: Arc, +) -> DaftResult>> { + // Add aliases to the expressions in the projection to preserve original names when splitting stateful UDFs. + // This is needed because when we split stateful UDFs, we create new names for intermediates, but we would like + // to have the same expression names as the original projection. + let aliased_projection_exprs = projection + .projection + .iter() + .map(|e| { + if has_stateful_udf(e) && !matches!(e.as_ref(), Expr::Alias(..)) { + e.alias(e.name()) + } else { + e.clone() + } + }) + .collect(); + + let aliased_projection = Project::try_new(projection.input.clone(), aliased_projection_exprs)?; + + recursive_optimize_project(&aliased_projection, plan, 0) +} + +fn recursive_optimize_project( + projection: &Project, + plan: Arc, recursive_count: usize, ) -> DaftResult>> { + // TODO: eliminate the need for recursive calls by doing a post-order traversal of the plan tree. + // Base case: no stateful UDFs at all let has_stateful_udfs = projection.projection.iter().any(has_stateful_udf); if !has_stateful_udfs { @@ -417,8 +447,8 @@ fn try_optimize_project( let new_project = Project::try_new(projection.input.clone(), remaining)?; let new_child_project = LogicalPlan::Project(new_project.clone()).arced(); let optimized_child_plan = - try_optimize_project(&new_project, new_child_project.clone(), recursive_count + 1)?; - optimized_child_plan.data.clone() + recursive_optimize_project(&new_project, new_child_project, recursive_count + 1)?; + optimized_child_plan.data }; // Start building a chain of `child -> Project -> ActorPoolProject -> ActorPoolProject -> ... -> Project` @@ -448,11 +478,8 @@ fn try_optimize_project( .into_iter() .chain(stateless_stages) .collect(); - let new_plan = LogicalPlan::Project(Project::try_new( - new_plan_child.clone(), - stateless_projection, - )?) - .arced(); + let new_plan = + LogicalPlan::Project(Project::try_new(new_plan_child, stateless_projection)?).arced(); // Iteratively build ActorPoolProject nodes: [...all columns that came before it, StatefulUDF] let new_plan = { @@ -603,14 +630,14 @@ mod tests { // Add a Projection with StatefulUDF and resource request let project_plan = scan_plan - .with_columns(vec![stateful_project_expr.clone().alias("b")])? + .with_columns(vec![stateful_project_expr.alias("b")])? .build(); // Project([col("a")]) --> ActorPoolProject([col("a"), foo(col("a")).alias("b")]) --> Project([col("a"), col("b")]) let expected = scan_plan.select(vec![col("a")])?.build(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, - vec![col("a"), stateful_project_expr.clone().alias("b")], + vec![col("a"), stateful_project_expr.alias("b")], )?) .arced(); let expected = @@ -719,7 +746,7 @@ mod tests { // Add a Projection with StatefulUDF and resource request // Project([col("a"), foo(foo(col("a"))).alias("b")]) let project_plan = scan_plan - .with_columns(vec![stacked_stateful_project_expr.clone().alias("b")])? + .with_columns(vec![stacked_stateful_project_expr.alias("b")])? .build(); let intermediate_name = "__TruncateRootStatefulUDF_0-1-0__"; @@ -728,9 +755,7 @@ mod tests { expected, vec![ col("a"), - create_stateful_udf(vec![col("a")]) - .clone() - .alias(intermediate_name), + create_stateful_udf(vec![col("a")]).alias(intermediate_name), ], )?) .arced(); @@ -749,24 +774,20 @@ mod tests { vec![ col(intermediate_name), col("a"), - create_stateful_udf(vec![col(intermediate_name)]) - .clone() - .alias("b"), + create_stateful_udf(vec![col(intermediate_name)]).alias("b"), ], )?) .arced(); let expected = LogicalPlan::Project(Project::try_new(expected, vec![col("a"), col("b")])?).arced(); - assert_optimized_plan_eq(project_plan.clone(), expected.clone())?; + assert_optimized_plan_eq(project_plan.clone(), expected)?; // With Projection Pushdown, elide intermediate Projects and also perform column pushdown let expected = scan_plan.build(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, vec![ - create_stateful_udf(vec![col("a")]) - .clone() - .alias(intermediate_name), + create_stateful_udf(vec![col("a")]).alias(intermediate_name), col("a"), ], )?) @@ -775,9 +796,7 @@ mod tests { expected, vec![ col("a"), - create_stateful_udf(vec![col(intermediate_name)]) - .clone() - .alias("b"), + create_stateful_udf(vec![col(intermediate_name)]).alias("b"), ], )?) .arced(); @@ -785,6 +804,59 @@ mod tests { Ok(()) } + #[test] + fn test_multiple_with_column_serial_no_alias() -> DaftResult<()> { + let scan_op = dummy_scan_operator(vec![Field::new("a", DataType::Utf8)]); + let scan_plan = dummy_scan_node(scan_op); + let stacked_stateful_project_expr = + create_stateful_udf(vec![create_stateful_udf(vec![col("a")])]); + + // Add a Projection with StatefulUDF and resource request + let project_plan = scan_plan + .select(vec![stacked_stateful_project_expr])? + .build(); + + let intermediate_name = "__TruncateRootStatefulUDF_0-0-0__"; + + let expected = scan_plan.select(vec![col("a")])?.build(); + let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( + expected, + vec![ + col("a"), + create_stateful_udf(vec![col("a")]).alias(intermediate_name), + ], + )?) + .arced(); + let expected = + LogicalPlan::Project(Project::try_new(expected, vec![col(intermediate_name)])?).arced(); + let expected = + LogicalPlan::Project(Project::try_new(expected, vec![col(intermediate_name)])?).arced(); + let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( + expected, + vec![ + col(intermediate_name), + create_stateful_udf(vec![col(intermediate_name)]).alias("a"), + ], + )?) + .arced(); + let expected = LogicalPlan::Project(Project::try_new(expected, vec![col("a")])?).arced(); + assert_optimized_plan_eq(project_plan.clone(), expected)?; + + let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( + scan_plan.build(), + vec![create_stateful_udf(vec![col("a")]).alias(intermediate_name)], + )?) + .arced(); + let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( + expected, + vec![create_stateful_udf(vec![col(intermediate_name)]).alias("a")], + )?) + .arced(); + assert_optimized_plan_eq_with_projection_pushdown(project_plan, expected)?; + + Ok(()) + } + #[test] fn test_multiple_with_column_serial_multiarg() -> DaftResult<()> { let scan_op = dummy_scan_operator(vec![ @@ -800,7 +872,7 @@ mod tests { // Add a Projection with StatefulUDF and resource request // Project([foo(foo(col("a")), foo(col("b"))).alias("c")]) let project_plan = scan_plan - .select(vec![stacked_stateful_project_expr.clone().alias("c")])? + .select(vec![stacked_stateful_project_expr.alias("c")])? .build(); let intermediate_name_0 = "__TruncateRootStatefulUDF_0-0-0__"; @@ -811,9 +883,7 @@ mod tests { vec![ col("a"), col("b"), - create_stateful_udf(vec![col("a")]) - .clone() - .alias(intermediate_name_0), + create_stateful_udf(vec![col("a")]).alias(intermediate_name_0), ], )?) .arced(); @@ -823,9 +893,7 @@ mod tests { col("a"), col("b"), col(intermediate_name_0), - create_stateful_udf(vec![col("b")]) - .clone() - .alias(intermediate_name_1), + create_stateful_udf(vec![col("b")]).alias(intermediate_name_1), ], )?) .arced(); @@ -845,22 +913,19 @@ mod tests { col(intermediate_name_0), col(intermediate_name_1), create_stateful_udf(vec![col(intermediate_name_0), col(intermediate_name_1)]) - .clone() .alias("c"), ], )?) .arced(); let expected = LogicalPlan::Project(Project::try_new(expected, vec![col("c")])?).arced(); - assert_optimized_plan_eq(project_plan.clone(), expected.clone())?; + assert_optimized_plan_eq(project_plan.clone(), expected)?; // With Projection Pushdown, elide intermediate Projects and also perform column pushdown let expected = scan_plan.build(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, vec![ - create_stateful_udf(vec![col("a")]) - .clone() - .alias(intermediate_name_0), + create_stateful_udf(vec![col("a")]).alias(intermediate_name_0), col("b"), ], )?) @@ -869,9 +934,7 @@ mod tests { expected, vec![ col(intermediate_name_0), - create_stateful_udf(vec![col("b")]) - .clone() - .alias(intermediate_name_1), + create_stateful_udf(vec![col("b")]).alias(intermediate_name_1), ], )?) .arced(); @@ -879,12 +942,11 @@ mod tests { expected, vec![ create_stateful_udf(vec![col(intermediate_name_0), col(intermediate_name_1)]) - .clone() .alias("c"), ], )?) .arced(); - assert_optimized_plan_eq_with_projection_pushdown(project_plan.clone(), expected.clone())?; + assert_optimized_plan_eq_with_projection_pushdown(project_plan, expected)?; Ok(()) } @@ -903,7 +965,7 @@ mod tests { // Add a Projection with StatefulUDF and resource request // Project([foo(foo(col("a")) + foo(col("b"))).alias("c")]) let project_plan = scan_plan - .select(vec![stacked_stateful_project_expr.clone().alias("c")])? + .select(vec![stacked_stateful_project_expr.alias("c")])? .build(); let intermediate_name_0 = "__TruncateAnyStatefulUDFChildren_1-0-0__"; @@ -915,9 +977,7 @@ mod tests { vec![ col("a"), col("b"), - create_stateful_udf(vec![col("a")]) - .clone() - .alias(intermediate_name_0), + create_stateful_udf(vec![col("a")]).alias(intermediate_name_0), ], )?) .arced(); @@ -927,9 +987,7 @@ mod tests { col("a"), col("b"), col(intermediate_name_0), - create_stateful_udf(vec![col("b")]) - .clone() - .alias(intermediate_name_1), + create_stateful_udf(vec![col("b")]).alias(intermediate_name_1), ], )?) .arced(); @@ -959,23 +1017,19 @@ mod tests { expected, vec![ col(intermediate_name_2), - create_stateful_udf(vec![col(intermediate_name_2)]) - .clone() - .alias("c"), + create_stateful_udf(vec![col(intermediate_name_2)]).alias("c"), ], )?) .arced(); let expected = LogicalPlan::Project(Project::try_new(expected, vec![col("c")])?).arced(); - assert_optimized_plan_eq(project_plan.clone(), expected.clone())?; + assert_optimized_plan_eq(project_plan.clone(), expected)?; // With Projection Pushdown, elide intermediate Projects and also perform column pushdown let expected = scan_plan.build(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, vec![ - create_stateful_udf(vec![col("a")]) - .clone() - .alias(intermediate_name_0), + create_stateful_udf(vec![col("a")]).alias(intermediate_name_0), col("b"), ], )?) @@ -984,9 +1038,7 @@ mod tests { expected, vec![ col(intermediate_name_0), - create_stateful_udf(vec![col("b")]) - .clone() - .alias(intermediate_name_1), + create_stateful_udf(vec![col("b")]).alias(intermediate_name_1), ], )?) .arced(); @@ -999,12 +1051,10 @@ mod tests { .arced(); let expected = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new( expected, - vec![create_stateful_udf(vec![col(intermediate_name_2)]) - .clone() - .alias("c")], + vec![create_stateful_udf(vec![col(intermediate_name_2)]).alias("c")], )?) .arced(); - assert_optimized_plan_eq_with_projection_pushdown(project_plan.clone(), expected.clone())?; + assert_optimized_plan_eq_with_projection_pushdown(project_plan, expected)?; Ok(()) } @@ -1018,10 +1068,7 @@ mod tests { // Add a Projection with StatefulUDF and resource request // Project([foo(col("a") + foo(col("a"))).alias("c")]) let project_plan = scan_plan - .select(vec![ - col("a"), - stacked_stateful_project_expr.clone().alias("c"), - ])? + .select(vec![col("a"), stacked_stateful_project_expr.alias("c")])? .build(); let intermediate_name_0 = "__TruncateAnyStatefulUDFChildren_1-1-0__"; @@ -1074,7 +1121,7 @@ mod tests { let expected = LogicalPlan::Project(Project::try_new(expected, vec![col("a"), col("c")])?).arced(); - assert_optimized_plan_eq(project_plan.clone(), expected.clone())?; + assert_optimized_plan_eq(project_plan, expected)?; Ok(()) } @@ -1126,7 +1173,7 @@ mod tests { LogicalPlan::Project(Project::try_new(expected, vec![col("a"), col("result")])?) .arced(); - assert_optimized_plan_eq(project_plan.clone(), expected.clone())?; + assert_optimized_plan_eq(project_plan, expected)?; Ok(()) } diff --git a/src/daft-plan/src/logical_optimization/test/mod.rs b/src/daft-plan/src/logical_optimization/test/mod.rs index 75b53b2182..a2b16d6188 100644 --- a/src/daft-plan/src/logical_optimization/test/mod.rs +++ b/src/daft-plan/src/logical_optimization/test/mod.rs @@ -25,8 +25,7 @@ pub fn assert_optimized_plan_with_rules_eq( ); let optimized_plan = optimizer .optimize_with_rules(optimizer.rule_batches[0].rules.as_slice(), plan.clone())? - .data - .clone(); + .data; assert_eq!( optimized_plan, expected, diff --git a/src/daft-plan/src/physical_ops/empty_scan.rs b/src/daft-plan/src/physical_ops/empty_scan.rs index 63097b33b9..d18196bf21 100644 --- a/src/daft-plan/src/physical_ops/empty_scan.rs +++ b/src/daft-plan/src/physical_ops/empty_scan.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use crate::ClusteringSpec; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct EmptyScan { pub schema: SchemaRef, pub clustering_spec: Arc, diff --git a/src/daft-plan/src/physical_ops/iceberg_write.rs b/src/daft-plan/src/physical_ops/iceberg_write.rs index c5959055c2..9036b77aef 100644 --- a/src/daft-plan/src/physical_ops/iceberg_write.rs +++ b/src/daft-plan/src/physical_ops/iceberg_write.rs @@ -4,7 +4,6 @@ use serde::{Deserialize, Serialize}; use crate::{physical_plan::PhysicalPlanRef, sink_info::IcebergCatalogInfo}; #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] - pub struct IcebergWrite { pub schema: SchemaRef, pub iceberg_info: IcebergCatalogInfo, diff --git a/src/daft-plan/src/physical_ops/in_memory.rs b/src/daft-plan/src/physical_ops/in_memory.rs index 56f52533c4..1a936daa22 100644 --- a/src/daft-plan/src/physical_ops/in_memory.rs +++ b/src/daft-plan/src/physical_ops/in_memory.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use crate::{source_info::InMemoryInfo, ClusteringSpec}; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct InMemoryScan { pub schema: SchemaRef, pub in_memory_info: InMemoryInfo, diff --git a/src/daft-plan/src/physical_ops/mod.rs b/src/daft-plan/src/physical_ops/mod.rs index 8a9a79a658..9ca3e0def1 100644 --- a/src/daft-plan/src/physical_ops/mod.rs +++ b/src/daft-plan/src/physical_ops/mod.rs @@ -67,6 +67,7 @@ pub use unpivot::Unpivot; #[macro_export] /// Implement the `common_display::tree::TreeDisplay` trait for the given struct +/// /// using the `get_name` method as the compact description and the `multiline_display` method for the default and verbose descriptions. macro_rules! impl_default_tree_display { ($($struct:ident),+) => { diff --git a/src/daft-plan/src/physical_ops/scan.rs b/src/daft-plan/src/physical_ops/scan.rs index d99fc9297b..10227745b9 100644 --- a/src/daft-plan/src/physical_ops/scan.rs +++ b/src/daft-plan/src/physical_ops/scan.rs @@ -51,8 +51,7 @@ Num Scan Tasks = {num_scan_tasks} Estimated Scan Bytes = {total_bytes} Clustering spec = {{ {clustering_spec} }} " - ) - .to_string(); + ); #[cfg(feature = "python")] if let FileFormatConfig::Database(config) = scan.scan_tasks[0].file_format_config.as_ref() diff --git a/src/daft-plan/src/physical_optimization/optimizer.rs b/src/daft-plan/src/physical_optimization/optimizer.rs index 58e2d43e53..80cd2863b8 100644 --- a/src/daft-plan/src/physical_optimization/optimizer.rs +++ b/src/daft-plan/src/physical_optimization/optimizer.rs @@ -44,7 +44,7 @@ impl PhysicalOptimizer { } pub fn optimize(&self, mut plan: PhysicalPlanRef) -> DaftResult { - for batch in self.rule_batches.iter() { + for batch in &self.rule_batches { plan = batch.optimize(plan, &self.config)?; } Ok(plan) diff --git a/src/daft-plan/src/physical_optimization/plan_context.rs b/src/daft-plan/src/physical_optimization/plan_context.rs index 95f3126c08..7a0976a9ca 100644 --- a/src/daft-plan/src/physical_optimization/plan_context.rs +++ b/src/daft-plan/src/physical_optimization/plan_context.rs @@ -42,7 +42,7 @@ impl PlanContext { impl PlanContext { // Clone the context to the children pub fn propagate(mut self) -> Self { - for child in self.children.iter_mut() { + for child in &mut self.children { child.context = self.context.clone(); } self diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 615d656b92..34304719dc 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -128,7 +128,7 @@ impl PhysicalPlan { }) => clustering_spec.clone(), Self::Sample(Sample { input, .. }) => input.clustering_spec(), Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId { input, .. }) => { - input.clustering_spec().clone() + input.clustering_spec() } Self::Sort(Sort { @@ -253,7 +253,7 @@ impl PhysicalPlan { }, Self::TabularScan(TabularScan { scan_tasks, .. }) => { let mut stats = ApproxStats::empty(); - for st in scan_tasks.iter() { + for st in scan_tasks { stats.lower_bound_rows += st.num_rows().unwrap_or(0); let in_memory_size = st.estimate_in_memory_size_bytes(None); stats.lower_bound_bytes += in_memory_size.unwrap_or(0); diff --git a/src/daft-plan/src/physical_planner/mod.rs b/src/daft-plan/src/physical_planner/mod.rs index 9813a73be8..3fe1d6212d 100644 --- a/src/daft-plan/src/physical_planner/mod.rs +++ b/src/daft-plan/src/physical_planner/mod.rs @@ -20,7 +20,7 @@ pub fn logical_to_physical( ) -> DaftResult { let mut visitor = PhysicalPlanTranslator { physical_children: vec![], - cfg: cfg.clone(), + cfg, }; let _output = logical_plan.visit(&mut visitor)?; assert_eq!( diff --git a/src/daft-plan/src/physical_planner/planner.rs b/src/daft-plan/src/physical_planner/planner.rs index 5071c1bce2..837086448c 100644 --- a/src/daft-plan/src/physical_planner/planner.rs +++ b/src/daft-plan/src/physical_planner/planner.rs @@ -193,7 +193,7 @@ impl TreeNodeRewriter for QueryStagePhysicalPlanTranslator { _ => panic!("We shouldn't have any nodes that have more than 3 children"), } } else { - self.physical_children.push(translated_pplan.clone()); + self.physical_children.push(translated_pplan); Ok(Transformed::no(node)) } } diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 639c571871..c7a364c770 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -8,7 +8,10 @@ use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use common_file_formats::FileFormat; use daft_core::prelude::*; -use daft_dsl::{col, is_partition_compatible, ApproxPercentileParams, ExprRef, SketchType}; +use daft_dsl::{ + col, is_partition_compatible, AggExpr, ApproxPercentileParams, ExprRef, SketchType, +}; +use daft_functions::numeric::sqrt; use daft_scan::PhysicalScanInfo; use crate::{ @@ -216,7 +219,7 @@ pub(super) fn translate_single_logical_node( let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( input_physical, num_partitions, - by.clone(), + by, )); PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op.into())) } @@ -449,9 +452,9 @@ pub(super) fn translate_single_logical_node( left_clustering_spec.as_ref() { by.len() >= left_on.len() - && by.iter().zip(left_on.iter()).all(|(e1, e2)| e1 == e2) - // TODO(Clark): Add support for descending sort orders. - && descending.iter().all(|v| !*v) + && by.iter().zip(left_on.iter()).all(|(e1, e2)| e1 == e2) + // TODO(Clark): Add support for descending sort orders. + && descending.iter().all(|v| !*v) } else { false }; @@ -462,9 +465,9 @@ pub(super) fn translate_single_logical_node( right_clustering_spec.as_ref() { by.len() >= right_on.len() - && by.iter().zip(right_on.iter()).all(|(e1, e2)| e1 == e2) - // TODO(Clark): Add support for descending sort orders. - && descending.iter().all(|v| !*v) + && by.iter().zip(right_on.iter()).all(|(e1, e2)| e1 == e2) + // TODO(Clark): Add support for descending sort orders. + && descending.iter().all(|v| !*v) } else { false }; @@ -587,7 +590,7 @@ pub(super) fn translate_single_logical_node( std::iter::repeat(false).take(left_on.len()).collect(), num_partitions, )) - .arced() + .arced(); } if !is_right_sort_partitioned { right_physical = PhysicalPlan::Sort(Sort::new( @@ -596,7 +599,7 @@ pub(super) fn translate_single_logical_node( std::iter::repeat(false).take(right_on.len()).collect(), num_partitions, )) - .arced() + .arced(); } false }; @@ -765,8 +768,6 @@ pub fn populate_aggregation_stages( HashMap, daft_dsl::AggExpr>, Vec, ) { - use daft_dsl::AggExpr::{self, *}; - // Aggregations to apply in the first and second stages. // Semantic column name -> AggExpr let mut first_stage_aggs: HashMap, AggExpr> = HashMap::new(); @@ -774,147 +775,245 @@ pub fn populate_aggregation_stages( // Project the aggregation results to their final output names let mut final_exprs: Vec = group_by.iter().map(|e| col(e.name())).collect(); + fn add_to_stage( + f: F, + expr: ExprRef, + schema: &Schema, + stage: &mut HashMap, AggExpr>, + ) -> Arc + where + F: Fn(ExprRef) -> AggExpr, + { + let id = f(expr.clone()).semantic_id(schema).id; + let agg_expr = f(expr.alias(id.clone())); + stage.insert(id.clone(), agg_expr); + id + } + for agg_expr in aggregations { let output_name = agg_expr.name(); match agg_expr { - Count(e, mode) => { + AggExpr::Count(e, mode) => { let count_id = agg_expr.semantic_id(schema).id; - let sum_of_count_id = Sum(col(count_id.clone())).semantic_id(schema).id; + let sum_of_count_id = AggExpr::Sum(col(count_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(count_id.clone()) - .or_insert(Count(e.alias(count_id.clone()).clone(), *mode)); + .or_insert(AggExpr::Count(e.alias(count_id.clone()).clone(), *mode)); second_stage_aggs .entry(sum_of_count_id.clone()) - .or_insert(Sum(col(count_id.clone()).alias(sum_of_count_id.clone()))); + .or_insert(AggExpr::Sum( + col(count_id.clone()).alias(sum_of_count_id.clone()), + )); final_exprs.push(col(sum_of_count_id.clone()).alias(output_name)); } - Sum(e) => { + AggExpr::Sum(e) => { let sum_id = agg_expr.semantic_id(schema).id; - let sum_of_sum_id = Sum(col(sum_id.clone())).semantic_id(schema).id; + let sum_of_sum_id = AggExpr::Sum(col(sum_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(sum_id.clone()) - .or_insert(Sum(e.alias(sum_id.clone()).clone())); + .or_insert(AggExpr::Sum(e.alias(sum_id.clone()).clone())); second_stage_aggs .entry(sum_of_sum_id.clone()) - .or_insert(Sum(col(sum_id.clone()).alias(sum_of_sum_id.clone()))); + .or_insert(AggExpr::Sum( + col(sum_id.clone()).alias(sum_of_sum_id.clone()), + )); final_exprs.push(col(sum_of_sum_id.clone()).alias(output_name)); } - Mean(e) => { - let sum_id = Sum(e.clone()).semantic_id(schema).id; - let count_id = Count(e.clone(), CountMode::Valid).semantic_id(schema).id; - let sum_of_sum_id = Sum(col(sum_id.clone())).semantic_id(schema).id; - let sum_of_count_id = Sum(col(count_id.clone())).semantic_id(schema).id; + AggExpr::Mean(e) => { + let sum_id = AggExpr::Sum(e.clone()).semantic_id(schema).id; + let count_id = AggExpr::Count(e.clone(), CountMode::Valid) + .semantic_id(schema) + .id; + let sum_of_sum_id = AggExpr::Sum(col(sum_id.clone())).semantic_id(schema).id; + let sum_of_count_id = AggExpr::Sum(col(count_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(sum_id.clone()) - .or_insert(Sum(e.alias(sum_id.clone()).clone())); + .or_insert(AggExpr::Sum(e.alias(sum_id.clone()).clone())); first_stage_aggs .entry(count_id.clone()) - .or_insert(Count(e.alias(count_id.clone()).clone(), CountMode::Valid)); + .or_insert(AggExpr::Count( + e.alias(count_id.clone()).clone(), + CountMode::Valid, + )); second_stage_aggs .entry(sum_of_sum_id.clone()) - .or_insert(Sum(col(sum_id.clone()).alias(sum_of_sum_id.clone()))); + .or_insert(AggExpr::Sum( + col(sum_id.clone()).alias(sum_of_sum_id.clone()), + )); second_stage_aggs .entry(sum_of_count_id.clone()) - .or_insert(Sum(col(count_id.clone()).alias(sum_of_count_id.clone()))); + .or_insert(AggExpr::Sum( + col(count_id.clone()).alias(sum_of_count_id.clone()), + )); final_exprs.push( (col(sum_of_sum_id.clone()).div(col(sum_of_count_id.clone()))) .alias(output_name), ); } - Min(e) => { + AggExpr::Stddev(sub_expr) => { + // The stddev calculation we're performing here is: + // stddev(X) = sqrt(E(X^2) - E(X)^2) + // where X is the sub_expr. + // + // First stage, we compute `sum(X^2)`, `sum(X)` and `count(X)`. + // Second stage, we `global_sqsum := sum(sum(X^2))`, `global_sum := sum(sum(X))` and `global_count := sum(count(X))` in order to get the global versions of the first stage. + // In the final projection, we then compute `sqrt((global_sqsum / global_count) - (global_sum / global_count) ^ 2)`. + + // first stage aggregation + let sum_id = add_to_stage( + AggExpr::Sum, + sub_expr.clone(), + schema, + &mut first_stage_aggs, + ); + let sq_sum_id = add_to_stage( + |sub_expr| AggExpr::Sum(sub_expr.clone().mul(sub_expr)), + sub_expr.clone(), + schema, + &mut first_stage_aggs, + ); + let count_id = add_to_stage( + |sub_expr| AggExpr::Count(sub_expr, CountMode::Valid), + sub_expr.clone(), + schema, + &mut first_stage_aggs, + ); + + // second stage aggregation + let global_sum_id = add_to_stage( + AggExpr::Sum, + col(sum_id.clone()), + schema, + &mut second_stage_aggs, + ); + let global_sq_sum_id = add_to_stage( + AggExpr::Sum, + col(sq_sum_id.clone()), + schema, + &mut second_stage_aggs, + ); + let global_count_id = add_to_stage( + AggExpr::Sum, + col(count_id.clone()), + schema, + &mut second_stage_aggs, + ); + + // final projection + let g_sq_sum = col(global_sq_sum_id); + let g_sum = col(global_sum_id); + let g_count = col(global_count_id); + let left = g_sq_sum.div(g_count.clone()); + let right = g_sum.div(g_count); + let right = right.clone().mul(right); + let result = sqrt::sqrt(left.sub(right)).alias(output_name); + + final_exprs.push(result); + } + AggExpr::Min(e) => { let min_id = agg_expr.semantic_id(schema).id; - let min_of_min_id = Min(col(min_id.clone())).semantic_id(schema).id; + let min_of_min_id = AggExpr::Min(col(min_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(min_id.clone()) - .or_insert(Min(e.alias(min_id.clone()).clone())); + .or_insert(AggExpr::Min(e.alias(min_id.clone()).clone())); second_stage_aggs .entry(min_of_min_id.clone()) - .or_insert(Min(col(min_id.clone()).alias(min_of_min_id.clone()))); + .or_insert(AggExpr::Min( + col(min_id.clone()).alias(min_of_min_id.clone()), + )); final_exprs.push(col(min_of_min_id.clone()).alias(output_name)); } - Max(e) => { + AggExpr::Max(e) => { let max_id = agg_expr.semantic_id(schema).id; - let max_of_max_id = Max(col(max_id.clone())).semantic_id(schema).id; + let max_of_max_id = AggExpr::Max(col(max_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(max_id.clone()) - .or_insert(Max(e.alias(max_id.clone()).clone())); + .or_insert(AggExpr::Max(e.alias(max_id.clone()).clone())); second_stage_aggs .entry(max_of_max_id.clone()) - .or_insert(Max(col(max_id.clone()).alias(max_of_max_id.clone()))); + .or_insert(AggExpr::Max( + col(max_id.clone()).alias(max_of_max_id.clone()), + )); final_exprs.push(col(max_of_max_id.clone()).alias(output_name)); } - AnyValue(e, ignore_nulls) => { + AggExpr::AnyValue(e, ignore_nulls) => { let any_id = agg_expr.semantic_id(schema).id; - let any_of_any_id = AnyValue(col(any_id.clone()), *ignore_nulls) + let any_of_any_id = AggExpr::AnyValue(col(any_id.clone()), *ignore_nulls) .semantic_id(schema) .id; first_stage_aggs .entry(any_id.clone()) - .or_insert(AnyValue(e.alias(any_id.clone()).clone(), *ignore_nulls)); + .or_insert(AggExpr::AnyValue( + e.alias(any_id.clone()).clone(), + *ignore_nulls, + )); second_stage_aggs .entry(any_of_any_id.clone()) - .or_insert(AnyValue( + .or_insert(AggExpr::AnyValue( col(any_id.clone()).alias(any_of_any_id.clone()), *ignore_nulls, )); final_exprs.push(col(any_of_any_id.clone()).alias(output_name)); } - List(e) => { + AggExpr::List(e) => { let list_id = agg_expr.semantic_id(schema).id; - let concat_of_list_id = Concat(col(list_id.clone())).semantic_id(schema).id; + let concat_of_list_id = + AggExpr::Concat(col(list_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(list_id.clone()) - .or_insert(List(e.alias(list_id.clone()).clone())); + .or_insert(AggExpr::List(e.alias(list_id.clone()).clone())); second_stage_aggs .entry(concat_of_list_id.clone()) - .or_insert(Concat( + .or_insert(AggExpr::Concat( col(list_id.clone()).alias(concat_of_list_id.clone()), )); final_exprs.push(col(concat_of_list_id.clone()).alias(output_name)); } - Concat(e) => { + AggExpr::Concat(e) => { let concat_id = agg_expr.semantic_id(schema).id; - let concat_of_concat_id = Concat(col(concat_id.clone())).semantic_id(schema).id; + let concat_of_concat_id = AggExpr::Concat(col(concat_id.clone())) + .semantic_id(schema) + .id; first_stage_aggs .entry(concat_id.clone()) - .or_insert(Concat(e.alias(concat_id.clone()).clone())); + .or_insert(AggExpr::Concat(e.alias(concat_id.clone()).clone())); second_stage_aggs .entry(concat_of_concat_id.clone()) - .or_insert(Concat( + .or_insert(AggExpr::Concat( col(concat_id.clone()).alias(concat_of_concat_id.clone()), )); final_exprs.push(col(concat_of_concat_id.clone()).alias(output_name)); } - MapGroups { func, inputs } => { + AggExpr::MapGroups { func, inputs } => { let func_id = agg_expr.semantic_id(schema).id; // No first stage aggregation for MapGroups, do all the work in the second stage. second_stage_aggs .entry(func_id.clone()) - .or_insert(MapGroups { + .or_insert(AggExpr::MapGroups { func: func.clone(), - inputs: inputs.to_vec(), + inputs: inputs.clone(), }); final_exprs.push(col(output_name)); } - &ApproxPercentile(ApproxPercentileParams { + &AggExpr::ApproxPercentile(ApproxPercentileParams { child: ref e, ref percentiles, force_list_output, }) => { let percentiles = percentiles.iter().map(|p| p.0).collect::>(); let sketch_id = agg_expr.semantic_id(schema).id; - let approx_id = ApproxSketch(col(sketch_id.clone()), SketchType::DDSketch) + let approx_id = AggExpr::ApproxSketch(col(sketch_id.clone()), SketchType::DDSketch) .semantic_id(schema) .id; first_stage_aggs .entry(sketch_id.clone()) - .or_insert(ApproxSketch( + .or_insert(AggExpr::ApproxSketch( e.alias(sketch_id.clone()), SketchType::DDSketch, )); second_stage_aggs .entry(approx_id.clone()) - .or_insert(MergeSketch( + .or_insert(AggExpr::MergeSketch( col(sketch_id.clone()).alias(approx_id.clone()), SketchType::DDSketch, )); @@ -924,30 +1023,30 @@ pub fn populate_aggregation_stages( .alias(output_name), ); } - ApproxCountDistinct(e) => { + AggExpr::ApproxCountDistinct(e) => { let first_stage_id = agg_expr.semantic_id(schema).id; let second_stage_id = - MergeSketch(col(first_stage_id.clone()), SketchType::HyperLogLog) + AggExpr::MergeSketch(col(first_stage_id.clone()), SketchType::HyperLogLog) .semantic_id(schema) .id; first_stage_aggs .entry(first_stage_id.clone()) - .or_insert(ApproxSketch( + .or_insert(AggExpr::ApproxSketch( e.alias(first_stage_id.clone()), SketchType::HyperLogLog, )); second_stage_aggs .entry(second_stage_id.clone()) - .or_insert(MergeSketch( + .or_insert(AggExpr::MergeSketch( col(first_stage_id).alias(second_stage_id.clone()), SketchType::HyperLogLog, )); final_exprs.push(col(second_stage_id).alias(output_name)); } - ApproxSketch(..) => { + AggExpr::ApproxSketch(..) => { unimplemented!("User-facing approx_sketch aggregation is not implemented") } - MergeSketch(..) => { + AggExpr::MergeSketch(..) => { unimplemented!("User-facing merge_sketch aggregation is not implemented") } } @@ -992,7 +1091,7 @@ mod tests { 10 ); let logical_plan = builder.into_partitions(10)?.build(); - let physical_plan = logical_to_physical(logical_plan, cfg.clone())?; + let physical_plan = logical_to_physical(logical_plan, cfg)?; // Check that the last repartition was dropped (the last op should be the filter). assert_matches!(physical_plan.as_ref(), PhysicalPlan::Filter(_)); Ok(()) @@ -1016,7 +1115,7 @@ mod tests { 1 ); let logical_plan = builder.hash_repartition(Some(1), vec![col("a")])?.build(); - let physical_plan = logical_to_physical(logical_plan, cfg.clone())?; + let physical_plan = logical_to_physical(logical_plan, cfg)?; assert_matches!(physical_plan.as_ref(), PhysicalPlan::TabularScan(_)); Ok(()) } @@ -1172,22 +1271,28 @@ mod tests { for mult in [1, 10] { let plan = get_hash_join_plan(cfg.clone(), l_opts.scale_by(mult), r_opts.scale_by(mult))?; - if !check_physical_matches(plan, l_exp, r_exp) { - panic!( - "Failed hash join test on case ({:?}, {:?}, {}, {}) with mult {}", - l_opts, r_opts, l_exp, r_exp, mult - ); - } + assert!( + check_physical_matches(plan, l_exp, r_exp), + "Failed hash join test on case ({:?}, {:?}, {}, {}) with mult {}", + l_opts, + r_opts, + l_exp, + r_exp, + mult + ); // reversed direction let plan = get_hash_join_plan(cfg.clone(), r_opts.scale_by(mult), l_opts.scale_by(mult))?; - if !check_physical_matches(plan, r_exp, l_exp) { - panic!( - "Failed hash join test on case ({:?}, {:?}, {}, {}) with mult {}", - r_opts, l_opts, r_exp, l_exp, mult - ); - } + assert!( + check_physical_matches(plan, r_exp, l_exp), + "Failed hash join test on case ({:?}, {:?}, {}, {}) with mult {}", + r_opts, + l_opts, + r_exp, + l_exp, + mult + ); } } Ok(()) @@ -1215,7 +1320,7 @@ mod tests { assert!(check_physical_matches(physical_plan, false, true)); let physical_plan = get_hash_join_plan( - cfg.clone(), + cfg, RepartitionOptions::Good(20), RepartitionOptions::Bad(26), )?; @@ -1237,21 +1342,25 @@ mod tests { let cfg: Arc = DaftExecutionConfig::default().into(); for (l_opts, r_opts, l_exp, r_exp) in cases { let plan = get_hash_join_plan(cfg.clone(), l_opts, r_opts)?; - if !check_physical_matches(plan, l_exp, r_exp) { - panic!( - "Failed single partition hash join test on case ({:?}, {:?}, {}, {})", - l_opts, r_opts, l_exp, r_exp - ); - } + assert!( + check_physical_matches(plan, l_exp, r_exp), + "Failed single partition hash join test on case ({:?}, {:?}, {}, {})", + l_opts, + r_opts, + l_exp, + r_exp + ); // reversed direction let plan = get_hash_join_plan(cfg.clone(), r_opts, l_opts)?; - if !check_physical_matches(plan, r_exp, l_exp) { - panic!( - "Failed single partition hash join test on case ({:?}, {:?}, {}, {})", - r_opts, l_opts, r_exp, l_exp - ); - } + assert!( + check_physical_matches(plan, r_exp, l_exp), + "Failed single partition hash join test on case ({:?}, {:?}, {}, {})", + r_opts, + l_opts, + r_exp, + l_exp + ); } Ok(()) } diff --git a/src/daft-plan/src/source_info/file_info.rs b/src/daft-plan/src/source_info/file_info.rs index 555e544c00..5517ebe178 100644 --- a/src/daft-plan/src/source_info/file_info.rs +++ b/src/daft-plan/src/source_info/file_info.rs @@ -132,7 +132,7 @@ impl FileInfos { .downcast_ref::() .unwrap() .iter() - .map(|n| n.cloned()) + .map(|n| n.copied()) .collect::>(); let num_rows = table .get_column("num_rows")? @@ -142,7 +142,7 @@ impl FileInfos { .downcast_ref::() .unwrap() .iter() - .map(|n| n.cloned()) + .map(|n| n.copied()) .collect::>(); Ok(Self::new_internal(file_paths, file_sizes, num_rows)) } diff --git a/src/daft-plan/src/treenode.rs b/src/daft-plan/src/treenode.rs index 7c4e42a6da..a06237c6a9 100644 --- a/src/daft-plan/src/treenode.rs +++ b/src/daft-plan/src/treenode.rs @@ -25,7 +25,7 @@ impl DynTreeNode for LogicalPlan { { Ok(self.with_new_children(&children).arced()) } else { - Ok(self.clone()) + Ok(self) } } } @@ -50,7 +50,7 @@ impl DynTreeNode for PhysicalPlan { { Ok(self.with_new_children(&children).arced()) } else { - Ok(self.clone()) + Ok(self) } } } diff --git a/src/daft-scan/src/anonymous.rs b/src/daft-scan/src/anonymous.rs index 956ee1c639..f6ed86b5e0 100644 --- a/src/daft-scan/src/anonymous.rs +++ b/src/daft-scan/src/anonymous.rs @@ -17,6 +17,7 @@ pub struct AnonymousScanOperator { } impl AnonymousScanOperator { + #[must_use] pub fn new( files: Vec, schema: SchemaRef, @@ -87,7 +88,7 @@ impl ScanOperator for AnonymousScanOperator { let chunk_spec = rg.map(ChunkSpec::Parquet); Ok(ScanTask::new( vec![DataSource::File { - path: f.to_string(), + path: f, chunk_spec, size_bytes: None, iceberg_delete_files: None, diff --git a/src/daft-scan/src/expr_rewriter.rs b/src/daft-scan/src/expr_rewriter.rs index 25f5a9e6a2..f678ad07c1 100644 --- a/src/daft-scan/src/expr_rewriter.rs +++ b/src/daft-scan/src/expr_rewriter.rs @@ -24,7 +24,9 @@ fn unalias(expr: ExprRef) -> DaftResult { } fn apply_partitioning_expr(expr: ExprRef, pfield: &PartitionField) -> Option { - use PartitionTransform::*; + use PartitionTransform::{ + Day, Hour, IcebergBucket, IcebergTruncate, Identity, Month, Void, Year, + }; match pfield.transform { Some(Identity) => Some( pfield @@ -65,6 +67,7 @@ pub struct PredicateGroups { } impl PredicateGroups { + #[must_use] pub fn new( partition_only_filter: Vec, data_only_filter: Vec, @@ -96,7 +99,7 @@ pub fn rewrite_predicate_for_partitioning( // Predicates that only reference data columns (no partition column references) or only reference partition columns // but involve non-identity transformations. let mut data_preds: Vec = vec![]; - for e in data_split.into_iter() { + for e in data_split { let mut all_data_keys = true; let mut all_part_keys = true; let mut any_non_identity_part_keys = false; @@ -150,7 +153,7 @@ pub fn rewrite_predicate_for_partitioning( let source_to_pfield = { let mut map = HashMap::with_capacity(pfields.len()); - for pf in pfields.iter() { + for pf in pfields { if let Some(ref source_field) = pf.source_field { let prev_value = map.insert(source_field.name.as_str(), pf); if let Some(prev_value) = prev_value { @@ -162,7 +165,7 @@ pub fn rewrite_predicate_for_partitioning( }; let with_part_cols = predicate.transform(&|expr: ExprRef| { - use Operator::*; + use Operator::{Eq, Gt, GtEq, Lt, LtEq, NotEq}; match expr.as_ref() { // Binary Op for Eq // All transforms should work as is @@ -331,7 +334,7 @@ pub fn rewrite_predicate_for_partitioning( // Filter to predicate clauses that only involve partition columns. let split = split_conjuction(&with_part_cols); let mut part_preds: Vec = vec![]; - for e in split.into_iter() { + for e in split { let mut all_part_keys = true; e.apply(&mut |e: &ExprRef| { if let Expr::Column(col_name) = e.as_ref() diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index 376548f7a7..a1ff42c138 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -118,7 +118,7 @@ fn run_glob_parallel( // Construct a static-lifetime BoxStreamIterator let iterator = BoxStreamIterator { boxstream, - runtime_handle: owned_runtime.clone(), + runtime_handle: owned_runtime, }; Ok(iterator) } @@ -148,7 +148,7 @@ impl GlobScanOperator { first_glob_path, Some(1), io_client.clone(), - io_runtime.clone(), + io_runtime, Some(io_stats.clone()), file_format, )?; @@ -177,7 +177,7 @@ impl GlobScanOperator { let (schema, _metadata) = daft_parquet::read::read_parquet_schema( first_filepath.as_str(), - io_client.clone(), + io_client, Some(io_stats), ParquetSchemaInferenceOptions { coerce_int96_timestamp_unit, @@ -313,9 +313,9 @@ impl ScanOperator for GlobScanOperator { let files = run_glob_parallel( self.glob_paths.clone(), - io_client.clone(), - io_runtime.clone(), - Some(io_stats.clone()), + io_client, + io_runtime, + Some(io_stats), file_format, )?; @@ -348,7 +348,7 @@ impl ScanOperator for GlobScanOperator { let chunk_spec = row_group.map(ChunkSpec::Parquet); Ok(ScanTask::new( vec![DataSource::File { - path: path.to_string(), + path, chunk_spec, size_bytes, iceberg_delete_files: None, diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index 10cc0c6804..194250f45e 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -22,7 +22,7 @@ use serde::{Deserialize, Serialize}; mod anonymous; pub use anonymous::AnonymousScanOperator; -mod glob; +pub mod glob; use common_daft_config::DaftExecutionConfig; pub mod scan_task_iters; @@ -98,18 +98,19 @@ impl From for pyo3::PyErr { } /// Specification of a subset of a file to be read. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum ChunkSpec { /// Selection of Parquet row groups. Parquet(Vec), } impl ChunkSpec { + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; match self { Self::Parquet(chunks) => { - res.push(format!("Chunks = {:?}", chunks)); + res.push(format!("Chunks = {chunks:?}")); } } res @@ -147,6 +148,7 @@ pub enum DataSource { } impl DataSource { + #[must_use] pub fn get_path(&self) -> &str { match self { Self::File { path, .. } | Self::Database { path, .. } => path, @@ -155,6 +157,7 @@ impl DataSource { } } + #[must_use] pub fn get_parquet_metadata(&self) -> Option<&Arc> { match self { Self::File { @@ -164,6 +167,7 @@ impl DataSource { } } + #[must_use] pub fn get_chunk_spec(&self) -> Option<&ChunkSpec> { match self { Self::File { chunk_spec, .. } => chunk_spec.as_ref(), @@ -173,6 +177,7 @@ impl DataSource { } } + #[must_use] pub fn get_size_bytes(&self) -> Option { match self { Self::File { size_bytes, .. } | Self::Database { size_bytes, .. } => *size_bytes, @@ -181,6 +186,7 @@ impl DataSource { } } + #[must_use] pub fn get_metadata(&self) -> Option<&TableMetadata> { match self { Self::File { metadata, .. } | Self::Database { metadata, .. } => metadata.as_ref(), @@ -189,6 +195,7 @@ impl DataSource { } } + #[must_use] pub fn get_statistics(&self) -> Option<&TableStatistics> { match self { Self::File { statistics, .. } | Self::Database { statistics, .. } => { @@ -199,6 +206,7 @@ impl DataSource { } } + #[must_use] pub fn get_partition_spec(&self) -> Option<&PartitionSpec> { match self { Self::File { partition_spec, .. } => partition_spec.as_ref(), @@ -208,6 +216,7 @@ impl DataSource { } } + #[must_use] pub fn get_iceberg_delete_files(&self) -> Option<&Vec> { match self { Self::File { @@ -218,6 +227,7 @@ impl DataSource { } } + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; match self { @@ -231,7 +241,7 @@ impl DataSource { statistics, parquet_metadata: _, } => { - res.push(format!("Path = {}", path)); + res.push(format!("Path = {path}")); if let Some(chunk_spec) = chunk_spec { res.push(format!( "Chunk spec = {{ {} }}", @@ -239,10 +249,10 @@ impl DataSource { )); } if let Some(size_bytes) = size_bytes { - res.push(format!("Size bytes = {}", size_bytes)); + res.push(format!("Size bytes = {size_bytes}")); } if let Some(iceberg_delete_files) = iceberg_delete_files { - res.push(format!("Iceberg delete files = {:?}", iceberg_delete_files)); + res.push(format!("Iceberg delete files = {iceberg_delete_files:?}")); } if let Some(metadata) = metadata { res.push(format!( @@ -257,7 +267,7 @@ impl DataSource { )); } if let Some(statistics) = statistics { - res.push(format!("Statistics = {}", statistics)); + res.push(format!("Statistics = {statistics}")); } } Self::Database { @@ -266,9 +276,9 @@ impl DataSource { metadata, statistics, } => { - res.push(format!("Path = {}", path)); + res.push(format!("Path = {path}")); if let Some(size_bytes) = size_bytes { - res.push(format!("Size bytes = {}", size_bytes)); + res.push(format!("Size bytes = {size_bytes}")); } if let Some(metadata) = metadata { res.push(format!( @@ -277,7 +287,7 @@ impl DataSource { )); } if let Some(statistics) = statistics { - res.push(format!("Statistics = {}", statistics)); + res.push(format!("Statistics = {statistics}")); } } #[cfg(feature = "python")] @@ -292,7 +302,7 @@ impl DataSource { } => { res.push(format!("Function = {module}.{func_name}")); if let Some(size_bytes) = size_bytes { - res.push(format!("Size bytes = {}", size_bytes)); + res.push(format!("Size bytes = {size_bytes}")); } if let Some(metadata) = metadata { res.push(format!( @@ -307,7 +317,7 @@ impl DataSource { )); } if let Some(statistics) = statistics { - res.push(format!("Statistics = {}", statistics)); + res.push(format!("Statistics = {statistics}")); } } } @@ -328,7 +338,7 @@ impl DisplayAs for DataSource { Self::PythonFactoryFunction { module, func_name, .. } => { - format!("{}:{}", module, func_name) + format!("{module}:{func_name}") } } } @@ -360,6 +370,7 @@ pub struct ScanTask { pub type ScanTaskRef = Arc; impl ScanTask { + #[must_use] pub fn new( sources: Vec, file_format_config: Arc, @@ -399,8 +410,8 @@ impl ScanTask { let metadata = length.map(|l| TableMetadata { length: l }); Self { sources, - file_format_config, schema, + file_format_config, storage_config, pushdowns, size_bytes_on_disk, @@ -453,6 +464,7 @@ impl ScanTask { )) } + #[must_use] pub fn materialized_schema(&self) -> SchemaRef { match &self.pushdowns.columns { None => self.schema.clone(), @@ -469,6 +481,7 @@ impl ScanTask { } /// Obtain an accurate, exact num_rows from the ScanTask, or `None` if this is not possible + #[must_use] pub fn num_rows(&self) -> Option { if self.pushdowns.filters.is_some() { // Cannot obtain an accurate num_rows if there are filters @@ -487,6 +500,7 @@ impl ScanTask { } /// Obtain an approximate num_rows from the ScanTask, or `None` if this is not possible + #[must_use] pub fn approx_num_rows(&self, config: Option<&DaftExecutionConfig>) -> Option { let approx_total_num_rows_before_pushdowns = self .metadata @@ -531,6 +545,7 @@ impl ScanTask { } /// Obtain the absolute maximum number of rows this ScanTask can give, or None if not possible to derive + #[must_use] pub fn upper_bound_rows(&self) -> Option { self.metadata.as_ref().map(|m| { if let Some(limit) = self.pushdowns.limit { @@ -541,10 +556,12 @@ impl ScanTask { }) } + #[must_use] pub fn size_bytes_on_disk(&self) -> Option { self.size_bytes_on_disk.map(|s| s as usize) } + #[must_use] pub fn estimate_in_memory_size_bytes( &self, config: Option<&DaftExecutionConfig>, @@ -570,6 +587,7 @@ impl ScanTask { }) } + #[must_use] pub fn partition_spec(&self) -> Option<&PartitionSpec> { match self.sources.first() { None => None, @@ -577,6 +595,7 @@ impl ScanTask { } } + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; // TODO(Clark): Use above methods to display some of the more derived fields. @@ -606,7 +625,7 @@ impl ScanTask { } res.extend(self.pushdowns.multiline_display()); if let Some(size_bytes) = self.size_bytes_on_disk { - res.push(format!("Size bytes on disk = {}", size_bytes)); + res.push(format!("Size bytes on disk = {size_bytes}")); } if let Some(metadata) = &self.metadata { res.push(format!( @@ -615,7 +634,7 @@ impl ScanTask { )); } if let Some(statistics) = &self.statistics { - res.push(format!("Statistics = {}", statistics)); + res.push(format!("Statistics = {statistics}")); } res } @@ -683,8 +702,7 @@ impl PartitionField { }) } (None, Some(tfm)) => Err(DaftError::ValueError(format!( - "transform set in PartitionField: {} but source_field not set", - tfm + "transform set in PartitionField: {tfm} but source_field not set" ))), _ => Ok(Self { field, @@ -726,16 +744,19 @@ pub enum PartitionTransform { } impl PartitionTransform { + #[must_use] pub fn supports_equals(&self) -> bool { true } + #[must_use] pub fn supports_not_equals(&self) -> bool { matches!(self, Self::Identity) } + #[must_use] pub fn supports_comparison(&self) -> bool { - use PartitionTransform::*; + use PartitionTransform::{Day, Hour, IcebergTruncate, Identity, Month, Year}; matches!( self, Identity | IcebergTruncate(_) | Year | Month | Day | Hour @@ -783,7 +804,7 @@ pub struct ScanOperatorRef(pub Arc); impl Hash for ScanOperatorRef { fn hash(&self, state: &mut H) { - Arc::as_ptr(&self.0).hash(state) + Arc::as_ptr(&self.0).hash(state); } } @@ -810,6 +831,7 @@ pub struct PhysicalScanInfo { } impl PhysicalScanInfo { + #[must_use] pub fn new( scan_op: ScanOperatorRef, source_schema: SchemaRef, @@ -824,6 +846,7 @@ impl PhysicalScanInfo { } } + #[must_use] pub fn with_pushdowns(&self, pushdowns: Pushdowns) -> Self { Self { scan_op: self.scan_op.clone(), @@ -853,6 +876,7 @@ impl Default for Pushdowns { } impl Pushdowns { + #[must_use] pub fn new( filters: Option, partition_filters: Option, @@ -867,6 +891,7 @@ impl Pushdowns { } } + #[must_use] pub fn is_empty(&self) -> bool { self.filters.is_none() && self.partition_filters.is_none() @@ -874,6 +899,7 @@ impl Pushdowns { && self.limit.is_none() } + #[must_use] pub fn with_limit(&self, limit: Option) -> Self { Self { filters: self.filters.clone(), @@ -883,6 +909,7 @@ impl Pushdowns { } } + #[must_use] pub fn with_filters(&self, filters: Option) -> Self { Self { filters, @@ -892,6 +919,7 @@ impl Pushdowns { } } + #[must_use] pub fn with_partition_filters(&self, partition_filters: Option) -> Self { Self { filters: self.filters.clone(), @@ -901,6 +929,7 @@ impl Pushdowns { } } + #[must_use] pub fn with_columns(&self, columns: Option>>) -> Self { Self { filters: self.filters.clone(), @@ -910,19 +939,20 @@ impl Pushdowns { } } + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; if let Some(columns) = &self.columns { res.push(format!("Projection pushdown = [{}]", columns.join(", "))); } if let Some(filters) = &self.filters { - res.push(format!("Filter pushdown = {}", filters)); + res.push(format!("Filter pushdown = {filters}")); } if let Some(pfilters) = &self.partition_filters { - res.push(format!("Partition Filter = {}", pfilters)); + res.push(format!("Partition Filter = {pfilters}")); } if let Some(limit) = self.limit { - res.push(format!("Limit pushdown = {}", limit)); + res.push(format!("Limit pushdown = {limit}")); } res } @@ -938,13 +968,13 @@ impl DisplayAs for Pushdowns { sub_items.push(format!("projection: [{}]", columns.join(", "))); } if let Some(filters) = &self.filters { - sub_items.push(format!("filter: {}", filters)); + sub_items.push(format!("filter: {filters}")); } if let Some(pfilters) = &self.partition_filters { - sub_items.push(format!("partition_filter: {}", pfilters)); + sub_items.push(format!("partition_filter: {pfilters}")); } if let Some(limit) = self.limit { - sub_items.push(format!("limit: {}", limit)); + sub_items.push(format!("limit: {limit}")); } s.push_str(&sub_items.join(", ")); s.push('}'); @@ -974,7 +1004,7 @@ mod test { fn make_scan_task(num_sources: usize) -> ScanTask { let sources = (0..num_sources) .map(|i| DataSource::File { - path: format!("test{}", i), + path: format!("test{i}"), chunk_spec: None, size_bytes: None, iceberg_delete_files: None, diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index fac37ccb48..23114c0f85 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -20,6 +20,7 @@ impl PythonTablesFactoryArgs { Self(args.into_iter().map(PyObjectSerializableWrapper).collect()) } + #[must_use] pub fn to_pytuple<'a>(&self, py: Python<'a>) -> Bound<'a, PyTuple> { pyo3::types::PyTuple::new_bound(py, self.0.iter().map(|x| x.0.bind(py))) } @@ -321,9 +322,7 @@ pub mod pylib { // TODO(Clark): Filter out scan tasks with pushed down filters + table stats? let pspec = PartitionSpec { - keys: partition_values - .map(|p| p.table) - .unwrap_or_else(|| Table::empty(None).unwrap()), + keys: partition_values.map_or_else(|| Table::empty(None).unwrap(), |p| p.table), }; let statistics = stats .map(|s| TableStatistics::from_stats_table(&s.table)) @@ -461,7 +460,7 @@ pub mod pylib { ) -> PyResult { let p_field = PartitionField::new( field.field, - source_field.map(|f| f.into()), + source_field.map(std::convert::Into::into), transform.map(|e| e.0), )?; Ok(Self(Arc::new(p_field))) @@ -537,16 +536,19 @@ pub mod pylib { Ok(format!("{:#?}", self.0)) } #[getter] + #[must_use] pub fn limit(&self) -> Option { self.0.limit } #[getter] + #[must_use] pub fn filters(&self) -> Option { self.0.filters.as_ref().map(|e| PyExpr { expr: e.clone() }) } #[getter] + #[must_use] pub fn partition_filters(&self) -> Option { self.0 .partition_filters @@ -555,6 +557,7 @@ pub mod pylib { } #[getter] + #[must_use] pub fn columns(&self) -> Option> { self.0.columns.as_deref().cloned() } diff --git a/src/daft-scan/src/scan_task_iters.rs b/src/daft-scan/src/scan_task_iters.rs index b223ee5732..bd2054b6d4 100644 --- a/src/daft-scan/src/scan_task_iters.rs +++ b/src/daft-scan/src/scan_task_iters.rs @@ -25,6 +25,7 @@ type BoxScanTaskIter<'a> = Box> + 'a /// * `scan_tasks`: A Boxed Iterator of ScanTaskRefs to perform merging on /// * `min_size_bytes`: Minimum size in bytes of a ScanTask, after which no more merging will be performed /// * `max_size_bytes`: Maximum size in bytes of a ScanTask, capping the maximum size of a merged ScanTask +#[must_use] pub fn merge_by_sizes<'a>( scan_tasks: BoxScanTaskIter<'a>, pushdowns: &Pushdowns, @@ -35,7 +36,7 @@ pub fn merge_by_sizes<'a>( let mut scan_tasks = scan_tasks.peekable(); let first_scantask = scan_tasks .peek() - .and_then(|x| x.as_ref().map(|x| x.clone()).ok()); + .and_then(|x| x.as_ref().map(std::clone::Clone::clone).ok()); if let Some(first_scantask) = first_scantask { let estimated_bytes_for_reading_limit_rows = first_scantask .as_ref() @@ -175,6 +176,7 @@ impl<'a> Iterator for MergeByFileSize<'a> { } } +#[must_use] pub fn split_by_row_groups( scan_tasks: BoxScanTaskIter, max_tasks: usize, @@ -218,7 +220,7 @@ pub fn split_by_row_groups( .map_or(true, |s| s > max_size_bytes as u64) && source .get_iceberg_delete_files() - .map_or(true, |f| f.is_empty()) + .map_or(true, std::vec::Vec::is_empty) { let (io_runtime, io_client) = t.storage_config.get_io_client_and_runtime()?; @@ -226,7 +228,7 @@ pub fn split_by_row_groups( let path = source.get_path(); let io_stats = - IOStatsContext::new(format!("split_by_row_groups for {:#?}", path)); + IOStatsContext::new(format!("split_by_row_groups for {path:#?}")); let mut file = io_runtime.block_on_current_thread(read_parquet_metadata( path, @@ -243,7 +245,7 @@ pub fn split_by_row_groups( let row_groups = std::mem::take(&mut file.row_groups); let num_row_groups = row_groups.len(); - for (i, rg) in row_groups.into_iter() { + for (i, rg) in row_groups { curr_row_groups.push((i, rg)); let rg = &curr_row_groups.last().unwrap().1; curr_row_group_indices.push(i as i64); diff --git a/src/daft-scan/src/storage_config.rs b/src/daft-scan/src/storage_config.rs index d169e06510..9a672c8cce 100644 --- a/src/daft-scan/src/storage_config.rs +++ b/src/daft-scan/src/storage_config.rs @@ -50,6 +50,7 @@ impl StorageConfig { } } + #[must_use] pub fn var_name(&self) -> &'static str { match self { Self::Native(_) => "Native", @@ -58,6 +59,7 @@ impl StorageConfig { } } + #[must_use] pub fn multiline_display(&self) -> Vec { match self { Self::Native(source) => source.multiline_display(), @@ -76,6 +78,7 @@ pub struct NativeStorageConfig { } impl NativeStorageConfig { + #[must_use] pub fn new_internal(multithreaded_io: bool, io_config: Option) -> Self { Self { io_config, @@ -83,6 +86,7 @@ impl NativeStorageConfig { } } + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; if let Some(io_config) = &self.io_config { @@ -106,16 +110,19 @@ impl Default for NativeStorageConfig { #[pymethods] impl NativeStorageConfig { #[new] + #[must_use] pub fn new(multithreaded_io: bool, io_config: Option) -> Self { Self::new_internal(multithreaded_io, io_config.map(|c| c.config)) } #[getter] + #[must_use] pub fn io_config(&self) -> Option { - self.io_config.clone().map(|c| c.into()) + self.io_config.clone().map(std::convert::Into::into) } #[getter] + #[must_use] pub fn multithreaded_io(&self) -> bool { self.multithreaded_io } @@ -133,6 +140,7 @@ pub struct PythonStorageConfig { #[cfg(feature = "python")] impl PythonStorageConfig { + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; if let Some(io_config) = &self.io_config { @@ -149,6 +157,7 @@ impl PythonStorageConfig { #[pymethods] impl PythonStorageConfig { #[new] + #[must_use] pub fn new(io_config: Option) -> Self { Self { io_config: io_config.map(|c| c.config), @@ -156,6 +165,7 @@ impl PythonStorageConfig { } #[getter] + #[must_use] pub fn io_config(&self) -> Option { self.io_config .as_ref() @@ -176,7 +186,7 @@ impl Eq for PythonStorageConfig {} #[cfg(feature = "python")] impl Hash for PythonStorageConfig { fn hash(&self, state: &mut H) { - self.io_config.hash(state) + self.io_config.hash(state); } } @@ -207,7 +217,7 @@ impl PyStorageConfig { /// Get the underlying storage config. #[getter] fn get_config(&self, py: Python) -> PyObject { - use StorageConfig::*; + use StorageConfig::{Native, Python}; match self.0.as_ref() { Native(config) => config.as_ref().clone().into_py(py), diff --git a/src/daft-scheduler/src/adaptive.rs b/src/daft-scheduler/src/adaptive.rs index e701bd0e73..2011a9c2ba 100644 --- a/src/daft-scheduler/src/adaptive.rs +++ b/src/daft-scheduler/src/adaptive.rs @@ -16,6 +16,7 @@ pub struct AdaptivePhysicalPlanScheduler { } impl AdaptivePhysicalPlanScheduler { + #[must_use] pub fn new(logical_plan: Arc, cfg: Arc) -> Self { Self { planner: AdaptivePlanner::new(logical_plan, cfg), @@ -71,8 +72,8 @@ impl AdaptivePhysicalPlanScheduler { ); self.planner.update(MaterializedResults { - in_memory_info, source_id, + in_memory_info, })?; Ok(()) }) diff --git a/src/daft-scheduler/src/scheduler.rs b/src/daft-scheduler/src/scheduler.rs index 709dd8ff4d..eeedc4471a 100644 --- a/src/daft-scheduler/src/scheduler.rs +++ b/src/daft-scheduler/src/scheduler.rs @@ -5,9 +5,17 @@ use common_error::DaftResult; use common_file_formats::FileFormat; use common_py_serde::impl_bincode_py_state_serialization; use daft_dsl::ExprRef; +#[cfg(feature = "python")] +use daft_plan::physical_ops::{DeltaLakeWrite, IcebergWrite, LanceWrite}; use daft_plan::{ - logical_to_physical, physical_ops::*, InMemoryInfo, PhysicalPlan, PhysicalPlanRef, - QueryStageOutput, + logical_to_physical, + physical_ops::{ + ActorPoolProject, Aggregate, BroadcastJoin, Coalesce, Concat, EmptyScan, Explode, + FanoutByHash, FanoutRandom, Filter, Flatten, HashJoin, InMemoryScan, Limit, + MonotonicallyIncreasingId, Pivot, Project, ReduceMerge, Sample, Sort, SortMergeJoin, Split, + TabularScan, TabularWriteCsv, TabularWriteJson, TabularWriteParquet, Unpivot, + }, + InMemoryInfo, PhysicalPlan, PhysicalPlanRef, QueryStageOutput, }; #[cfg(feature = "python")] use daft_plan::{DeltaLakeCatalogInfo, IcebergCatalogInfo, LanceCatalogInfo}; diff --git a/src/daft-schema/src/dtype.rs b/src/daft-schema/src/dtype.rs index 65cf8f808e..00ef1083ca 100644 --- a/src/daft-schema/src/dtype.rs +++ b/src/daft-schema/src/dtype.rs @@ -7,6 +7,8 @@ use serde::{Deserialize, Serialize}; use crate::{field::Field, image_mode::ImageMode, time_unit::TimeUnit}; +pub type DaftDataType = DataType; + #[derive(Clone, Debug, Display, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum DataType { // ArrowTypes: @@ -107,8 +109,11 @@ pub enum DataType { Struct(Vec), /// A nested [`DataType`] that is represented as List>. - #[display("Map[{_0}]")] - Map(Box), + #[display("Map[{key}: {value}]")] + Map { + key: Box, + value: Box, + }, /// Extension type. #[display("{_1}")] @@ -233,14 +238,29 @@ impl DataType { Self::List(field) => Ok(ArrowType::LargeList(Box::new( arrow2::datatypes::Field::new("item", field.to_arrow()?, true), ))), - Self::Map(field) => Ok(ArrowType::Map( - Box::new(arrow2::datatypes::Field::new( - "item", - field.to_arrow()?, - true, - )), - false, - )), + Self::Map { key, value } => { + let struct_type = ArrowType::Struct(vec![ + // We never allow null keys in maps for several reasons: + // 1. Null typically represents the absence of a value, which doesn't make sense for a key. + // 2. Null comparisons can be problematic (similar to how f64::NAN != f64::NAN). + // 3. It maintains consistency with common map implementations in arrow (no null keys). + // 4. It simplifies map operations + // + // This decision aligns with the thoughts of team members like Jay and Sammy, who argue that: + // - Nulls in keys could lead to unintuitive behavior + // - If users need to count or group by null values, they can use other constructs like + // group_by operations on non-map types, which offer more explicit control. + // + // By disallowing null keys, we encourage more robust data modeling practices and + // provide a clearer semantic meaning for map types in our system. + arrow2::datatypes::Field::new("key", key.to_arrow()?, true), + arrow2::datatypes::Field::new("value", value.to_arrow()?, true), + ]); + + let struct_field = arrow2::datatypes::Field::new("entries", struct_type, true); + + Ok(ArrowType::map(struct_field, false)) + } Self::Struct(fields) => Ok({ let fields = fields .iter() @@ -288,7 +308,10 @@ impl DataType { FixedSizeList(child_dtype, size) => { FixedSizeList(Box::new(child_dtype.to_physical()), *size) } - Map(child_dtype) => List(Box::new(child_dtype.to_physical())), + Map { key, value } => List(Box::new(Struct(vec![ + Field::new("key", key.to_physical()), + Field::new("value", value.to_physical()), + ]))), Embedding(dtype, size) => FixedSizeList(Box::new(dtype.to_physical()), *size), Image(mode) => Struct(vec![ Field::new( @@ -328,20 +351,6 @@ impl DataType { } } - #[inline] - pub fn nested_dtype(&self) -> Option<&Self> { - match self { - Self::Map(dtype) - | Self::List(dtype) - | Self::FixedSizeList(dtype, _) - | Self::FixedShapeTensor(dtype, _) - | Self::SparseTensor(dtype) - | Self::FixedShapeSparseTensor(dtype, _) - | Self::Tensor(dtype) => Some(dtype), - _ => None, - } - } - #[inline] pub fn is_arrow(&self) -> bool { self.to_arrow().is_ok() @@ -350,21 +359,33 @@ impl DataType { #[inline] pub fn is_numeric(&self) -> bool { match self { - Self::Int8 - | Self::Int16 - | Self::Int32 - | Self::Int64 - | Self::Int128 - | Self::UInt8 - | Self::UInt16 - | Self::UInt32 - | Self::UInt64 - // DataType::Float16 - | Self::Float32 - | Self::Float64 => true, - Self::Extension(_, inner, _) => inner.is_numeric(), - _ => false - } + Self::Int8 + | Self::Int16 + | Self::Int32 + | Self::Int64 + | Self::Int128 + | Self::UInt8 + | Self::UInt16 + | Self::UInt32 + | Self::UInt64 + // DataType::Float16 + | Self::Float32 + | Self::Float64 => true, + Self::Extension(_, inner, _) => inner.is_numeric(), + _ => false + } + } + + #[inline] + pub fn assert_is_numeric(&self) -> DaftResult<()> { + if self.is_numeric() { + Ok(()) + } else { + Err(DaftError::TypeError(format!( + "Numeric mean is not implemented for type {}", + self, + ))) + } } #[inline] @@ -453,7 +474,7 @@ impl DataType { #[inline] pub fn is_map(&self) -> bool { - matches!(self, Self::Map(..)) + matches!(self, Self::Map { .. }) } #[inline] @@ -576,7 +597,7 @@ impl DataType { | Self::FixedShapeTensor(..) | Self::SparseTensor(..) | Self::FixedShapeSparseTensor(..) - | Self::Map(..) + | Self::Map { .. } ) } @@ -590,7 +611,7 @@ impl DataType { let p: Self = self.to_physical(); matches!( p, - Self::List(..) | Self::FixedSizeList(..) | Self::Struct(..) | Self::Map(..) + Self::List(..) | Self::FixedSizeList(..) | Self::Struct(..) | Self::Map { .. } ) } @@ -605,9 +626,13 @@ impl DataType { } } +#[expect( + clippy::fallible_impl_from, + reason = "https://github.com/Eventual-Inc/Daft/issues/3015" +)] impl From<&ArrowType> for DataType { fn from(item: &ArrowType) -> Self { - match item { + let result = match item { ArrowType::Null => Self::Null, ArrowType::Boolean => Self::Boolean, ArrowType::Int8 => Self::Int8, @@ -638,7 +663,29 @@ impl From<&ArrowType> for DataType { ArrowType::FixedSizeList(field, size) => { Self::FixedSizeList(Box::new(field.as_ref().data_type().into()), *size) } - ArrowType::Map(field, ..) => Self::Map(Box::new(field.as_ref().data_type().into())), + ArrowType::Map(field, ..) => { + // todo: TryFrom in future? want in second pass maybe + + // field should be a struct + let ArrowType::Struct(fields) = &field.data_type else { + panic!("Map should have a struct as its key") + }; + + let [key, value] = fields.as_slice() else { + panic!("Map should have two fields") + }; + + let key = &key.data_type; + let value = &value.data_type; + + let key = Self::from(key); + let value = Self::from(value); + + let key = Box::new(key); + let value = Box::new(value); + + Self::Map { key, value } + } ArrowType::Struct(fields) => { let fields: Vec = fields.iter().map(|fld| fld.into()).collect(); Self::Struct(fields) @@ -659,7 +706,9 @@ impl From<&ArrowType> for DataType { } _ => panic!("DataType :{item:?} is not supported"), - } + }; + + result } } diff --git a/src/daft-schema/src/field.rs b/src/daft-schema/src/field.rs index 774545fee4..f4cd6ecb16 100644 --- a/src/daft-schema/src/field.rs +++ b/src/daft-schema/src/field.rs @@ -18,6 +18,7 @@ pub struct Field { } pub type FieldRef = Arc; +pub type DaftField = Field; #[derive(Clone, Display, Debug, PartialEq, Eq, Deserialize, Serialize, Hash)] #[display("{id}")] @@ -87,6 +88,14 @@ impl Field { ) } + pub fn to_physical(&self) -> Self { + Self { + name: self.name.clone(), + dtype: self.dtype.to_physical(), + metadata: self.metadata.clone(), + } + } + pub fn rename>(&self, name: S) -> Self { Self { name: name.into(), diff --git a/src/daft-schema/src/image_format.rs b/src/daft-schema/src/image_format.rs index 93ec40963e..0aeb8432de 100644 --- a/src/daft-schema/src/image_format.rs +++ b/src/daft-schema/src/image_format.rs @@ -39,7 +39,7 @@ impl ImageFormat { impl ImageFormat { pub fn iterator() -> std::slice::Iter<'static, Self> { - use ImageFormat::*; + use ImageFormat::{BMP, GIF, JPEG, PNG, TIFF}; static FORMATS: [ImageFormat; 5] = [PNG, JPEG, TIFF, GIF, BMP]; FORMATS.iter() @@ -50,7 +50,7 @@ impl FromStr for ImageFormat { type Err = DaftError; fn from_str(format: &str) -> DaftResult { - use ImageFormat::*; + use ImageFormat::{BMP, GIF, JPEG, PNG, TIFF}; match format { "PNG" => Ok(PNG), diff --git a/src/daft-schema/src/image_mode.rs b/src/daft-schema/src/image_mode.rs index 9b41875ff0..e75e90bf28 100644 --- a/src/daft-schema/src/image_mode.rs +++ b/src/daft-schema/src/image_mode.rs @@ -1,3 +1,4 @@ +#![expect(non_local_definitions, reason = "we want to remove this...")] use std::str::FromStr; use common_error::{DaftError, DaftResult}; @@ -65,7 +66,7 @@ impl ImageMode { impl ImageMode { pub fn from_pil_mode_str(mode: &str) -> DaftResult { - use ImageMode::*; + use ImageMode::{L, LA, RGB, RGBA}; match mode { "L" => Ok(L), @@ -85,7 +86,7 @@ impl ImageMode { } } pub fn try_from_num_channels(num_channels: u16, dtype: &DataType) -> DaftResult { - use ImageMode::*; + use ImageMode::{L, L16, LA, LA16, RGB, RGB16, RGB32F, RGBA, RGBA16, RGBA32F}; match (num_channels, dtype) { (1, DataType::UInt8) => Ok(L), @@ -99,13 +100,13 @@ impl ImageMode { (4, DataType::UInt16) => Ok(RGBA16), (4, DataType::Float32) => Ok(RGBA32F), (_, _) => Err(DaftError::ValueError(format!( - "Images with more than {} channels and dtype {} are not supported", - num_channels, dtype, + "Images with more than {num_channels} channels and dtype {dtype} are not supported", ))), } } + #[must_use] pub fn num_channels(&self) -> u16 { - use ImageMode::*; + use ImageMode::{L, L16, LA, LA16, RGB, RGB16, RGB32F, RGBA, RGBA16, RGBA32F}; match self { L | L16 => 1, @@ -115,12 +116,13 @@ impl ImageMode { } } pub fn iterator() -> std::slice::Iter<'static, Self> { - use ImageMode::*; + use ImageMode::{L, L16, LA, LA16, RGB, RGB16, RGB32F, RGBA, RGBA16, RGBA32F}; static MODES: [ImageMode; 10] = [L, LA, RGB, RGBA, L16, LA16, RGB16, RGBA16, RGB32F, RGBA32F]; MODES.iter() } + #[must_use] pub fn get_dtype(&self) -> DataType { self.into() } @@ -130,7 +132,7 @@ impl FromStr for ImageMode { type Err = DaftError; fn from_str(mode: &str) -> DaftResult { - use ImageMode::*; + use ImageMode::{L, L16, LA, LA16, RGB, RGB16, RGB32F, RGBA, RGBA16, RGBA32F}; match mode { "L" => Ok(L), diff --git a/src/daft-schema/src/python/datatype.rs b/src/daft-schema/src/python/datatype.rs index ceff5e18f3..2c1b0eba11 100644 --- a/src/daft-schema/src/python/datatype.rs +++ b/src/daft-schema/src/python/datatype.rs @@ -53,6 +53,7 @@ impl PyTimeUnit { _ => Err(pyo3::exceptions::PyNotImplementedError::new_err(())), } } + #[must_use] pub fn __hash__(&self) -> u64 { use std::{ collections::hash_map::DefaultHasher, @@ -145,8 +146,7 @@ impl PyDataType { pub fn fixed_size_binary(size: i64) -> PyResult { if size <= 0 { return Err(PyValueError::new_err(format!( - "The size for fixed-size binary types must be a positive integer, but got: {}", - size + "The size for fixed-size binary types must be a positive integer, but got: {size}" ))); } Ok(DataType::FixedSizeBinary(usize::try_from(size)?).into()) @@ -200,8 +200,7 @@ impl PyDataType { pub fn fixed_size_list(data_type: Self, size: i64) -> PyResult { if size <= 0 { return Err(PyValueError::new_err(format!( - "The size for fixed-size list types must be a positive integer, but got: {}", - size + "The size for fixed-size list types must be a positive integer, but got: {size}" ))); } Ok(DataType::FixedSizeList(Box::new(data_type.dtype), usize::try_from(size)?).into()) @@ -209,14 +208,15 @@ impl PyDataType { #[staticmethod] pub fn map(key_type: Self, value_type: Self) -> PyResult { - Ok(DataType::Map(Box::new(DataType::Struct(vec![ - Field::new("key", key_type.dtype), - Field::new("value", value_type.dtype), - ]))) + Ok(DataType::Map { + key: Box::new(key_type.dtype), + value: Box::new(value_type.dtype), + } .into()) } #[staticmethod] + #[must_use] pub fn r#struct(fields: IndexMap) -> Self { DataType::Struct( fields @@ -236,7 +236,7 @@ impl PyDataType { Ok(DataType::Extension( name.to_string(), Box::new(storage_data_type.dtype), - metadata.map(|s| s.to_string()), + metadata.map(std::string::ToString::to_string), ) .into()) } @@ -245,8 +245,7 @@ impl PyDataType { pub fn embedding(data_type: Self, size: i64) -> PyResult { if size <= 0 { return Err(PyValueError::new_err(format!( - "The size for embedding types must be a positive integer, but got: {}", - size + "The size for embedding types must be a positive integer, but got: {size}" ))); } if !data_type.dtype.is_numeric() { @@ -267,13 +266,13 @@ impl PyDataType { ) -> PyResult { match (height, width) { (Some(height), Some(width)) => { - let image_mode = mode.ok_or(PyValueError::new_err( + let image_mode = mode.ok_or_else(|| PyValueError::new_err( "Image mode must be provided if specifying an image size.", ))?; Ok(DataType::FixedShapeImage(image_mode, height, width).into()) } (None, None) => Ok(DataType::Image(mode).into()), - (_, _) => Err(PyValueError::new_err(format!("Height and width for image type must both be specified or both not specified, but got: height={:?}, width={:?}", height, width))), + (_, _) => Err(PyValueError::new_err(format!("Height and width for image type must both be specified or both not specified, but got: height={height:?}, width={width:?}"))), } } @@ -408,6 +407,7 @@ impl PyDataType { Ok(DataType::from_json(serialized)?.into()) } + #[must_use] pub fn __hash__(&self) -> u64 { use std::{ collections::hash_map::DefaultHasher, diff --git a/src/daft-schema/src/python/schema.rs b/src/daft-schema/src/python/schema.rs index 3a13583ba8..bacc8cc8cf 100644 --- a/src/daft-schema/src/python/schema.rs +++ b/src/daft-schema/src/python/schema.rs @@ -42,6 +42,7 @@ impl PySchema { .call1((pyarrow_fields,)) } + #[must_use] pub fn names(&self) -> Vec { self.schema.names() } diff --git a/src/daft-schema/src/schema.rs b/src/daft-schema/src/schema.rs index 04c0d88c71..d220897228 100644 --- a/src/daft-schema/src/schema.rs +++ b/src/daft-schema/src/schema.rs @@ -29,15 +29,19 @@ pub struct Schema { impl Schema { pub fn new(fields: Vec) -> DaftResult { - let mut map: IndexMap = indexmap::IndexMap::new(); - - for f in fields.into_iter() { - let old = map.insert(f.name.clone(), f); - if let Some(item) = old { - return Err(DaftError::ValueError(format!( - "Attempting to make a Schema with duplicate field names: {}", - item.name - ))); + let mut map = IndexMap::new(); + + for f in fields { + match map.entry(f.name.clone()) { + indexmap::map::Entry::Vacant(entry) => { + entry.insert(f); + } + indexmap::map::Entry::Occupied(entry) => { + return Err(DaftError::ValueError(format!( + "Attempting to make a Schema with duplicate field names: {}", + entry.key() + ))); + } } } @@ -47,7 +51,7 @@ impl Schema { pub fn exclude>(&self, names: &[S]) -> DaftResult { let mut fields = IndexMap::new(); let names = names.iter().map(|s| s.as_ref()).collect::>(); - for (name, field) in self.fields.iter() { + for (name, field) in &self.fields { if !names.contains(&name.as_str()) { fields.insert(name.clone(), field.clone()); } @@ -257,7 +261,7 @@ impl Schema { impl Hash for Schema { fn hash(&self, state: &mut H) { - state.write_u64(hash_index_map(&self.fields)) + state.write_u64(hash_index_map(&self.fields)); } } diff --git a/src/daft-schema/src/time_unit.rs b/src/daft-schema/src/time_unit.rs index d4b17b0e7c..8f34409271 100644 --- a/src/daft-schema/src/time_unit.rs +++ b/src/daft-schema/src/time_unit.rs @@ -1,4 +1,7 @@ +use std::str::FromStr; + use arrow2::datatypes::TimeUnit as ArrowTimeUnit; +use common_error::DaftError; use derive_more::Display; use serde::{Deserialize, Serialize}; @@ -14,6 +17,7 @@ pub enum TimeUnit { impl TimeUnit { #![allow(clippy::wrong_self_convention)] + #[must_use] pub fn to_arrow(&self) -> ArrowTimeUnit { match self { Self::Nanoseconds => ArrowTimeUnit::Nanosecond, @@ -23,6 +27,7 @@ impl TimeUnit { } } + #[must_use] pub fn to_scale_factor(&self) -> i64 { match self { Self::Seconds => 1, @@ -33,6 +38,19 @@ impl TimeUnit { } } +impl FromStr for TimeUnit { + type Err = DaftError; + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "ns" | "nanoseconds" => Ok(Self::Nanoseconds), + "us" | "microseconds" => Ok(Self::Microseconds), + "ms" | "milliseconds" => Ok(Self::Milliseconds), + "s" | "seconds" => Ok(Self::Seconds), + _ => Err(DaftError::ValueError("Invalid time unit".to_string())), + } + } +} + impl From<&ArrowTimeUnit> for TimeUnit { fn from(tu: &ArrowTimeUnit) -> Self { match tu { @@ -44,6 +62,7 @@ impl From<&ArrowTimeUnit> for TimeUnit { } } +#[must_use] pub fn infer_timeunit_from_format_string(format: &str) -> TimeUnit { if format.contains("%9f") || format.contains("%.9f") { TimeUnit::Nanoseconds diff --git a/src/daft-sketch/src/arrow2_serde.rs b/src/daft-sketch/src/arrow2_serde.rs index 4213d0180c..32b10cc30c 100644 --- a/src/daft-sketch/src/arrow2_serde.rs +++ b/src/daft-sketch/src/arrow2_serde.rs @@ -16,10 +16,10 @@ enum Error { impl From for DaftError { fn from(value: Error) -> Self { - use Error::*; + use Error::DeserializationError; match value { DeserializationError { source } => { - Self::ComputeError(format!("Deserialization error: {}", source)) + Self::ComputeError(format!("Deserialization error: {source}")) } } } @@ -37,6 +37,7 @@ lazy_static! { } /// Converts a Vec> into an arrow2 Array +#[must_use] pub fn into_arrow2(sketches: Vec>) -> Box { if sketches.is_empty() { return arrow2::array::StructArray::new_empty(ARROW2_DDSKETCH_DTYPE.clone()).to_boxed(); @@ -64,7 +65,7 @@ pub fn from_arrow2( item_vec .map(|item_vec| item_vec.into_iter().map(|item| item.0).collect()) .with_context(|_| DeserializationSnafu {}) - .map_err(|e| e.into()) + .map_err(std::convert::Into::into) } #[cfg(test)] @@ -79,7 +80,7 @@ mod tests { let mut sketch = DDSketch::new(Config::default()); for i in 0..10 { - sketch.add(i as f64); + sketch.add(f64::from(i)); } let expected_min = sketch.min(); diff --git a/src/daft-sql/Cargo.toml b/src/daft-sql/Cargo.toml index 2b80dda42c..81d7d36ff0 100644 --- a/src/daft-sql/Cargo.toml +++ b/src/daft-sql/Cargo.toml @@ -1,6 +1,7 @@ [dependencies] common-daft-config = {path = "../common/daft-config"} common-error = {path = "../common/error"} +common-io-config = {path = "../common/io-config", default-features = false} daft-core = {path = "../daft-core"} daft-dsl = {path = "../daft-dsl"} daft-functions = {path = "../daft-functions"} diff --git a/src/daft-sql/src/catalog.rs b/src/daft-sql/src/catalog.rs index 3495d32703..f4a08e6230 100644 --- a/src/daft-sql/src/catalog.rs +++ b/src/daft-sql/src/catalog.rs @@ -10,6 +10,7 @@ pub struct SQLCatalog { impl SQLCatalog { /// Create an empty catalog + #[must_use] pub fn new() -> Self { Self { tables: HashMap::new(), @@ -22,13 +23,14 @@ impl SQLCatalog { } /// Get a table from the catalog + #[must_use] pub fn get_table(&self, name: &str) -> Option { self.tables.get(name).cloned() } /// Copy from another catalog, using tables from other in case of conflict pub fn copy_from(&mut self, other: &Self) { - for (name, plan) in other.tables.iter() { + for (name, plan) in &other.tables { self.tables.insert(name.clone(), plan.clone()); } } diff --git a/src/daft-sql/src/error.rs b/src/daft-sql/src/error.rs index 31f8a400ed..7cfa8428aa 100644 --- a/src/daft-sql/src/error.rs +++ b/src/daft-sql/src/error.rs @@ -12,6 +12,8 @@ pub enum PlannerError { ParseError { message: String }, #[snafu(display("Invalid operation: {message}"))] InvalidOperation { message: String }, + #[snafu(display("Invalid argument ({message}) for function '{function}'"))] + InvalidFunctionArgument { message: String, function: String }, #[snafu(display("Table not found: {message}"))] TableNotFound { message: String }, #[snafu(display("Column {column_name} not found in {relation}"))] @@ -57,6 +59,7 @@ impl PlannerError { } } + #[must_use] pub fn unsupported_sql(sql: String) -> Self { Self::UnsupportedSQL { message: sql } } @@ -66,6 +69,13 @@ impl PlannerError { message: message.into(), } } + + pub fn invalid_argument, F: Into>(arg: S, function: F) -> Self { + Self::InvalidFunctionArgument { + message: arg.into(), + function: function.into(), + } + } } #[macro_export] diff --git a/src/daft-sql/src/functions.rs b/src/daft-sql/src/functions.rs index 6172d2d382..db35adf141 100644 --- a/src/daft-sql/src/functions.rs +++ b/src/daft-sql/src/functions.rs @@ -1,6 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use daft_dsl::ExprRef; +use hashing::SQLModuleHashing; use once_cell::sync::Lazy; use sqlparser::ast::{ Function, FunctionArg, FunctionArgExpr, FunctionArgOperator, FunctionArguments, @@ -8,7 +9,11 @@ use sqlparser::ast::{ use crate::{ error::{PlannerError, SQLPlannerResult}, - modules::*, + modules::{ + hashing, SQLModule, SQLModuleAggs, SQLModuleConfig, SQLModuleFloat, SQLModuleImage, + SQLModuleJson, SQLModuleList, SQLModuleMap, SQLModuleNumeric, SQLModulePartitioning, + SQLModulePython, SQLModuleSketch, SQLModuleStructs, SQLModuleTemporal, SQLModuleUtf8, + }, planner::SQLPlanner, unsupported_sql_err, }; @@ -18,6 +23,7 @@ pub(crate) static SQL_FUNCTIONS: Lazy = Lazy::new(|| { let mut functions = SQLFunctions::new(); functions.register::(); functions.register::(); + functions.register::(); functions.register::(); functions.register::(); functions.register::(); @@ -29,6 +35,7 @@ pub(crate) static SQL_FUNCTIONS: Lazy = Lazy::new(|| { functions.register::(); functions.register::(); functions.register::(); + functions.register::(); functions }); @@ -82,6 +89,16 @@ pub trait SQLFunction: Send + Sync { } fn to_expr(&self, inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult; + + /// Produce the docstrings for this SQL function, parametrized by an alias which is the function name to invoke this in SQL + fn docstrings(&self, alias: &str) -> String { + format!("{alias}: No docstring available") + } + + /// Produce the docstrings for this SQL function, parametrized by an alias which is the function name to invoke this in SQL + fn arg_names(&self) -> &'static [&'static str] { + &["todo"] + } } /// TODOs @@ -89,6 +106,7 @@ pub trait SQLFunction: Send + Sync { /// - Add more functions.. pub struct SQLFunctions { pub(crate) map: HashMap>, + pub(crate) docsmap: HashMap, } pub(crate) struct SQLFunctionArguments { @@ -97,19 +115,86 @@ pub(crate) struct SQLFunctionArguments { } impl SQLFunctionArguments { - pub fn get_unnamed(&self, idx: usize) -> Option<&ExprRef> { + pub fn get_positional(&self, idx: usize) -> Option<&ExprRef> { self.positional.get(&idx) } pub fn get_named(&self, name: &str) -> Option<&ExprRef> { self.named.get(name) } + + pub fn try_get_named(&self, name: &str) -> Result, PlannerError> { + self.named + .get(name) + .map(|expr| T::from_expr(expr)) + .transpose() + } + pub fn try_get_positional(&self, idx: usize) -> Result, PlannerError> { + self.positional + .get(&idx) + .map(|expr| T::from_expr(expr)) + .transpose() + } +} + +pub trait SQLLiteral { + fn from_expr(expr: &ExprRef) -> Result + where + Self: Sized; +} + +impl SQLLiteral for String { + fn from_expr(expr: &ExprRef) -> Result + where + Self: Sized, + { + let e = expr + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| PlannerError::invalid_operation("Expected a string literal"))?; + Ok(e.to_string()) + } +} + +impl SQLLiteral for i64 { + fn from_expr(expr: &ExprRef) -> Result + where + Self: Sized, + { + expr.as_literal() + .and_then(daft_dsl::LiteralValue::as_i64) + .ok_or_else(|| PlannerError::invalid_operation("Expected an integer literal")) + } +} + +impl SQLLiteral for usize { + fn from_expr(expr: &ExprRef) -> Result + where + Self: Sized, + { + expr.as_literal() + .and_then(|lit| lit.as_i64().map(|v| v as Self)) + .ok_or_else(|| PlannerError::invalid_operation("Expected an integer literal")) + } +} + +impl SQLLiteral for bool { + fn from_expr(expr: &ExprRef) -> Result + where + Self: Sized, + { + expr.as_literal() + .and_then(daft_dsl::LiteralValue::as_bool) + .ok_or_else(|| PlannerError::invalid_operation("Expected a boolean literal")) + } } impl SQLFunctions { /// Create a new [SQLFunctions] instance. + #[must_use] pub fn new() -> Self { Self { map: HashMap::new(), + docsmap: HashMap::new(), } } @@ -120,10 +205,13 @@ impl SQLFunctions { /// Add a [FunctionExpr] to the [SQLFunctions] instance. pub fn add_fn(&mut self, name: &str, func: F) { + self.docsmap + .insert(name.to_string(), (func.docstrings(name), func.arg_names())); self.map.insert(name.to_string(), Arc::new(func)); } /// Get a function by name from the [SQLFunctions] instance. + #[must_use] pub fn get(&self, name: &str) -> Option<&Arc> { self.map.get(name) } @@ -194,6 +282,15 @@ impl SQLPlanner { where T: TryFrom, { + self.parse_function_args(args, expected_named, expected_positional)? + .try_into() + } + pub(crate) fn parse_function_args( + &self, + args: &[FunctionArg], + expected_named: &'static [&'static str], + expected_positional: usize, + ) -> SQLPlannerResult { let mut positional_args = HashMap::new(); let mut named_args = HashMap::new(); for (idx, arg) in args.iter().enumerate() { @@ -214,15 +311,14 @@ impl SQLPlanner { } positional_args.insert(idx, self.try_unwrap_function_arg_expr(arg)?); } - _ => unsupported_sql_err!("unsupported function argument type"), + other => unsupported_sql_err!("unsupported function argument type: {other}, valid function arguments for this function are: {expected_named:?}."), } } - SQLFunctionArguments { + Ok(SQLFunctionArguments { positional: positional_args, named: named_args, - } - .try_into() + }) } pub(crate) fn plan_function_arg( @@ -235,7 +331,10 @@ impl SQLPlanner { } } - fn try_unwrap_function_arg_expr(&self, expr: &FunctionArgExpr) -> SQLPlannerResult { + pub(crate) fn try_unwrap_function_arg_expr( + &self, + expr: &FunctionArgExpr, + ) -> SQLPlannerResult { match expr { FunctionArgExpr::Expr(expr) => self.plan_expr(expr), _ => unsupported_sql_err!("Wildcard function args not yet supported"), diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 87d658660d..7d472afa9c 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -3,9 +3,9 @@ pub mod error; pub mod functions; mod modules; mod planner; - #[cfg(feature = "python")] pub mod python; +mod table_provider; #[cfg(feature = "python")] use pyo3::prelude::*; @@ -138,13 +138,13 @@ mod tests { #[case::from("select tbl2.text from tbl2")] #[case::using("select tbl2.text from tbl2 join tbl3 using (id)")] #[case( - r#" + r" select abs(i32) as abs, ceil(i32) as ceil, floor(i32) as floor, sign(i32) as sign - from tbl1"# + from tbl1" )] #[case("select round(i32, 1) from tbl1")] #[case::groupby("select max(i32) from tbl1 group by utf8")] @@ -156,7 +156,7 @@ mod tests { #[case::globalagg("select max(i32) from tbl1")] fn test_compiles(mut planner: SQLPlanner, #[case] query: &str) -> SQLPlannerResult<()> { let plan = planner.plan_sql(query); - assert!(plan.is_ok(), "query: {}\nerror: {:?}", query, plan); + assert!(plan.is_ok(), "query: {query}\nerror: {plan:?}"); Ok(()) } @@ -295,7 +295,7 @@ mod tests { #[case::starts_with("select starts_with(utf8, 'a') as starts_with from tbl1")] #[case::contains("select contains(utf8, 'a') as contains from tbl1")] #[case::split("select split(utf8, '.') as split from tbl1")] - #[case::replace("select replace(utf8, 'a', 'b') as replace from tbl1")] + #[case::replace("select regexp_replace(utf8, 'a', 'b') as replace from tbl1")] #[case::length("select length(utf8) as length from tbl1")] #[case::lower("select lower(utf8) as lower from tbl1")] #[case::upper("select upper(utf8) as upper from tbl1")] @@ -317,7 +317,7 @@ mod tests { // #[case::to_datetime("select to_datetime(utf8, 'YYYY-MM-DD') as to_datetime from tbl1")] fn test_compiles_funcs(mut planner: SQLPlanner, #[case] query: &str) -> SQLPlannerResult<()> { let plan = planner.plan_sql(query); - assert!(plan.is_ok(), "query: {}\nerror: {:?}", query, plan); + assert!(plan.is_ok(), "query: {query}\nerror: {plan:?}"); Ok(()) } diff --git a/src/daft-sql/src/modules/aggs.rs b/src/daft-sql/src/modules/aggs.rs index 695d3c9c79..7e8ceb5fcb 100644 --- a/src/daft-sql/src/modules/aggs.rs +++ b/src/daft-sql/src/modules/aggs.rs @@ -16,7 +16,7 @@ pub struct SQLModuleAggs; impl SQLModule for SQLModuleAggs { fn register(parent: &mut SQLFunctions) { - use AggExpr::*; + use AggExpr::{Count, Max, Mean, Min, Sum}; // HACK TO USE AggExpr as an enum rather than a let nil = Arc::new(Expr::Literal(LiteralValue::Null)); parent.add_fn( @@ -27,7 +27,7 @@ impl SQLModule for SQLModuleAggs { parent.add_fn("avg", Mean(nil.clone())); parent.add_fn("mean", Mean(nil.clone())); parent.add_fn("min", Min(nil.clone())); - parent.add_fn("max", Max(nil.clone())); + parent.add_fn("max", Max(nil)); } } @@ -41,6 +41,26 @@ impl SQLFunction for AggExpr { to_expr(self, inputs.as_slice()) } } + + fn docstrings(&self, alias: &str) -> String { + match self { + Self::Count(_, _) => static_docs::COUNT_DOCSTRING.to_string(), + Self::Sum(_) => static_docs::SUM_DOCSTRING.to_string(), + Self::Mean(_) => static_docs::AVG_DOCSTRING.replace("{}", alias), + Self::Min(_) => static_docs::MIN_DOCSTRING.to_string(), + Self::Max(_) => static_docs::MAX_DOCSTRING.to_string(), + e => unimplemented!("Need to implement docstrings for {e}"), + } + } + + fn arg_names(&self) -> &'static [&'static str] { + match self { + Self::Count(_, _) | Self::Sum(_) | Self::Mean(_) | Self::Min(_) | Self::Max(_) => { + &["input"] + } + e => unimplemented!("Need to implement arg names for {e}"), + } + } } fn handle_count(inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult { @@ -74,7 +94,7 @@ fn handle_count(inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResul }) } -pub(crate) fn to_expr(expr: &AggExpr, args: &[ExprRef]) -> SQLPlannerResult { +pub fn to_expr(expr: &AggExpr, args: &[ExprRef]) -> SQLPlannerResult { match expr { AggExpr::Count(_, _) => unreachable!("count should be handled by by this point"), AggExpr::Sum(_) => { @@ -89,6 +109,10 @@ pub(crate) fn to_expr(expr: &AggExpr, args: &[ExprRef]) -> SQLPlannerResult { + ensure!(args.len() == 1, "stddev takes exactly one argument"); + Ok(args[0].clone().stddev()) + } AggExpr::Min(_) => { ensure!(args.len() == 1, "min takes exactly one argument"); Ok(args[0].clone().min()) @@ -103,3 +127,201 @@ pub(crate) fn to_expr(expr: &AggExpr, args: &[ExprRef]) -> SQLPlannerResult unsupported_sql_err!("map_groups"), } } + +mod static_docs { + pub(crate) const COUNT_DOCSTRING: &str = + "Counts the number of non-null elements in the input expression. + +Example: + +.. code-block:: sql + :caption: SQL + + SELECT count(x) FROM tbl + +.. code-block:: text + :caption: Input + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 100 │ + ├╌╌╌╌╌╌╌┤ + │ 200 │ + ├╌╌╌╌╌╌╌┤ + │ null │ + ╰───────╯ + (Showing first 3 of 3 rows) + +.. code-block:: text + :caption: Output + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 1 │ + ╰───────╯ + (Showing first 1 of 1 rows)"; + + pub(crate) const SUM_DOCSTRING: &str = + "Calculates the sum of non-null elements in the input expression. + +Example: + +.. code-block:: sql + :caption: SQL + + SELECT sum(x) FROM tbl + +.. code-block:: text + :caption: Input + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 100 │ + ├╌╌╌╌╌╌╌┤ + │ 200 │ + ├╌╌╌╌╌╌╌┤ + │ null │ + ╰───────╯ + (Showing first 3 of 3 rows) + +.. code-block:: text + :caption: Output + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 300 │ + ╰───────╯ + (Showing first 1 of 1 rows)"; + + pub(crate) const AVG_DOCSTRING: &str = + "Calculates the average (mean) of non-null elements in the input expression. + +.. seealso:: + This SQL Function has aliases. + + * :func:`~daft.sql._sql_funcs.mean` + * :func:`~daft.sql._sql_funcs.avg` + +Example: + +.. code-block:: sql + :caption: SQL + + SELECT {}(x) FROM tbl + +.. code-block:: text + :caption: Input + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 100 │ + ├╌╌╌╌╌╌╌┤ + │ 200 │ + ├╌╌╌╌╌╌╌┤ + │ null │ + ╰───────╯ + (Showing first 3 of 3 rows) + +.. code-block:: text + :caption: Output + + ╭───────────╮ + │ x │ + │ --- │ + │ Float64 │ + ╞═══════════╡ + │ 150.0 │ + ╰───────────╯ + (Showing first 1 of 1 rows)"; + + pub(crate) const MIN_DOCSTRING: &str = + "Finds the minimum value among non-null elements in the input expression. + +Example: + +.. code-block:: sql + :caption: SQL + + SELECT min(x) FROM tbl + +.. code-block:: text + :caption: Input + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 100 │ + ├╌╌╌╌╌╌╌┤ + │ 200 │ + ├╌╌╌╌╌╌╌┤ + │ null │ + ╰───────╯ + (Showing first 3 of 3 rows) + +.. code-block:: text + :caption: Output + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 100 │ + ╰───────╯ + (Showing first 1 of 1 rows)"; + + pub(crate) const MAX_DOCSTRING: &str = + "Finds the maximum value among non-null elements in the input expression. + +Example: + +.. code-block:: sql + :caption: SQL + + SELECT max(x) FROM tbl + +.. code-block:: text + :caption: Input + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 100 │ + ├╌╌╌╌╌╌╌┤ + │ 200 │ + ├╌╌╌╌╌╌╌┤ + │ null │ + ╰───────╯ + (Showing first 3 of 3 rows) + +.. code-block:: text + :caption: Output + + ╭───────╮ + │ x │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 200 │ + ╰───────╯ + (Showing first 1 of 1 rows)"; +} diff --git a/src/daft-sql/src/modules/config.rs b/src/daft-sql/src/modules/config.rs new file mode 100644 index 0000000000..9a540d3025 --- /dev/null +++ b/src/daft-sql/src/modules/config.rs @@ -0,0 +1,391 @@ +use common_io_config::{AzureConfig, GCSConfig, HTTPConfig, IOConfig, S3Config}; +use daft_core::prelude::{DataType, Field}; +use daft_dsl::{literal_value, Expr, ExprRef, LiteralValue}; + +use super::SQLModule; +use crate::{ + error::{PlannerError, SQLPlannerResult}, + functions::{SQLFunction, SQLFunctionArguments, SQLFunctions}, + unsupported_sql_err, +}; + +pub struct SQLModuleConfig; + +impl SQLModule for SQLModuleConfig { + fn register(parent: &mut SQLFunctions) { + parent.add_fn("S3Config", S3ConfigFunction); + parent.add_fn("HTTPConfig", HTTPConfigFunction); + parent.add_fn("AzureConfig", AzureConfigFunction); + parent.add_fn("GCSConfig", GCSConfigFunction); + } +} + +pub struct S3ConfigFunction; +macro_rules! item { + ($name:expr, $ty:ident) => { + ( + Field::new(stringify!($name), DataType::$ty), + literal_value($name), + ) + }; +} + +impl SQLFunction for S3ConfigFunction { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + // TODO(cory): Ideally we should use serde to deserialize the input arguments + let args: SQLFunctionArguments = planner.parse_function_args( + inputs, + &[ + "region_name", + "endpoint_url", + "key_id", + "session_token", + "access_key", + "credentials_provider", + "buffer_time", + "max_connections_per_io_thread", + "retry_initial_backoff_ms", + "connect_timeout_ms", + "read_timeout_ms", + "num_tries", + "retry_mode", + "anonymous", + "use_ssl", + "verify_ssl", + "check_hostname_ssl", + "requester_pays", + "force_virtual_addressing", + "profile_name", + ], + 0, + )?; + + let region_name = args.try_get_named::("region_name")?; + let endpoint_url = args.try_get_named::("endpoint_url")?; + let key_id = args.try_get_named::("key_id")?; + let session_token = args.try_get_named::("session_token")?; + + let access_key = args.try_get_named::("access_key")?; + let buffer_time = args.try_get_named("buffer_time")?.map(|t: i64| t as u64); + + let max_connections_per_io_thread = args + .try_get_named("max_connections_per_io_thread")? + .map(|t: i64| t as u32); + + let retry_initial_backoff_ms = args + .try_get_named("retry_initial_backoff_ms")? + .map(|t: i64| t as u64); + + let connect_timeout_ms = args + .try_get_named("connect_timeout_ms")? + .map(|t: i64| t as u64); + + let read_timeout_ms = args + .try_get_named("read_timeout_ms")? + .map(|t: i64| t as u64); + + let num_tries = args.try_get_named("num_tries")?.map(|t: i64| t as u32); + let retry_mode = args.try_get_named::("retry_mode")?; + let anonymous = args.try_get_named::("anonymous")?; + let use_ssl = args.try_get_named::("use_ssl")?; + let verify_ssl = args.try_get_named::("verify_ssl")?; + let check_hostname_ssl = args.try_get_named::("check_hostname_ssl")?; + let requester_pays = args.try_get_named::("requester_pays")?; + let force_virtual_addressing = args.try_get_named::("force_virtual_addressing")?; + let profile_name = args.try_get_named::("profile_name")?; + + let entries = vec![ + (Field::new("variant", DataType::Utf8), literal_value("s3")), + item!(region_name, Utf8), + item!(endpoint_url, Utf8), + item!(key_id, Utf8), + item!(session_token, Utf8), + item!(access_key, Utf8), + item!(buffer_time, UInt64), + item!(max_connections_per_io_thread, UInt32), + item!(retry_initial_backoff_ms, UInt64), + item!(connect_timeout_ms, UInt64), + item!(read_timeout_ms, UInt64), + item!(num_tries, UInt32), + item!(retry_mode, Utf8), + item!(anonymous, Boolean), + item!(use_ssl, Boolean), + item!(verify_ssl, Boolean), + item!(check_hostname_ssl, Boolean), + item!(requester_pays, Boolean), + item!(force_virtual_addressing, Boolean), + item!(profile_name, Utf8), + ] + .into_iter() + .collect::<_>(); + + Ok(Expr::Literal(LiteralValue::Struct(entries)).arced()) + } +} + +pub struct HTTPConfigFunction; + +impl SQLFunction for HTTPConfigFunction { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + let args: SQLFunctionArguments = + planner.parse_function_args(inputs, &["user_agent", "bearer_token"], 0)?; + + let user_agent = args.try_get_named::("user_agent")?; + let bearer_token = args.try_get_named::("bearer_token")?; + + let entries = vec![ + (Field::new("variant", DataType::Utf8), literal_value("http")), + item!(user_agent, Utf8), + item!(bearer_token, Utf8), + ] + .into_iter() + .collect::<_>(); + + Ok(Expr::Literal(LiteralValue::Struct(entries)).arced()) + } +} +pub struct AzureConfigFunction; +impl SQLFunction for AzureConfigFunction { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + let args: SQLFunctionArguments = planner.parse_function_args( + inputs, + &[ + "storage_account", + "access_key", + "sas_token", + "bearer_token", + "tenant_id", + "client_id", + "client_secret", + "use_fabric_endpoint", + "anonymous", + "endpoint_url", + "use_ssl", + ], + 0, + )?; + + let storage_account = args.try_get_named::("storage_account")?; + let access_key = args.try_get_named::("access_key")?; + let sas_token = args.try_get_named::("sas_token")?; + let bearer_token = args.try_get_named::("bearer_token")?; + let tenant_id = args.try_get_named::("tenant_id")?; + let client_id = args.try_get_named::("client_id")?; + let client_secret = args.try_get_named::("client_secret")?; + let use_fabric_endpoint = args.try_get_named::("use_fabric_endpoint")?; + let anonymous = args.try_get_named::("anonymous")?; + let endpoint_url = args.try_get_named::("endpoint_url")?; + let use_ssl = args.try_get_named::("use_ssl")?; + + let entries = vec![ + ( + Field::new("variant", DataType::Utf8), + literal_value("azure"), + ), + item!(storage_account, Utf8), + item!(access_key, Utf8), + item!(sas_token, Utf8), + item!(bearer_token, Utf8), + item!(tenant_id, Utf8), + item!(client_id, Utf8), + item!(client_secret, Utf8), + item!(use_fabric_endpoint, Boolean), + item!(anonymous, Boolean), + item!(endpoint_url, Utf8), + item!(use_ssl, Boolean), + ] + .into_iter() + .collect::<_>(); + + Ok(Expr::Literal(LiteralValue::Struct(entries)).arced()) + } +} + +pub struct GCSConfigFunction; + +impl SQLFunction for GCSConfigFunction { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + let args: SQLFunctionArguments = planner.parse_function_args( + inputs, + &["project_id", "credentials", "token", "anonymous"], + 0, + )?; + + let project_id = args.try_get_named::("project_id")?; + let credentials = args.try_get_named::("credentials")?; + let token = args.try_get_named::("token")?; + let anonymous = args.try_get_named::("anonymous")?; + + let entries = vec![ + (Field::new("variant", DataType::Utf8), literal_value("gcs")), + item!(project_id, Utf8), + item!(credentials, Utf8), + item!(token, Utf8), + item!(anonymous, Boolean), + ] + .into_iter() + .collect::<_>(); + + Ok(Expr::Literal(LiteralValue::Struct(entries)).arced()) + } +} + +pub(crate) fn expr_to_iocfg(expr: &ExprRef) -> SQLPlannerResult { + // TODO(CORY): use serde to deserialize this + let Expr::Literal(LiteralValue::Struct(entries)) = expr.as_ref() else { + unsupported_sql_err!("Invalid IOConfig"); + }; + + macro_rules! get_value { + ($field:literal, $type:ident) => { + entries + .get(&Field::new($field, DataType::$type)) + .and_then(|s| match s { + LiteralValue::$type(s) => Some(Ok(s.clone())), + LiteralValue::Null => None, + _ => Some(Err(PlannerError::invalid_argument($field, "IOConfig"))), + }) + .transpose() + }; + } + + let variant = get_value!("variant", Utf8)? + .expect("variant is required for IOConfig, this indicates a programming error"); + + match variant.as_ref() { + "s3" => { + let region_name = get_value!("region_name", Utf8)?; + let endpoint_url = get_value!("endpoint_url", Utf8)?; + let key_id = get_value!("key_id", Utf8)?; + let session_token = get_value!("session_token", Utf8)?.map(|s| s.into()); + let access_key = get_value!("access_key", Utf8)?.map(|s| s.into()); + let buffer_time = get_value!("buffer_time", UInt64)?; + let max_connections_per_io_thread = + get_value!("max_connections_per_io_thread", UInt32)?; + let retry_initial_backoff_ms = get_value!("retry_initial_backoff_ms", UInt64)?; + let connect_timeout_ms = get_value!("connect_timeout_ms", UInt64)?; + let read_timeout_ms = get_value!("read_timeout_ms", UInt64)?; + let num_tries = get_value!("num_tries", UInt32)?; + let retry_mode = get_value!("retry_mode", Utf8)?; + let anonymous = get_value!("anonymous", Boolean)?; + let use_ssl = get_value!("use_ssl", Boolean)?; + let verify_ssl = get_value!("verify_ssl", Boolean)?; + let check_hostname_ssl = get_value!("check_hostname_ssl", Boolean)?; + let requester_pays = get_value!("requester_pays", Boolean)?; + let force_virtual_addressing = get_value!("force_virtual_addressing", Boolean)?; + let profile_name = get_value!("profile_name", Utf8)?; + let default = S3Config::default(); + let s3_config = S3Config { + region_name, + endpoint_url, + key_id, + session_token, + access_key, + credentials_provider: None, + buffer_time, + max_connections_per_io_thread: max_connections_per_io_thread + .unwrap_or(default.max_connections_per_io_thread), + retry_initial_backoff_ms: retry_initial_backoff_ms + .unwrap_or(default.retry_initial_backoff_ms), + connect_timeout_ms: connect_timeout_ms.unwrap_or(default.connect_timeout_ms), + read_timeout_ms: read_timeout_ms.unwrap_or(default.read_timeout_ms), + num_tries: num_tries.unwrap_or(default.num_tries), + retry_mode, + anonymous: anonymous.unwrap_or(default.anonymous), + use_ssl: use_ssl.unwrap_or(default.use_ssl), + verify_ssl: verify_ssl.unwrap_or(default.verify_ssl), + check_hostname_ssl: check_hostname_ssl.unwrap_or(default.check_hostname_ssl), + requester_pays: requester_pays.unwrap_or(default.requester_pays), + force_virtual_addressing: force_virtual_addressing + .unwrap_or(default.force_virtual_addressing), + profile_name, + }; + + Ok(IOConfig { + s3: s3_config, + ..Default::default() + }) + } + "http" => { + let default = HTTPConfig::default(); + let user_agent = get_value!("user_agent", Utf8)?.unwrap_or(default.user_agent); + let bearer_token = get_value!("bearer_token", Utf8)?.map(|s| s.into()); + + Ok(IOConfig { + http: HTTPConfig { + user_agent, + bearer_token, + }, + ..Default::default() + }) + } + "azure" => { + let storage_account = get_value!("storage_account", Utf8)?; + let access_key = get_value!("access_key", Utf8)?; + let sas_token = get_value!("sas_token", Utf8)?; + let bearer_token = get_value!("bearer_token", Utf8)?; + let tenant_id = get_value!("tenant_id", Utf8)?; + let client_id = get_value!("client_id", Utf8)?; + let client_secret = get_value!("client_secret", Utf8)?; + let use_fabric_endpoint = get_value!("use_fabric_endpoint", Boolean)?; + let anonymous = get_value!("anonymous", Boolean)?; + let endpoint_url = get_value!("endpoint_url", Utf8)?; + let use_ssl = get_value!("use_ssl", Boolean)?; + + let default = AzureConfig::default(); + + Ok(IOConfig { + azure: AzureConfig { + storage_account, + access_key: access_key.map(|s| s.into()), + sas_token, + bearer_token, + tenant_id, + client_id, + client_secret: client_secret.map(|s| s.into()), + use_fabric_endpoint: use_fabric_endpoint.unwrap_or(default.use_fabric_endpoint), + anonymous: anonymous.unwrap_or(default.anonymous), + endpoint_url, + use_ssl: use_ssl.unwrap_or(default.use_ssl), + }, + ..Default::default() + }) + } + "gcs" => { + let project_id = get_value!("project_id", Utf8)?; + let credentials = get_value!("credentials", Utf8)?; + let token = get_value!("token", Utf8)?; + let anonymous = get_value!("anonymous", Boolean)?; + let default = GCSConfig::default(); + + Ok(IOConfig { + gcs: GCSConfig { + project_id, + credentials: credentials.map(|s| s.into()), + token, + anonymous: anonymous.unwrap_or(default.anonymous), + }, + ..Default::default() + }) + } + _ => { + unreachable!("variant is required for IOConfig, this indicates a programming error") + } + } +} diff --git a/src/daft-sql/src/modules/float.rs b/src/daft-sql/src/modules/float.rs index 4cfffe34b4..292a5c4d85 100644 --- a/src/daft-sql/src/modules/float.rs +++ b/src/daft-sql/src/modules/float.rs @@ -37,6 +37,14 @@ impl SQLFunction for SQLFillNan { _ => unsupported_sql_err!("Invalid arguments for 'fill_nan': '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::FILL_NAN_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "fill_value"] + } } pub struct SQLIsInf {} @@ -52,6 +60,14 @@ impl SQLFunction for SQLIsInf { _ => unsupported_sql_err!("Invalid arguments for 'is_inf': '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::IS_INF_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } pub struct SQLIsNan {} @@ -67,6 +83,14 @@ impl SQLFunction for SQLIsNan { _ => unsupported_sql_err!("Invalid arguments for 'is_nan': '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::IS_NAN_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } pub struct SQLNotNan {} @@ -82,4 +106,26 @@ impl SQLFunction for SQLNotNan { _ => unsupported_sql_err!("Invalid arguments for 'not_nan': '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::NOT_NAN_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } +} + +mod static_docs { + pub(crate) const FILL_NAN_DOCSTRING: &str = + "Replaces NaN values in the input expression with a specified fill value."; + + pub(crate) const IS_INF_DOCSTRING: &str = + "Checks if the input expression is infinite (positive or negative infinity)."; + + pub(crate) const IS_NAN_DOCSTRING: &str = + "Checks if the input expression is NaN (Not a Number)."; + + pub(crate) const NOT_NAN_DOCSTRING: &str = + "Checks if the input expression is not NaN (Not a Number)."; } diff --git a/src/daft-sql/src/modules/hashing.rs b/src/daft-sql/src/modules/hashing.rs new file mode 100644 index 0000000000..e1ca169135 --- /dev/null +++ b/src/daft-sql/src/modules/hashing.rs @@ -0,0 +1,111 @@ +use daft_dsl::ExprRef; +use daft_functions::{ + hash::hash, + minhash::{minhash, MinHashFunction}, +}; +use sqlparser::ast::FunctionArg; + +use super::SQLModule; +use crate::{ + error::{PlannerError, SQLPlannerResult}, + functions::{SQLFunction, SQLFunctionArguments, SQLFunctions}, + unsupported_sql_err, +}; + +pub struct SQLModuleHashing; + +impl SQLModule for SQLModuleHashing { + fn register(parent: &mut SQLFunctions) { + parent.add_fn("hash", SQLHash); + parent.add_fn("minhash", SQLMinhash); + } +} + +pub struct SQLHash; + +impl SQLFunction for SQLHash { + fn to_expr( + &self, + inputs: &[FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(hash(input, None)) + } + [input, seed] => { + let input = planner.plan_function_arg(input)?; + match seed { + FunctionArg::Named { name, arg, .. } if name.value == "seed" => { + let seed = planner.try_unwrap_function_arg_expr(arg)?; + Ok(hash(input, Some(seed))) + } + arg @ FunctionArg::Unnamed(_) => { + let seed = planner.plan_function_arg(arg)?; + Ok(hash(input, Some(seed))) + } + _ => unsupported_sql_err!("Invalid arguments for hash: '{inputs:?}'"), + } + } + _ => unsupported_sql_err!("Invalid arguments for hash: '{inputs:?}'"), + } + } +} + +pub struct SQLMinhash; + +impl TryFrom for MinHashFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let num_hashes = args + .get_named("num_hashes") + .ok_or_else(|| PlannerError::invalid_operation("num_hashes is required"))? + .as_literal() + .and_then(daft_dsl::LiteralValue::as_i64) + .ok_or_else(|| PlannerError::invalid_operation("num_hashes must be an integer"))? + as usize; + + let ngram_size = args + .get_named("ngram_size") + .ok_or_else(|| PlannerError::invalid_operation("ngram_size is required"))? + .as_literal() + .and_then(daft_dsl::LiteralValue::as_i64) + .ok_or_else(|| PlannerError::invalid_operation("ngram_size must be an integer"))? + as usize; + let seed = args + .get_named("seed") + .map(|arg| { + arg.as_literal() + .and_then(daft_dsl::LiteralValue::as_i64) + .ok_or_else(|| PlannerError::invalid_operation("num_hashes must be an integer")) + }) + .transpose()? + .unwrap_or(1) as u32; + Ok(Self { + num_hashes, + ngram_size, + seed, + }) + } +} + +impl SQLFunction for SQLMinhash { + fn to_expr( + &self, + inputs: &[FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args: MinHashFunction = + planner.plan_function_args(args, &["num_hashes", "ngram_size", "seed"], 0)?; + + Ok(minhash(input, args.num_hashes, args.ngram_size, args.seed)) + } + _ => unsupported_sql_err!("Invalid arguments for minhash: '{inputs:?}'"), + } + } +} diff --git a/src/daft-sql/src/modules/image/crop.rs b/src/daft-sql/src/modules/image/crop.rs index 36c72fcca3..286208889c 100644 --- a/src/daft-sql/src/modules/image/crop.rs +++ b/src/daft-sql/src/modules/image/crop.rs @@ -21,4 +21,12 @@ impl SQLFunction for SQLImageCrop { _ => unsupported_sql_err!("Invalid arguments for image_crop: '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + "Crops an image to a specified bounding box. The bounding box is specified as [x, y, width, height].".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input_image", "bounding_box"] + } } diff --git a/src/daft-sql/src/modules/image/decode.rs b/src/daft-sql/src/modules/image/decode.rs index a6b95d538d..a896c67a05 100644 --- a/src/daft-sql/src/modules/image/decode.rs +++ b/src/daft-sql/src/modules/image/decode.rs @@ -61,4 +61,12 @@ impl SQLFunction for SQLImageDecode { _ => unsupported_sql_err!("Invalid arguments for image_decode: '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + "Decodes an image from binary data. Optionally, you can specify the image mode and error handling behavior.".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "mode", "on_error"] + } } diff --git a/src/daft-sql/src/modules/image/encode.rs b/src/daft-sql/src/modules/image/encode.rs index a902179f88..acf489c807 100644 --- a/src/daft-sql/src/modules/image/encode.rs +++ b/src/daft-sql/src/modules/image/encode.rs @@ -46,4 +46,12 @@ impl SQLFunction for SQLImageEncode { _ => unsupported_sql_err!("Invalid arguments for image_encode: '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + "Encodes an image into the specified image file format, returning a binary column of encoded bytes.".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input_image", "image_format"] + } } diff --git a/src/daft-sql/src/modules/image/resize.rs b/src/daft-sql/src/modules/image/resize.rs index 8ce37eb7f8..ac6c12fd50 100644 --- a/src/daft-sql/src/modules/image/resize.rs +++ b/src/daft-sql/src/modules/image/resize.rs @@ -16,7 +16,7 @@ impl TryFrom for ImageResize { fn try_from(args: SQLFunctionArguments) -> Result { let width = args .get_named("w") - .or_else(|| args.get_unnamed(0)) + .or_else(|| args.get_positional(0)) .map(|arg| match arg.as_ref() { Expr::Literal(LiteralValue::Int64(i)) => Ok(*i), _ => unsupported_sql_err!("Expected width to be a number"), @@ -28,7 +28,7 @@ impl TryFrom for ImageResize { let height = args .get_named("h") - .or_else(|| args.get_unnamed(1)) + .or_else(|| args.get_positional(1)) .map(|arg| match arg.as_ref() { Expr::Literal(LiteralValue::Int64(i)) => Ok(*i), _ => unsupported_sql_err!("Expected height to be a number"), @@ -64,4 +64,12 @@ impl SQLFunction for SQLImageResize { _ => unsupported_sql_err!("Invalid arguments for image_resize: '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + "Resizes an image to the specified width and height.".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input_image", "width", "height"] + } } diff --git a/src/daft-sql/src/modules/image/to_mode.rs b/src/daft-sql/src/modules/image/to_mode.rs index a02efb2d36..b5b9202d1f 100644 --- a/src/daft-sql/src/modules/image/to_mode.rs +++ b/src/daft-sql/src/modules/image/to_mode.rs @@ -41,4 +41,12 @@ impl SQLFunction for SQLImageToMode { _ => unsupported_sql_err!("Invalid arguments for image_encode: '{inputs:?}'"), } } + + fn docstrings(&self, _alias: &str) -> String { + "Converts an image to the specified mode (e.g. RGB, RGBA, Grayscale).".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input_image", "mode"] + } } diff --git a/src/daft-sql/src/modules/json.rs b/src/daft-sql/src/modules/json.rs index f0d600daea..8dc9e617f5 100644 --- a/src/daft-sql/src/modules/json.rs +++ b/src/daft-sql/src/modules/json.rs @@ -35,4 +35,17 @@ impl SQLFunction for JsonQuery { ), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::JSON_QUERY_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "query"] + } +} + +mod static_docs { + pub(crate) const JSON_QUERY_DOCSTRING: &str = + "Extracts a JSON object from a JSON string using a JSONPath expression."; } diff --git a/src/daft-sql/src/modules/list.rs b/src/daft-sql/src/modules/list.rs index b9e52d9748..bd6db25990 100644 --- a/src/daft-sql/src/modules/list.rs +++ b/src/daft-sql/src/modules/list.rs @@ -55,6 +55,14 @@ impl SQLFunction for SQLListChunk { ), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_CHUNK_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "chunk_size"] + } } pub struct SQLListCount; @@ -86,6 +94,14 @@ impl SQLFunction for SQLListCount { _ => unsupported_sql_err!("invalid arguments for list_count. Expected either list_count(expr) or list_count(expr, mode)"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_COUNT_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "mode"] + } } pub struct SQLExplode; @@ -104,6 +120,14 @@ impl SQLFunction for SQLExplode { _ => unsupported_sql_err!("Expected 1 argument"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::EXPLODE_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } pub struct SQLListJoin; @@ -125,6 +149,14 @@ impl SQLFunction for SQLListJoin { ), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_JOIN_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "separator"] + } } pub struct SQLListMax; @@ -143,6 +175,14 @@ impl SQLFunction for SQLListMax { _ => unsupported_sql_err!("invalid arguments for list_max. Expected list_max(expr)"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_MAX_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } pub struct SQLListMean; @@ -161,6 +201,14 @@ impl SQLFunction for SQLListMean { _ => unsupported_sql_err!("invalid arguments for list_mean. Expected list_mean(expr)"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_MEAN_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } pub struct SQLListMin; @@ -179,6 +227,14 @@ impl SQLFunction for SQLListMin { _ => unsupported_sql_err!("invalid arguments for list_min. Expected list_min(expr)"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_MIN_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } pub struct SQLListSum; @@ -197,6 +253,14 @@ impl SQLFunction for SQLListSum { _ => unsupported_sql_err!("invalid arguments for list_sum. Expected list_sum(expr)"), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_SUM_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } pub struct SQLListSlice; @@ -219,6 +283,14 @@ impl SQLFunction for SQLListSlice { ), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_SLICE_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "start", "end"] + } } pub struct SQLListSort; @@ -258,4 +330,38 @@ impl SQLFunction for SQLListSort { ), } } + + fn docstrings(&self, _alias: &str) -> String { + static_docs::LIST_SORT_DOCSTRING.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "order"] + } +} + +mod static_docs { + pub(crate) const LIST_CHUNK_DOCSTRING: &str = "Splits a list into chunks of a specified size."; + + pub(crate) const LIST_COUNT_DOCSTRING: &str = "Counts the number of elements in a list."; + + pub(crate) const EXPLODE_DOCSTRING: &str = "Expands a list column into multiple rows."; + + pub(crate) const LIST_JOIN_DOCSTRING: &str = + "Joins elements of a list into a single string using a specified separator."; + + pub(crate) const LIST_MAX_DOCSTRING: &str = "Returns the maximum value in a list."; + + pub(crate) const LIST_MEAN_DOCSTRING: &str = + "Calculates the mean (average) of values in a list."; + + pub(crate) const LIST_MIN_DOCSTRING: &str = "Returns the minimum value in a list."; + + pub(crate) const LIST_SUM_DOCSTRING: &str = "Calculates the sum of values in a list."; + + pub(crate) const LIST_SLICE_DOCSTRING: &str = + "Extracts a portion of a list from a start index to an end index."; + + pub(crate) const LIST_SORT_DOCSTRING: &str = + "Sorts the elements of a list in ascending or descending order."; } diff --git a/src/daft-sql/src/modules/map.rs b/src/daft-sql/src/modules/map.rs index d3a328f3a4..0ae5aca2be 100644 --- a/src/daft-sql/src/modules/map.rs +++ b/src/daft-sql/src/modules/map.rs @@ -30,4 +30,23 @@ impl SQLFunction for MapGet { _ => invalid_operation_err!("Expected 2 input args"), } } + + fn docstrings(&self, alias: &str) -> String { + static_docs::MAP_GET_DOCSTRING.replace("{}", alias) + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "key"] + } +} + +mod static_docs { + pub(crate) const MAP_GET_DOCSTRING: &str = + "Retrieves the value associated with a given key from a map. + +.. seealso:: + + * :func:`~daft.sql._sql_funcs.map_get` + * :func:`~daft.sql._sql_funcs.map_extract` +"; } diff --git a/src/daft-sql/src/modules/mod.rs b/src/daft-sql/src/modules/mod.rs index 0f60ecbff9..ded8007e2d 100644 --- a/src/daft-sql/src/modules/mod.rs +++ b/src/daft-sql/src/modules/mod.rs @@ -1,7 +1,9 @@ use crate::functions::SQLFunctions; pub mod aggs; +pub mod config; pub mod float; +pub mod hashing; pub mod image; pub mod json; pub mod list; @@ -15,6 +17,7 @@ pub mod temporal; pub mod utf8; pub use aggs::SQLModuleAggs; +pub use config::SQLModuleConfig; pub use float::SQLModuleFloat; pub use image::SQLModuleImage; pub use json::SQLModuleJson; diff --git a/src/daft-sql/src/modules/numeric.rs b/src/daft-sql/src/modules/numeric.rs index 197d958860..66178f2f3b 100644 --- a/src/daft-sql/src/modules/numeric.rs +++ b/src/daft-sql/src/modules/numeric.rs @@ -88,6 +88,67 @@ impl SQLFunction for SQLNumericExpr { let inputs = self.args_to_expr_unnamed(inputs, planner)?; to_expr(self, inputs.as_slice()) } + + fn docstrings(&self, _alias: &str) -> String { + let docstring = match self { + Self::Abs => "Gets the absolute value of a number.", + Self::Ceil => "Rounds a number up to the nearest integer.", + Self::Exp => "Calculates the exponential of a number (e^x).", + Self::Floor => "Rounds a number down to the nearest integer.", + Self::Round => "Rounds a number to a specified number of decimal places.", + Self::Sign => "Returns the sign of a number (-1, 0, or 1).", + Self::Sqrt => "Calculates the square root of a number.", + Self::Sin => "Calculates the sine of an angle in radians.", + Self::Cos => "Calculates the cosine of an angle in radians.", + Self::Tan => "Calculates the tangent of an angle in radians.", + Self::Cot => "Calculates the cotangent of an angle in radians.", + Self::ArcSin => "Calculates the inverse sine (arc sine) of a number.", + Self::ArcCos => "Calculates the inverse cosine (arc cosine) of a number.", + Self::ArcTan => "Calculates the inverse tangent (arc tangent) of a number.", + Self::ArcTan2 => { + "Calculates the angle between the positive x-axis and the ray from (0,0) to (x,y)." + } + Self::Radians => "Converts an angle from degrees to radians.", + Self::Degrees => "Converts an angle from radians to degrees.", + Self::Log => "Calculates the natural logarithm of a number.", + Self::Log2 => "Calculates the base-2 logarithm of a number.", + Self::Log10 => "Calculates the base-10 logarithm of a number.", + Self::Ln => "Calculates the natural logarithm of a number.", + Self::ArcTanh => "Calculates the inverse hyperbolic tangent of a number.", + Self::ArcCosh => "Calculates the inverse hyperbolic cosine of a number.", + Self::ArcSinh => "Calculates the inverse hyperbolic sine of a number.", + }; + docstring.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + match self { + Self::Abs + | Self::Ceil + | Self::Floor + | Self::Sign + | Self::Sqrt + | Self::Sin + | Self::Cos + | Self::Tan + | Self::Cot + | Self::ArcSin + | Self::ArcCos + | Self::ArcTan + | Self::Radians + | Self::Degrees + | Self::Log2 + | Self::Log10 + | Self::Ln + | Self::ArcTanh + | Self::ArcCosh + | Self::ArcSinh => &["input"], + Self::Log => &["input", "base"], + Self::Round => &["input", "precision"], + Self::Exp => &["input", "exponent"], + Self::ArcTan2 => &["y", "x"], + } + } } fn to_expr(expr: &SQLNumericExpr, args: &[ExprRef]) -> SQLPlannerResult { @@ -180,8 +241,8 @@ fn to_expr(expr: &SQLNumericExpr, args: &[ExprRef]) -> SQLPlannerResult .as_literal() .and_then(|lit| match lit { LiteralValue::Float64(f) => Some(*f), - LiteralValue::Int32(i) => Some(*i as f64), - LiteralValue::UInt32(u) => Some(*u as f64), + LiteralValue::Int32(i) => Some(f64::from(*i)), + LiteralValue::UInt32(u) => Some(f64::from(*u)), LiteralValue::Int64(i) => Some(*i as f64), LiteralValue::UInt64(u) => Some(*u as f64), _ => None, diff --git a/src/daft-sql/src/modules/partitioning.rs b/src/daft-sql/src/modules/partitioning.rs index e833edd51d..e3600e6af3 100644 --- a/src/daft-sql/src/modules/partitioning.rs +++ b/src/daft-sql/src/modules/partitioning.rs @@ -42,14 +42,14 @@ impl SQLFunction for PartitioningExpr { let n = planner .plan_function_arg(&args[1])? .as_literal() - .and_then(|l| l.as_i64()) + .and_then(daft_dsl::LiteralValue::as_i64) .ok_or_else(|| { crate::error::PlannerError::unsupported_sql( "Expected integer literal".to_string(), ) }) .and_then(|n| { - if n > i32::MAX as i64 { + if n > i64::from(i32::MAX) { Err(crate::error::PlannerError::unsupported_sql( "Integer literal too large".to_string(), )) @@ -69,7 +69,7 @@ impl SQLFunction for PartitioningExpr { let w = planner .plan_function_arg(&args[1])? .as_literal() - .and_then(|l| l.as_i64()) + .and_then(daft_dsl::LiteralValue::as_i64) .ok_or_else(|| { crate::error::PlannerError::unsupported_sql( "Expected integer literal".to_string(), @@ -80,6 +80,28 @@ impl SQLFunction for PartitioningExpr { } } } + + fn docstrings(&self, _alias: &str) -> String { + match self { + Self::Years => "Extracts the number of years since epoch time from a datetime expression.".to_string(), + Self::Months => "Extracts the number of months since epoch time from a datetime expression.".to_string(), + Self::Days => "Extracts the number of days since epoch time from a datetime expression.".to_string(), + Self::Hours => "Extracts the number of hours since epoch time from a datetime expression.".to_string(), + Self::IcebergBucket(_) => "Computes a bucket number for the input expression based the specified number of buckets using an Iceberg-specific hash.".to_string(), + Self::IcebergTruncate(_) => "Truncates the input expression to a specified width.".to_string(), + } + } + + fn arg_names(&self) -> &'static [&'static str] { + match self { + Self::Years => &["input"], + Self::Months => &["input"], + Self::Days => &["input"], + Self::Hours => &["input"], + Self::IcebergBucket(_) => &["input", "num_buckets"], + Self::IcebergTruncate(_) => &["input", "width"], + } + } } fn partitioning_helper daft_dsl::ExprRef>( diff --git a/src/daft-sql/src/modules/structs.rs b/src/daft-sql/src/modules/structs.rs index 66be42d8e3..17fae85c9e 100644 --- a/src/daft-sql/src/modules/structs.rs +++ b/src/daft-sql/src/modules/structs.rs @@ -34,4 +34,12 @@ impl SQLFunction for StructGet { _ => invalid_operation_err!("Expected 2 input args"), } } + + fn docstrings(&self, _alias: &str) -> String { + "Extracts a field from a struct expression by name.".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input", "field"] + } } diff --git a/src/daft-sql/src/modules/temporal.rs b/src/daft-sql/src/modules/temporal.rs index 58687724fa..51e275d7c1 100644 --- a/src/daft-sql/src/modules/temporal.rs +++ b/src/daft-sql/src/modules/temporal.rs @@ -1,5 +1,7 @@ use daft_dsl::ExprRef; -use daft_functions::temporal::*; +use daft_functions::temporal::{ + dt_date, dt_day, dt_day_of_week, dt_hour, dt_minute, dt_month, dt_second, dt_time, dt_year, +}; use sqlparser::ast::FunctionArg; use super::SQLModule; @@ -50,6 +52,16 @@ macro_rules! temporal { ), } } + fn docstrings(&self, _alias: &str) -> String { + format!( + "Extracts the {} component from a datetime expression.", + stringify!($fn_name).replace("dt_", "") + ) + } + + fn arg_names(&self) -> &'static [&'static str] { + &["input"] + } } }; } diff --git a/src/daft-sql/src/modules/utf8.rs b/src/daft-sql/src/modules/utf8.rs index 263a8bd9e7..2aafed49c1 100644 --- a/src/daft-sql/src/modules/utf8.rs +++ b/src/daft-sql/src/modules/utf8.rs @@ -1,32 +1,48 @@ +use daft_core::array::ops::Utf8NormalizeOptions; use daft_dsl::{ - functions::{self, utf8::Utf8Expr}, + functions::{ + self, + utf8::{normalize, Utf8Expr}, + }, ExprRef, LiteralValue, }; +use daft_functions::{ + count_matches::{utf8_count_matches, CountMatchesFunction}, + tokenize::{tokenize_decode, tokenize_encode, TokenizeDecodeFunction, TokenizeEncodeFunction}, +}; use super::SQLModule; use crate::{ - ensure, error::SQLPlannerResult, functions::SQLFunction, invalid_operation_err, - unsupported_sql_err, + ensure, + error::{PlannerError, SQLPlannerResult}, + functions::{SQLFunction, SQLFunctionArguments}, + invalid_operation_err, unsupported_sql_err, }; pub struct SQLModuleUtf8; impl SQLModule for SQLModuleUtf8 { fn register(parent: &mut crate::functions::SQLFunctions) { - use Utf8Expr::*; + use Utf8Expr::{ + Capitalize, Contains, EndsWith, Extract, ExtractAll, Find, Left, Length, LengthBytes, + Lower, Lpad, Lstrip, Match, Repeat, Replace, Reverse, Right, Rpad, Rstrip, Split, + StartsWith, ToDate, ToDatetime, Upper, + }; parent.add_fn("ends_with", EndsWith); parent.add_fn("starts_with", StartsWith); parent.add_fn("contains", Contains); - parent.add_fn("split", Split(true)); + parent.add_fn("split", Split(false)); // TODO add split variants // parent.add("split", f(Split(false))); - parent.add_fn("match", Match); - parent.add_fn("extract", Extract(0)); - parent.add_fn("extract_all", ExtractAll(0)); - parent.add_fn("replace", Replace(true)); + parent.add_fn("regexp_match", Match); + parent.add_fn("regexp_extract", Extract(0)); + parent.add_fn("regexp_extract_all", ExtractAll(0)); + parent.add_fn("regexp_replace", Replace(true)); + parent.add_fn("regexp_split", Split(true)); // TODO add replace variants // parent.add("replace", f(Replace(false))); parent.add_fn("length", Length); + parent.add_fn("length_bytes", LengthBytes); parent.add_fn("lower", Lower); parent.add_fn("upper", Upper); parent.add_fn("lstrip", Lstrip); @@ -39,13 +55,13 @@ impl SQLModule for SQLModuleUtf8 { parent.add_fn("rpad", Rpad); parent.add_fn("lpad", Lpad); parent.add_fn("repeat", Repeat); - parent.add_fn("like", Like); - parent.add_fn("ilike", Ilike); - parent.add_fn("substr", Substr); - parent.add_fn("to_date", ToDate("".to_string())); - parent.add_fn("to_datetime", ToDatetime("".to_string(), None)); - // TODO add normalization variants. - // parent.add("normalize", f(Normalize(Default::default()))); + + parent.add_fn("to_date", ToDate(String::new())); + parent.add_fn("to_datetime", ToDatetime(String::new(), None)); + parent.add_fn("count_matches", SQLCountMatches); + parent.add_fn("normalize", SQLNormalize); + parent.add_fn("tokenize_encode", SQLTokenizeEncode); + parent.add_fn("tokenize_decode", SQLTokenizeDecode); } } @@ -60,11 +76,85 @@ impl SQLFunction for Utf8Expr { let inputs = self.args_to_expr_unnamed(inputs, planner)?; to_expr(self, &inputs) } + + fn docstrings(&self, _alias: &str) -> String { + match self { + Self::EndsWith => "Returns true if the string ends with the specified substring".to_string(), + Self::StartsWith => "Returns true if the string starts with the specified substring".to_string(), + Self::Contains => "Returns true if the string contains the specified substring".to_string(), + Self::Split(_) => "Splits the string by the specified delimiter and returns an array of substrings".to_string(), + Self::Match => "Returns true if the string matches the specified regular expression pattern".to_string(), + Self::Extract(_) => "Extracts the first substring that matches the specified regular expression pattern".to_string(), + Self::ExtractAll(_) => "Extracts all substrings that match the specified regular expression pattern".to_string(), + Self::Replace(_) => "Replaces all occurrences of a substring with a new string".to_string(), + Self::Like => "Returns true if the string matches the specified SQL LIKE pattern".to_string(), + Self::Ilike => "Returns true if the string matches the specified SQL LIKE pattern (case-insensitive)".to_string(), + Self::Length => "Returns the length of the string".to_string(), + Self::Lower => "Converts the string to lowercase".to_string(), + Self::Upper => "Converts the string to uppercase".to_string(), + Self::Lstrip => "Removes leading whitespace from the string".to_string(), + Self::Rstrip => "Removes trailing whitespace from the string".to_string(), + Self::Reverse => "Reverses the order of characters in the string".to_string(), + Self::Capitalize => "Capitalizes the first character of the string".to_string(), + Self::Left => "Returns the specified number of leftmost characters from the string".to_string(), + Self::Right => "Returns the specified number of rightmost characters from the string".to_string(), + Self::Find => "Returns the index of the first occurrence of a substring within the string".to_string(), + Self::Rpad => "Pads the string on the right side with the specified string until it reaches the specified length".to_string(), + Self::Lpad => "Pads the string on the left side with the specified string until it reaches the specified length".to_string(), + Self::Repeat => "Repeats the string the specified number of times".to_string(), + Self::Substr => "Returns a substring of the string starting at the specified position and length".to_string(), + Self::ToDate(_) => "Parses the string as a date using the specified format.".to_string(), + Self::ToDatetime(_, _) => "Parses the string as a datetime using the specified format.".to_string(), + Self::LengthBytes => "Returns the length of the string in bytes".to_string(), + Self::Normalize(_) => unimplemented!("Normalize not implemented"), + } + } + + fn arg_names(&self) -> &'static [&'static str] { + match self { + Self::EndsWith => &["string_input", "substring"], + Self::StartsWith => &["string_input", "substring"], + Self::Contains => &["string_input", "substring"], + Self::Split(_) => &["string_input", "delimiter"], + Self::Match => &["string_input", "pattern"], + Self::Extract(_) => &["string_input", "pattern"], + Self::ExtractAll(_) => &["string_input", "pattern"], + Self::Replace(_) => &["string_input", "pattern", "replacement"], + Self::Like => &["string_input", "pattern"], + Self::Ilike => &["string_input", "pattern"], + Self::Length => &["string_input"], + Self::Lower => &["string_input"], + Self::Upper => &["string_input"], + Self::Lstrip => &["string_input"], + Self::Rstrip => &["string_input"], + Self::Reverse => &["string_input"], + Self::Capitalize => &["string_input"], + Self::Left => &["string_input", "length"], + Self::Right => &["string_input", "length"], + Self::Find => &["string_input", "substring"], + Self::Rpad => &["string_input", "length", "pad"], + Self::Lpad => &["string_input", "length", "pad"], + Self::Repeat => &["string_input", "count"], + Self::Substr => &["string_input", "start", "length"], + Self::ToDate(_) => &["string_input", "format"], + Self::ToDatetime(_, _) => &["string_input", "format"], + Self::LengthBytes => &["string_input"], + Self::Normalize(_) => unimplemented!("Normalize not implemented"), + } + } } fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { - use functions::utf8::*; - use Utf8Expr::*; + use functions::utf8::{ + capitalize, contains, endswith, extract, extract_all, find, left, length, length_bytes, + lower, lpad, lstrip, match_, repeat, replace, reverse, right, rpad, rstrip, split, + startswith, to_date, to_datetime, upper, Utf8Expr, + }; + use Utf8Expr::{ + Capitalize, Contains, EndsWith, Extract, ExtractAll, Find, Ilike, Left, Length, + LengthBytes, Like, Lower, Lpad, Lstrip, Match, Normalize, Repeat, Replace, Reverse, Right, + Rpad, Rstrip, Split, StartsWith, Substr, ToDate, ToDatetime, Upper, + }; match expr { EndsWith => { ensure!(args.len() == 2, "endswith takes exactly two arguments"); @@ -78,19 +168,44 @@ fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { ensure!(args.len() == 2, "contains takes exactly two arguments"); Ok(contains(args[0].clone(), args[1].clone())) } - Split(_) => { + Split(true) => { + ensure!(args.len() == 2, "split takes exactly two arguments"); + Ok(split(args[0].clone(), args[1].clone(), true)) + } + Split(false) => { ensure!(args.len() == 2, "split takes exactly two arguments"); Ok(split(args[0].clone(), args[1].clone(), false)) } Match => { - unsupported_sql_err!("match") - } - Extract(_) => { - unsupported_sql_err!("extract") - } - ExtractAll(_) => { - unsupported_sql_err!("extract_all") + ensure!(args.len() == 2, "regexp_match takes exactly two arguments"); + Ok(match_(args[0].clone(), args[1].clone())) } + Extract(_) => match args { + [input, pattern] => Ok(extract(input.clone(), pattern.clone(), 0)), + [input, pattern, idx] => { + let idx = idx.as_literal().and_then(daft_dsl::LiteralValue::as_i64).ok_or_else(|| { + PlannerError::invalid_operation(format!("Expected a literal integer for the third argument of regexp_extract, found {idx:?}")) + })?; + + Ok(extract(input.clone(), pattern.clone(), idx as usize)) + } + _ => { + invalid_operation_err!("regexp_extract takes exactly two or three arguments") + } + }, + ExtractAll(_) => match args { + [input, pattern] => Ok(extract_all(input.clone(), pattern.clone(), 0)), + [input, pattern, idx] => { + let idx = idx.as_literal().and_then(daft_dsl::LiteralValue::as_i64).ok_or_else(|| { + PlannerError::invalid_operation(format!("Expected a literal integer for the third argument of regexp_extract, found {idx:?}")) + })?; + + Ok(extract_all(input.clone(), pattern.clone(), idx as usize)) + } + _ => { + invalid_operation_err!("regexp_extract_all takes exactly two or three arguments") + } + }, Replace(_) => { ensure!(args.len() == 3, "replace takes exactly three arguments"); Ok(replace( @@ -101,10 +216,10 @@ fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { )) } Like => { - unsupported_sql_err!("like") + unreachable!("like should be handled by the parser") } Ilike => { - unsupported_sql_err!("ilike") + unreachable!("ilike should be handled by the parser") } Length => { ensure!(args.len() == 1, "length takes exactly one argument"); @@ -163,8 +278,7 @@ fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { Ok(repeat(args[0].clone(), args[1].clone())) } Substr => { - ensure!(args.len() == 3, "substr takes exactly three arguments"); - Ok(substr(args[0].clone(), args[1].clone(), args[2].clone())) + unreachable!("substr should be handled by the parser") } ToDate(_) => { ensure!(args.len() == 2, "to_date takes exactly two arguments"); @@ -195,3 +309,233 @@ fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { } } } + +pub struct SQLCountMatches; + +impl TryFrom for CountMatchesFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let whole_words = args.try_get_named("whole_words")?.unwrap_or(false); + let case_sensitive = args.try_get_named("case_sensitive")?.unwrap_or(true); + + Ok(Self { + whole_words, + case_sensitive, + }) + } +} + +impl SQLFunction for SQLCountMatches { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, pattern] => { + let input = planner.plan_function_arg(input)?; + let pattern = planner.plan_function_arg(pattern)?; + Ok(utf8_count_matches(input, pattern, false, true)) + } + [input, pattern, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let pattern = planner.plan_function_arg(pattern)?; + let args: CountMatchesFunction = + planner.plan_function_args(args, &["whole_words", "case_sensitive"], 0)?; + + Ok(utf8_count_matches( + input, + pattern, + args.whole_words, + args.case_sensitive, + )) + } + _ => Err(PlannerError::invalid_operation( + "Invalid arguments for count_matches: '{inputs:?}'", + )), + } + } +} + +pub struct SQLNormalize; + +impl TryFrom for Utf8NormalizeOptions { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let remove_punct = args.try_get_named("remove_punct")?.unwrap_or(false); + let lowercase = args.try_get_named("lowercase")?.unwrap_or(false); + let nfd_unicode = args.try_get_named("nfd_unicode")?.unwrap_or(false); + let white_space = args.try_get_named("white_space")?.unwrap_or(false); + + Ok(Self { + remove_punct, + lowercase, + nfd_unicode, + white_space, + }) + } +} + +impl SQLFunction for SQLNormalize { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(normalize(input, Utf8NormalizeOptions::default())) + } + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args: Utf8NormalizeOptions = planner.plan_function_args( + args, + &["remove_punct", "lowercase", "nfd_unicode", "white_space"], + 0, + )?; + Ok(normalize(input, args)) + } + _ => invalid_operation_err!("Invalid arguments for normalize"), + } + } +} + +pub struct SQLTokenizeEncode; +impl TryFrom for TokenizeEncodeFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + if args.get_named("io_config").is_some() { + return Err(PlannerError::invalid_operation( + "io_config argument is not yet supported for tokenize_encode", + )); + } + + let tokens_path = args.try_get_named("tokens_path")?.ok_or_else(|| { + PlannerError::invalid_operation("tokens_path argument is required for tokenize_encode") + })?; + + let pattern = args.try_get_named("pattern")?; + let special_tokens = args.try_get_named("special_tokens")?; + let use_special_tokens = args.try_get_named("use_special_tokens")?.unwrap_or(false); + + Ok(Self { + tokens_path, + pattern, + special_tokens, + use_special_tokens, + io_config: None, + }) + } +} + +impl SQLFunction for SQLTokenizeEncode { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, tokens_path] => { + let input = planner.plan_function_arg(input)?; + let tokens_path = planner.plan_function_arg(tokens_path)?; + let tokens_path = tokens_path + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| { + PlannerError::invalid_operation("tokens_path argument must be a string") + })?; + Ok(tokenize_encode(input, tokens_path, None, None, None, false)) + } + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args: TokenizeEncodeFunction = planner.plan_function_args( + args, + &[ + "tokens_path", + "pattern", + "special_tokens", + "use_special_tokens", + ], + 1, // tokens_path can be named or positional + )?; + Ok(tokenize_encode( + input, + &args.tokens_path, + None, + args.pattern.as_deref(), + args.special_tokens.as_deref(), + args.use_special_tokens, + )) + } + _ => invalid_operation_err!("Invalid arguments for tokenize_encode"), + } + } +} + +pub struct SQLTokenizeDecode; +impl TryFrom for TokenizeDecodeFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + if args.get_named("io_config").is_some() { + return Err(PlannerError::invalid_operation( + "io_config argument is not yet supported for tokenize_decode", + )); + } + + let tokens_path = args.try_get_named("tokens_path")?.ok_or_else(|| { + PlannerError::invalid_operation("tokens_path argument is required for tokenize_encode") + })?; + + let pattern = args.try_get_named("pattern")?; + let special_tokens = args.try_get_named("special_tokens")?; + + Ok(Self { + tokens_path, + pattern, + special_tokens, + io_config: None, + }) + } +} +impl SQLFunction for SQLTokenizeDecode { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, tokens_path] => { + let input = planner.plan_function_arg(input)?; + let tokens_path = planner.plan_function_arg(tokens_path)?; + let tokens_path = tokens_path + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| { + PlannerError::invalid_operation("tokens_path argument must be a string") + })?; + Ok(tokenize_decode(input, tokens_path, None, None, None)) + } + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args: TokenizeDecodeFunction = planner.plan_function_args( + args, + &["tokens_path", "pattern", "special_tokens"], + 1, // tokens_path can be named or positional + )?; + Ok(tokenize_decode( + input, + &args.tokens_path, + None, + args.pattern.as_deref(), + args.special_tokens.as_deref(), + )) + } + _ => invalid_operation_err!("Invalid arguments for tokenize_decode"), + } + } +} diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index b9f326ce99..5ffcbb16dd 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -29,7 +29,7 @@ use crate::{ /// A named logical plan /// This is used to keep track of the table name associated with a logical plan while planning a SQL query #[derive(Debug, Clone)] -pub(crate) struct Relation { +pub struct Relation { pub(crate) inner: LogicalPlanBuilder, pub(crate) name: String, } @@ -320,7 +320,13 @@ impl SQLPlanner { let mut left_rel = self.plan_relation(&relation)?; for join in &from.joins { - use sqlparser::ast::{JoinConstraint, JoinOperator::*}; + use sqlparser::ast::{ + JoinConstraint, + JoinOperator::{ + AsOf, CrossApply, CrossJoin, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, + OuterApply, RightAnti, RightOuter, RightSemi, + }, + }; let Relation { inner: right_plan, name: right_name, @@ -397,7 +403,19 @@ impl SQLPlanner { fn plan_relation(&self, rel: &sqlparser::ast::TableFactor) -> SQLPlannerResult { match rel { - sqlparser::ast::TableFactor::Table { name, .. } => { + sqlparser::ast::TableFactor::Table { + name, + args: Some(args), + alias, + .. + } => { + let tbl_fn = name.0.first().unwrap().value.as_str(); + + self.plan_table_function(tbl_fn, args, alias) + } + sqlparser::ast::TableFactor::Table { + name, args: None, .. + } => { let table_name = name.to_string(); let plan = self .catalog @@ -476,16 +494,18 @@ impl SQLPlanner { } }; - use sqlparser::ast::ExcludeSelectItem::*; - return match exclude { + use sqlparser::ast::ExcludeSelectItem::{Multiple, Single}; + match exclude { Single(item) => current_relation .inner .schema() .exclude(&[&item.to_string()]), Multiple(items) => { - let items = - items.iter().map(|i| i.to_string()).collect::>(); + let items = items + .iter() + .map(std::string::ToString::to_string) + .collect::>(); current_relation.inner.schema().exclude(items.as_slice()) } @@ -497,7 +517,7 @@ impl SQLPlanner { .map(|n| col(n.as_ref())) .collect::>() }) - .map_err(|e| e.into()); + .map_err(std::convert::Into::into) } else { Ok(vec![col("*")]) } @@ -515,8 +535,7 @@ impl SQLPlanner { .or_else(|_| n.parse::().map(LiteralValue::Float64)) .map_err(|_| { PlannerError::invalid_operation(format!( - "could not parse number literal '{:?}'", - n + "could not parse number literal '{n:?}'" )) })?, Value::Boolean(b) => LiteralValue::Boolean(*b), @@ -626,7 +645,26 @@ impl SQLPlanner { SQLExpr::Ceil { expr, .. } => Ok(ceil(self.plan_expr(expr)?)), SQLExpr::Floor { expr, .. } => Ok(floor(self.plan_expr(expr)?)), SQLExpr::Position { .. } => unsupported_sql_err!("POSITION"), - SQLExpr::Substring { .. } => unsupported_sql_err!("SUBSTRING"), + SQLExpr::Substring { + expr, + substring_from, + substring_for, + special: true, // We only support SUBSTRING(expr, start, length) syntax + } => { + let (Some(substring_from), Some(substring_for)) = (substring_from, substring_for) + else { + unsupported_sql_err!("SUBSTRING") + }; + + let expr = self.plan_expr(expr)?; + let start = self.plan_expr(substring_from)?; + let length = self.plan_expr(substring_for)?; + + Ok(daft_dsl::functions::utf8::substr(expr, start, length)) + } + SQLExpr::Substring { special: false, .. } => { + unsupported_sql_err!("`SUBSTRING(expr [FROM start] [FOR len])` syntax") + } SQLExpr::Trim { .. } => unsupported_sql_err!("TRIM"), SQLExpr::Overlay { .. } => unsupported_sql_err!("OVERLAY"), SQLExpr::Collate { .. } => unsupported_sql_err!("COLLATE"), @@ -709,7 +747,22 @@ impl SQLPlanner { } SQLExpr::Struct { .. } => unsupported_sql_err!("STRUCT"), SQLExpr::Named { .. } => unsupported_sql_err!("NAMED"), - SQLExpr::Dictionary(_) => unsupported_sql_err!("DICTIONARY"), + SQLExpr::Dictionary(dict) => { + let entries = dict + .iter() + .map(|entry| { + let key = entry.key.value.clone(); + let value = self.plan_expr(&entry.value)?; + let value = value.as_literal().ok_or_else(|| { + PlannerError::invalid_operation("Dictionary value is not a literal") + })?; + let struct_field = Field::new(key, value.get_type()); + Ok((struct_field, value.clone())) + }) + .collect::>()?; + + Ok(Expr::Literal(LiteralValue::Struct(entries)).arced()) + } SQLExpr::Map(_) => unsupported_sql_err!("MAP"), SQLExpr::Subscript { expr, subscript } => self.plan_subscript(expr, subscript.as_ref()), SQLExpr::Array(_) => unsupported_sql_err!("ARRAY"), @@ -751,10 +804,10 @@ impl SQLPlanner { // --------------------------------- // array/list // --------------------------------- - SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_type)) - | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_type, None)) => { - DataType::List(Box::new(self.sql_dtype_to_dtype(inner_type)?)) - } + SQLDataType::Array( + ArrayElemTypeDef::AngleBracket(inner_type) + | ArrayElemTypeDef::SquareBracket(inner_type, None), + ) => DataType::List(Box::new(self.sql_dtype_to_dtype(inner_type)?)), SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_type, Some(size))) => { DataType::FixedSizeList( Box::new(self.sql_dtype_to_dtype(inner_type)?), @@ -868,7 +921,7 @@ impl SQLPlanner { let dtype = self.sql_dtype_to_dtype(field_type)?; let name = match field_name { Some(name) => name.to_string(), - None => format!("col_{}", idx), + None => format!("col_{idx}"), }; Ok(Field::new(name, dtype)) @@ -933,7 +986,7 @@ impl SQLPlanner { .ok_or_else(|| { PlannerError::invalid_operation("subscript without a current relation") }) - .map(|p| p.schema())?; + .map(Relation::schema)?; let expr_field = expr.to_field(schema.as_ref())?; match expr_field.dtype { DataType::List(_) | DataType::FixedSizeList(_, _) => { @@ -946,7 +999,7 @@ impl SQLPlanner { invalid_operation_err!("Index must be a string literal") } } - DataType::Map(_) => Ok(daft_dsl::functions::map::get(expr, index)), + DataType::Map { .. } => Ok(daft_dsl::functions::map::get(expr, index)), dtype => { invalid_operation_err!("nested access on column with type: {}", dtype) } @@ -1080,7 +1133,7 @@ pub fn sql_expr>(s: S) -> SQLPlannerResult { } fn ident_to_str(ident: &Ident) -> String { - if let Some('"') = ident.quote_style { + if ident.quote_style == Some('"') { ident.value.to_string() } else { ident.to_string() diff --git a/src/daft-sql/src/python.rs b/src/daft-sql/src/python.rs index 283184f014..b61d3fedd2 100644 --- a/src/daft-sql/src/python.rs +++ b/src/daft-sql/src/python.rs @@ -5,6 +5,31 @@ use pyo3::prelude::*; use crate::{catalog::SQLCatalog, functions::SQL_FUNCTIONS, planner::SQLPlanner}; +#[pyclass] +pub struct SQLFunctionStub { + name: String, + docstring: String, + arg_names: Vec<&'static str>, +} + +#[pymethods] +impl SQLFunctionStub { + #[getter] + fn name(&self) -> PyResult { + Ok(self.name.clone()) + } + + #[getter] + fn docstring(&self) -> PyResult { + Ok(self.docstring.clone()) + } + + #[getter] + fn arg_names(&self) -> PyResult> { + Ok(self.arg_names.clone()) + } +} + #[pyfunction] pub fn sql( sql: &str, @@ -23,8 +48,20 @@ pub fn sql_expr(sql: &str) -> PyResult { } #[pyfunction] -pub fn list_sql_functions() -> Vec { - SQL_FUNCTIONS.map.keys().cloned().collect() +pub fn list_sql_functions() -> Vec { + SQL_FUNCTIONS + .map + .keys() + .cloned() + .map(|name| { + let (docstring, args) = SQL_FUNCTIONS.docsmap.get(&name).unwrap(); + SQLFunctionStub { + name, + docstring: docstring.to_string(), + arg_names: args.to_vec(), + } + }) + .collect() } /// PyCatalog is the Python interface to the Catalog. diff --git a/src/daft-sql/src/table_provider/mod.rs b/src/daft-sql/src/table_provider/mod.rs new file mode 100644 index 0000000000..453fa4965f --- /dev/null +++ b/src/daft-sql/src/table_provider/mod.rs @@ -0,0 +1,119 @@ +pub mod read_parquet; +use std::{collections::HashMap, sync::Arc}; + +use daft_plan::LogicalPlanBuilder; +use once_cell::sync::Lazy; +use read_parquet::ReadParquetFunction; +use sqlparser::ast::{TableAlias, TableFunctionArgs}; + +use crate::{ + error::SQLPlannerResult, + modules::config::expr_to_iocfg, + planner::{Relation, SQLPlanner}, + unsupported_sql_err, +}; + +pub(crate) static SQL_TABLE_FUNCTIONS: Lazy = Lazy::new(|| { + let mut functions = SQLTableFunctions::new(); + functions.add_fn("read_parquet", ReadParquetFunction); + #[cfg(feature = "python")] + functions.add_fn("read_deltalake", ReadDeltalakeFunction); + + functions +}); + +/// TODOs +/// - Use multimap for function variants. +/// - Add more functions.. +pub struct SQLTableFunctions { + pub(crate) map: HashMap>, +} + +impl SQLTableFunctions { + /// Create a new [SQLFunctions] instance. + pub fn new() -> Self { + Self { + map: HashMap::new(), + } + } + /// Add a [FunctionExpr] to the [SQLFunctions] instance. + pub(crate) fn add_fn(&mut self, name: &str, func: F) { + self.map.insert(name.to_string(), Arc::new(func)); + } + + /// Get a function by name from the [SQLFunctions] instance. + pub(crate) fn get(&self, name: &str) -> Option<&Arc> { + self.map.get(name) + } +} + +impl SQLPlanner { + pub(crate) fn plan_table_function( + &self, + fn_name: &str, + args: &TableFunctionArgs, + alias: &Option, + ) -> SQLPlannerResult { + let fns = &SQL_TABLE_FUNCTIONS; + + let Some(func) = fns.get(fn_name) else { + unsupported_sql_err!("Function `{}` not found", fn_name); + }; + + let builder = func.plan(self, args)?; + let name = alias + .as_ref() + .map(|a| a.name.value.clone()) + .unwrap_or_else(|| fn_name.to_string()); + + Ok(Relation::new(builder, name)) + } +} + +pub(crate) trait SQLTableFunction: Send + Sync { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult; +} + +pub struct ReadDeltalakeFunction; + +#[cfg(feature = "python")] +impl SQLTableFunction for ReadDeltalakeFunction { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult { + let (uri, io_config) = match args.args.as_slice() { + [uri] => (uri, None), + [uri, io_config] => { + let args = planner.parse_function_args(&[io_config.clone()], &["io_config"], 0)?; + let io_config = args.get_named("io_config").map(expr_to_iocfg).transpose()?; + + (uri, io_config) + } + _ => unsupported_sql_err!("Expected one or two arguments"), + }; + let uri = planner.plan_function_arg(uri)?; + + let Some(uri) = uri.as_literal().and_then(|lit| lit.as_str()) else { + unsupported_sql_err!("Expected a string literal for the first argument"); + }; + + LogicalPlanBuilder::delta_scan(uri, io_config, true).map_err(From::from) + } +} + +#[cfg(not(feature = "python"))] +impl SQLTableFunction for ReadDeltalakeFunction { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult { + unsupported_sql_err!("`read_deltalake` function is not supported. Enable the `python` feature to use this function.") + } +} diff --git a/src/daft-sql/src/table_provider/read_parquet.rs b/src/daft-sql/src/table_provider/read_parquet.rs new file mode 100644 index 0000000000..36f84507a0 --- /dev/null +++ b/src/daft-sql/src/table_provider/read_parquet.rs @@ -0,0 +1,77 @@ +use daft_core::prelude::TimeUnit; +use daft_plan::{LogicalPlanBuilder, ParquetScanBuilder}; +use sqlparser::ast::TableFunctionArgs; + +use super::SQLTableFunction; +use crate::{ + error::{PlannerError, SQLPlannerResult}, + functions::SQLFunctionArguments, + modules::config::expr_to_iocfg, + planner::SQLPlanner, +}; + +pub(super) struct ReadParquetFunction; + +impl TryFrom for ParquetScanBuilder { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let glob_paths: String = args.try_get_positional(0)?.ok_or_else(|| { + PlannerError::invalid_operation("path is required for `read_parquet`") + })?; + let infer_schema = args.try_get_named("infer_schema")?.unwrap_or(true); + let coerce_int96_timestamp_unit = + args.try_get_named::("coerce_int96_timestamp_unit")?; + let coerce_int96_timestamp_unit: TimeUnit = coerce_int96_timestamp_unit + .as_deref() + .unwrap_or("nanoseconds") + .parse::() + .map_err(|_| { + PlannerError::invalid_argument("coerce_int96_timestamp_unit", "read_parquet") + })?; + let chunk_size = args.try_get_named("chunk_size")?; + let multithreaded = args.try_get_named("multithreaded")?.unwrap_or(true); + + let field_id_mapping = None; // TODO + let row_groups = None; // TODO + let schema = None; // TODO + let io_config = args.get_named("io_config").map(expr_to_iocfg).transpose()?; + + Ok(Self { + glob_paths: vec![glob_paths], + infer_schema, + coerce_int96_timestamp_unit, + field_id_mapping, + row_groups, + chunk_size, + io_config, + multithreaded, + schema, + }) + } +} + +impl SQLTableFunction for ReadParquetFunction { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult { + let builder: ParquetScanBuilder = planner.plan_function_args( + args.args.as_slice(), + &[ + "infer_schema", + "coerce_int96_timestamp_unit", + "chunk_size", + "multithreaded", + // "schema", + // "field_id_mapping", + // "row_groups", + "io_config", + ], + 1, // 1 positional argument (path) + )?; + + builder.finish().map_err(From::from) + } +} diff --git a/src/daft-stats/src/column_stats/logical.rs b/src/daft-stats/src/column_stats/logical.rs index 29b2d47421..6b3c63b471 100644 --- a/src/daft-stats/src/column_stats/logical.rs +++ b/src/daft-stats/src/column_stats/logical.rs @@ -31,7 +31,7 @@ impl std::ops::BitAnd for &ColumnRangeStatistics { let lt = self.to_truth_value(); let rt = rhs.to_truth_value(); - use TruthValue::*; + use TruthValue::{False, Maybe, True}; let nv = match (lt, rt) { (False, _) => False, (_, False) => False, @@ -55,7 +55,7 @@ impl std::ops::BitOr for &ColumnRangeStatistics { // +-------+-------+-------+------+ let lt = self.to_truth_value(); let rt = rhs.to_truth_value(); - use TruthValue::*; + use TruthValue::{False, Maybe, True}; let nv = match (lt, rt) { (False, False) => False, (True, _) => True, diff --git a/src/daft-stats/src/column_stats/mod.rs b/src/daft-stats/src/column_stats/mod.rs index df96daa373..b5f71f7771 100644 --- a/src/daft-stats/src/column_stats/mod.rs +++ b/src/daft-stats/src/column_stats/mod.rs @@ -14,7 +14,7 @@ pub enum ColumnRangeStatistics { Loaded(Series, Series), } -#[derive(PartialEq, Debug)] +#[derive(PartialEq, Eq, Debug)] pub enum TruthValue { False, Maybe, @@ -52,6 +52,7 @@ impl ColumnRangeStatistics { } } + #[must_use] pub fn supports_dtype(dtype: &DataType) -> bool { match dtype { // SUPPORTED TYPES: @@ -71,12 +72,13 @@ impl ColumnRangeStatistics { // UNSUPPORTED TYPES: // Types that don't support comparisons and can't be used as ColumnRangeStatistics - DataType::List(..) | DataType::FixedSizeList(..) | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::SparseTensor(..) | DataType::FixedShapeSparseTensor(..) | DataType::FixedShapeTensor(..) | DataType::Struct(..) | DataType::Map(..) | DataType::Extension(..) | DataType::Embedding(..) | DataType::Unknown => false, + DataType::List(..) | DataType::FixedSizeList(..) | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::SparseTensor(..) | DataType::FixedShapeSparseTensor(..) | DataType::FixedShapeTensor(..) | DataType::Struct(..) | DataType::Map { .. } | DataType::Extension(..) | DataType::Embedding(..) | DataType::Unknown => false, #[cfg(feature = "python")] DataType::Python => false, } } + #[must_use] pub fn to_truth_value(&self) -> TruthValue { match self { Self::Missing => TruthValue::Maybe, @@ -93,6 +95,7 @@ impl ColumnRangeStatistics { } } + #[must_use] pub fn from_truth_value(tv: TruthValue) -> Self { let (lower, upper) = match tv { TruthValue::False => (false, false), @@ -123,6 +126,7 @@ impl ColumnRangeStatistics { } } + #[must_use] pub fn from_series(series: &Series) -> Self { let lower = series.min(None).unwrap(); let upper = series.max(None).unwrap(); @@ -160,36 +164,35 @@ impl ColumnRangeStatistics { Self::Loaded(l, r) => { match (l.data_type(), dtype) { // Int casting to higher bitwidths - (DataType::Int8, DataType::Int16) | - (DataType::Int8, DataType::Int32) | - (DataType::Int8, DataType::Int64) | - (DataType::Int16, DataType::Int32) | - (DataType::Int16, DataType::Int64) | - (DataType::Int32, DataType::Int64) | - // UInt casting to higher bitwidths - (DataType::UInt8, DataType::UInt16) | - (DataType::UInt8, DataType::UInt32) | - (DataType::UInt8, DataType::UInt64) | - (DataType::UInt16, DataType::UInt32) | - (DataType::UInt16, DataType::UInt64) | - (DataType::UInt32, DataType::UInt64) | - // Float casting to higher bitwidths - (DataType::Float32, DataType::Float64) | - // Numeric to temporal casting from smaller-than-eq bitwidths - (DataType::Int8, DataType::Date) | - (DataType::Int16, DataType::Date) | - (DataType::Int32, DataType::Date) | - (DataType::Int8, DataType::Timestamp(..)) | - (DataType::Int16, DataType::Timestamp(..)) | - (DataType::Int32, DataType::Timestamp(..)) | - (DataType::Int64, DataType::Timestamp(..)) | - // Binary to Utf8 - (DataType::Binary, DataType::Utf8) - => Ok(Self::Loaded( + ( + DataType::Int8, + DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Date + | DataType::Timestamp(..), + ) + | ( + DataType::Int16, + DataType::Int32 + | DataType::Int64 + | DataType::Date + | DataType::Timestamp(..), + ) + | ( + DataType::Int32, + DataType::Int64 | DataType::Date | DataType::Timestamp(..), + ) + | (DataType::UInt8, DataType::UInt16 | DataType::UInt32 | DataType::UInt64) + | (DataType::UInt16, DataType::UInt32 | DataType::UInt64) + | (DataType::UInt32, DataType::UInt64) + | (DataType::Float32, DataType::Float64) + | (DataType::Int64, DataType::Timestamp(..)) + | (DataType::Binary, DataType::Utf8) => Ok(Self::Loaded( l.cast(dtype).context(DaftCoreComputeSnafu)?, r.cast(dtype).context(DaftCoreComputeSnafu)?, )), - _ => Ok(Self::Missing) + _ => Ok(Self::Missing), } } } @@ -203,10 +206,9 @@ impl std::fmt::Display for ColumnRangeStatistics { Self::Loaded(lower, upper) => write!( f, "ColumnRangeStatistics: -lower:\n{} -upper:\n{} - ", - lower, upper +lower:\n{lower} +upper:\n{upper} + " ), } } @@ -223,7 +225,7 @@ impl TryFrom<&daft_dsl::LiteralValue> for ColumnRangeStatistics { fn try_from(value: &daft_dsl::LiteralValue) -> crate::Result { let series = value.to_series(); assert_eq!(series.len(), 1); - Self::new(Some(series.clone()), Some(series.clone())) + Self::new(Some(series.clone()), Some(series)) } } diff --git a/src/daft-stats/src/partition_spec.rs b/src/daft-stats/src/partition_spec.rs index ccf6d1c713..24834bf116 100644 --- a/src/daft-stats/src/partition_spec.rs +++ b/src/daft-stats/src/partition_spec.rs @@ -10,12 +10,14 @@ pub struct PartitionSpec { } impl PartitionSpec { + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("Keys = {}", self.keys)); res } + #[must_use] pub fn to_fill_map(&self) -> HashMap<&str, ExprRef> { self.keys .schema diff --git a/src/daft-stats/src/table_metadata.rs b/src/daft-stats/src/table_metadata.rs index d7fc6d4cbf..bcd76e96c4 100644 --- a/src/daft-stats/src/table_metadata.rs +++ b/src/daft-stats/src/table_metadata.rs @@ -1,11 +1,12 @@ use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TableMetadata { pub length: usize, } impl TableMetadata { + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("Length = {}", self.length)); diff --git a/src/daft-stats/src/table_stats.rs b/src/daft-stats/src/table_stats.rs index 0fff747c98..e0d91d24c6 100644 --- a/src/daft-stats/src/table_stats.rs +++ b/src/daft-stats/src/table_stats.rs @@ -34,6 +34,7 @@ impl TableStatistics { Ok(Self { columns }) } + #[must_use] pub fn from_table(table: &Table) -> Self { let mut columns = IndexMap::with_capacity(table.num_columns()); for name in table.column_names() { @@ -106,7 +107,11 @@ impl TableStatistics { sum_so_far += elem_size; } } else { - for elem_size in self.columns.values().map(|c| c.element_size()) { + for elem_size in self + .columns + .values() + .map(super::column_stats::ColumnRangeStatistics::element_size) + { sum_so_far += elem_size?.unwrap_or(0.); } } @@ -119,20 +124,20 @@ impl TableStatistics { Expr::Alias(col, _) => self.eval_expression(col.as_ref()), Expr::Column(col_name) => { let col = self.columns.get(col_name.as_ref()); - if let Some(col) = col { - Ok(col.clone()) - } else { - Err(crate::Error::DaftCoreCompute { + let Some(col) = col else { + return Err(crate::Error::DaftCoreCompute { source: DaftError::FieldNotFound(col_name.to_string()), - }) - } + }); + }; + + Ok(col.clone()) } Expr::Literal(lit_value) => lit_value.try_into(), Expr::Not(col) => self.eval_expression(col)?.not(), Expr::BinaryOp { op, left, right } => { let lhs = self.eval_expression(left)?; let rhs = self.eval_expression(right)?; - use daft_dsl::Operator::*; + use daft_dsl::Operator::{And, Eq, Gt, GtEq, Lt, LtEq, Minus, NotEq, Or, Plus}; match op { Lt => lhs.lt(&rhs), LtEq => lhs.lte(&rhs), @@ -161,7 +166,7 @@ impl TableStatistics { fill_map: Option<&HashMap<&str, ExprRef>>, ) -> crate::Result { let mut columns = IndexMap::new(); - for (field_name, field) in schema.fields.iter() { + for (field_name, field) in &schema.fields { let crs = match self.columns.get(field_name) { Some(column_stat) => column_stat .cast(&field.dtype) @@ -194,7 +199,6 @@ impl Display for TableStatistics { #[cfg(test)] mod test { - use daft_core::prelude::*; use daft_dsl::{col, lit}; use daft_table::Table; diff --git a/src/daft-table/src/ffi.rs b/src/daft-table/src/ffi.rs index 37a118c50e..81d495728c 100644 --- a/src/daft-table/src/ffi.rs +++ b/src/daft-table/src/ffi.rs @@ -42,12 +42,12 @@ pub fn record_batches_to_table( let columns = cols .into_iter() .enumerate() - .map(|(i, c)| { - let c = cast_array_for_daft_if_needed(c); - Series::try_from((names.get(i).unwrap().as_str(), c)) + .map(|(i, array)| { + let cast_array = cast_array_for_daft_if_needed(array); + Series::try_from((names.get(i).unwrap().as_str(), cast_array)) }) .collect::>>()?; - tables.push(Table::new_with_size(schema.clone(), columns, num_rows)?) + tables.push(Table::new_with_size(schema.clone(), columns, num_rows)?); } Ok(Table::concat(tables.as_slice())?) }) @@ -72,7 +72,7 @@ pub fn table_to_record_batch( let record = pyarrow .getattr(pyo3::intern!(py, "RecordBatch"))? - .call_method1(pyo3::intern!(py, "from_arrays"), (arrays, names.to_vec()))?; + .call_method1(pyo3::intern!(py, "from_arrays"), (arrays, names.clone()))?; Ok(record.into()) } diff --git a/src/daft-table/src/growable/mod.rs b/src/daft-table/src/growable/mod.rs index de736ff29e..6404a78c04 100644 --- a/src/daft-table/src/growable/mod.rs +++ b/src/daft-table/src/growable/mod.rs @@ -51,7 +51,7 @@ impl<'a> GrowableTable<'a> { if !self.growables.is_empty() { self.growables .iter_mut() - .for_each(|g| g.extend(index, start, len)) + .for_each(|g| g.extend(index, start, len)); } } @@ -60,7 +60,7 @@ impl<'a> GrowableTable<'a> { if !self.growables.is_empty() { self.growables .iter_mut() - .for_each(|g| g.add_nulls(additional)) + .for_each(|g| g.add_nulls(additional)); } } diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 3669fda3f5..cf96344a53 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -1,5 +1,6 @@ #![feature(hash_raw_entry)] #![feature(let_chains)] +#![feature(iterator_try_collect)] use core::slice; use std::{ @@ -145,9 +146,9 @@ impl Table { pub fn empty(schema: Option) -> DaftResult { let schema = schema.unwrap_or_else(|| Schema::empty().into()); let mut columns: Vec = Vec::with_capacity(schema.names().len()); - for (field_name, field) in schema.fields.iter() { + for (field_name, field) in &schema.fields { let series = Series::empty(field_name, &field.dtype); - columns.push(series) + columns.push(series); } Ok(Self::new_unchecked(schema, columns, 0)) } @@ -160,9 +161,7 @@ impl Table { /// /// * `columns` - Columns to crate a table from as [`Series`] objects pub fn from_nonempty_columns(columns: Vec) -> DaftResult { - if columns.is_empty() { - panic!("Cannot call Table::new() with empty columns. This indicates an internal error, please file an issue."); - } + assert!(!columns.is_empty(), "Cannot call Table::new() with empty columns. This indicates an internal error, please file an issue."); let schema = Schema::new(columns.iter().map(|s| s.field().clone()).collect())?; let schema: SchemaRef = schema.into(); @@ -342,7 +341,7 @@ impl Table { let num_filtered = mask .validity() .map(|validity| arrow2::bitmap::and(validity, mask.as_bitmap()).unset_bits()) - .unwrap_or(mask.as_bitmap().unset_bits()); + .unwrap_or_else(|| mask.as_bitmap().unset_bits()); mask.len() - num_filtered }; @@ -480,6 +479,7 @@ impl Table { } } AggExpr::Mean(expr) => self.eval_expression(expr)?.mean(groups), + AggExpr::Stddev(expr) => self.eval_expression(expr)?.stddev(groups), AggExpr::Min(expr) => self.eval_expression(expr)?.min(groups), AggExpr::Max(expr) => self.eval_expression(expr)?.max(groups), &AggExpr::AnyValue(ref expr, ignore_nulls) => { @@ -495,6 +495,7 @@ impl Table { fn eval_expression(&self, expr: &Expr) -> DaftResult { use crate::Expr::*; + let expected_field = expr.to_field(self.schema.as_ref())?; let series = match expr { Alias(child, name) => Ok(self.eval_expression(child)?.rename(name)), @@ -572,6 +573,7 @@ impl Table { } }, }?; + if expected_field.name != series.field().name { return Err(DaftError::ComputeError(format!( "Mismatch of expected expression name and name from computed series ({} vs {}) for expression: {expr}", @@ -579,32 +581,41 @@ impl Table { series.field().name ))); } - if expected_field.dtype != series.field().dtype { - panic!("Mismatch of expected expression data type and data type from computed series, {} vs {}", expected_field.dtype, series.field().dtype); - } + + assert!( + !(expected_field.dtype != series.field().dtype), + "Data type mismatch in expression evaluation:\n\ + Expected type: {}\n\ + Computed type: {}\n\ + Expression: {}\n\ + This likely indicates an internal error in type inference or computation.", + expected_field.dtype, + series.field().dtype, + expr + ); Ok(series) } pub fn eval_expression_list(&self, exprs: &[ExprRef]) -> DaftResult { - let result_series = exprs + let result_series: Vec<_> = exprs .iter() .map(|e| self.eval_expression(e)) - .collect::>>()?; + .try_collect()?; - let fields = result_series - .iter() - .map(|s| s.field().clone()) - .collect::>(); - let mut seen: HashSet = HashSet::new(); - for field in fields.iter() { + let fields: Vec<_> = result_series.iter().map(|s| s.field().clone()).collect(); + + let mut seen = HashSet::new(); + + for field in &fields { let name = &field.name; if seen.contains(name) { return Err(DaftError::ValueError(format!( "Duplicate name found when evaluating expressions: {name}" ))); } - seen.insert(name.clone()); + seen.insert(name); } + let new_schema = Schema::new(fields)?; let has_agg_expr = exprs.iter().any(|e| matches!(e.as_ref(), Expr::Agg(..))); @@ -696,16 +707,11 @@ impl Table { // Begin the body. res.push_str("\n"); - let head_rows; - let tail_rows; - - if self.len() > 10 { - head_rows = 5; - tail_rows = 5; + let (head_rows, tail_rows) = if self.len() > 10 { + (5, 5) } else { - head_rows = self.len(); - tail_rows = 0; - } + (self.len(), 0) + }; let styled_td = ""); - for col in self.columns.iter() { + for col in &self.columns { res.push_str(styled_td); res.push_str(&html_value(col, i)); res.push_str(""); @@ -726,7 +732,7 @@ impl Table { if tail_rows != 0 { res.push_str(""); - for _ in self.columns.iter() { + for _ in &self.columns { res.push_str(""); } res.push_str("\n"); @@ -736,7 +742,7 @@ impl Table { // Begin row. res.push_str(""); - for col in self.columns.iter() { + for col in &self.columns { res.push_str(styled_td); res.push_str(&html_value(col, i)); res.push_str(""); diff --git a/src/daft-table/src/ops/agg.rs b/src/daft-table/src/ops/agg.rs index 70abdf69f4..93ef8425d7 100644 --- a/src/daft-table/src/ops/agg.rs +++ b/src/daft-table/src/ops/agg.rs @@ -100,7 +100,7 @@ impl Table { // Take fast path short circuit if there is only 1 group let (groupkeys_table, grouped_col) = if groupvals_indices.is_empty() { - let empty_groupkeys_table = Self::empty(Some(groupby_table.schema.clone()))?; + let empty_groupkeys_table = Self::empty(Some(groupby_table.schema))?; let empty_udf_output_col = Series::empty( evaluated_inputs .first() diff --git a/src/daft-table/src/ops/explode.rs b/src/daft-table/src/ops/explode.rs index 2c0fc0fee3..bdd715ac4a 100644 --- a/src/daft-table/src/ops/explode.rs +++ b/src/daft-table/src/ops/explode.rs @@ -73,7 +73,7 @@ impl Table { } let mut exploded_columns = evaluated_columns .iter() - .map(|c| c.explode()) + .map(daft_core::series::Series::explode) .collect::>>()?; let capacity_expected = exploded_columns.first().unwrap().len(); diff --git a/src/daft-table/src/ops/groups.rs b/src/daft-table/src/ops/groups.rs index 76e6c04c33..1edccccdc7 100644 --- a/src/daft-table/src/ops/groups.rs +++ b/src/daft-table/src/ops/groups.rs @@ -32,7 +32,7 @@ impl Table { let mut key_indices: Vec = Vec::with_capacity(probe_table.len()); let mut values_indices: Vec> = Vec::with_capacity(probe_table.len()); - for (idx_hash, val_idx) in probe_table.into_iter() { + for (idx_hash, val_idx) in probe_table { key_indices.push(idx_hash.idx); values_indices.push(val_idx); } diff --git a/src/daft-table/src/ops/hash.rs b/src/daft-table/src/ops/hash.rs index 0abdcb8867..c011597c3f 100644 --- a/src/daft-table/src/ops/hash.rs +++ b/src/daft-table/src/ops/hash.rs @@ -19,7 +19,7 @@ pub struct IndexHash { impl Hash for IndexHash { fn hash(&self, state: &mut H) { - state.write_u64(self.hash) + state.write_u64(self.hash); } } diff --git a/src/daft-table/src/ops/joins/merge_join.rs b/src/daft-table/src/ops/joins/merge_join.rs index eb57db2d1e..4b5a861811 100644 --- a/src/daft-table/src/ops/joins/merge_join.rs +++ b/src/daft-table/src/ops/joins/merge_join.rs @@ -88,7 +88,7 @@ pub fn merge_inner_join(left: &Table, right: &Table) -> DaftResult<(Series, Seri )?); } let combined_comparator = |a_idx: usize, b_idx: usize| -> Option { - for comparator in cmp_list.iter() { + for comparator in &cmp_list { match comparator(a_idx, b_idx) { Some(Ordering::Equal) => continue, other => return other, @@ -218,11 +218,11 @@ pub fn merge_inner_join(left: &Table, right: &Table) -> DaftResult<(Series, Seri match state { // If extending a left-side run or propagating an existing right-side run, move left pointer forward. MergeJoinState::LeftEqualRun(_) | MergeJoinState::StagedRightEqualRun(_) => { - left_idx += 1 + left_idx += 1; } // If extending a right-side run or propagating an existing left-side run, move right pointer forward. MergeJoinState::RightEqualRun(_) | MergeJoinState::StagedLeftEqualRun(_) => { - right_idx += 1 + right_idx += 1; } _ => unreachable!(), } diff --git a/src/daft-table/src/ops/joins/mod.rs b/src/daft-table/src/ops/joins/mod.rs index 5bb4f77d4b..0c6b678d35 100644 --- a/src/daft-table/src/ops/joins/mod.rs +++ b/src/daft-table/src/ops/joins/mod.rs @@ -52,9 +52,8 @@ fn add_non_join_key_columns( for field in left.schema.fields.values() { if join_keys.contains(&field.name) { continue; - } else { - join_series.push(left.get_column(&field.name)?.take(&lidx)?); } + join_series.push(left.get_column(&field.name)?.take(&lidx)?); } drop(lidx); @@ -62,9 +61,9 @@ fn add_non_join_key_columns( for field in right.schema.fields.values() { if join_keys.contains(&field.name) { continue; - } else { - join_series.push(right.get_column(&field.name)?.take(&ridx)?); } + + join_series.push(right.get_column(&field.name)?.take(&ridx)?); } Ok(join_series) diff --git a/src/daft-table/src/ops/mod.rs b/src/daft-table/src/ops/mod.rs index 6d53b60e13..66e2a958e2 100644 --- a/src/daft-table/src/ops/mod.rs +++ b/src/daft-table/src/ops/mod.rs @@ -1,7 +1,7 @@ mod agg; mod explode; mod groups; -pub(crate) mod hash; +pub mod hash; mod joins; mod partition; mod pivot; diff --git a/src/daft-table/src/ops/partition.rs b/src/daft-table/src/ops/partition.rs index 6d07d3f778..93d61f1547 100644 --- a/src/daft-table/src/ops/partition.rs +++ b/src/daft-table/src/ops/partition.rs @@ -36,7 +36,7 @@ impl Table { for (s_idx, t_idx) in targets.as_arrow().values_iter().enumerate() { if *t_idx >= (num_partitions as u64) { - return Err(DaftError::ComputeError(format!("idx in target array is out of bounds, target idx {} at index {} out of {} partitions", t_idx, s_idx, num_partitions))); + return Err(DaftError::ComputeError(format!("idx in target array is out of bounds, target idx {t_idx} at index {s_idx} out of {num_partitions} partitions"))); } output_to_input_idx[unsafe { t_idx.as_usize() }].push(s_idx as u64); diff --git a/src/daft-table/src/ops/pivot.rs b/src/daft-table/src/ops/pivot.rs index 4418d4365e..2eaf7274e6 100644 --- a/src/daft-table/src/ops/pivot.rs +++ b/src/daft-table/src/ops/pivot.rs @@ -23,7 +23,7 @@ fn map_name_to_pivot_key_idx<'a>( .collect::>(); let mut name_to_pivot_key_idx_mapping = std::collections::HashMap::new(); - for name in names.iter() { + for name in names { if let Some(pivot_key_idx) = pivot_key_str_to_idx_mapping.get(name.as_str()) { name_to_pivot_key_idx_mapping.insert(name, *pivot_key_idx); } @@ -46,7 +46,7 @@ fn map_pivot_key_idx_to_values_indices( for (p_key, p_indices) in pivot_keys_indices.iter().zip(pivot_vals_indices.iter()) { let p_indices_hashset = p_indices.iter().collect::>(); let mut values_indices = Vec::new(); - for g_indices_hashset in group_vals_indices_hashsets.iter() { + for g_indices_hashset in &group_vals_indices_hashsets { let matches = g_indices_hashset .intersection(&p_indices_hashset) .collect::>(); diff --git a/src/daft-table/src/probeable/probe_set.rs b/src/daft-table/src/probeable/probe_set.rs index 0fdff1e0fc..a948ad2a4b 100644 --- a/src/daft-table/src/probeable/probe_set.rs +++ b/src/daft-table/src/probeable/probe_set.rs @@ -15,7 +15,7 @@ use daft_core::{ use super::{ArrowTableEntry, IndicesMapper, Probeable, ProbeableBuilder}; use crate::{ops::hash::IndexHash, Table}; -pub(crate) struct ProbeSet { +pub struct ProbeSet { schema: SchemaRef, hash_table: HashMap, tables: Vec, @@ -156,7 +156,7 @@ impl Probeable for ProbeSet { } } -pub(crate) struct ProbeSetBuilder(pub ProbeSet); +pub struct ProbeSetBuilder(pub ProbeSet); impl ProbeableBuilder for ProbeSetBuilder { fn add_table(&mut self, table: &Table) -> DaftResult<()> { diff --git a/src/daft-table/src/probeable/probe_table.rs b/src/daft-table/src/probeable/probe_table.rs index c8e401084f..c0a4bde0de 100644 --- a/src/daft-table/src/probeable/probe_table.rs +++ b/src/daft-table/src/probeable/probe_table.rs @@ -16,7 +16,7 @@ use daft_core::{ use super::{ArrowTableEntry, IndicesMapper, Probeable, ProbeableBuilder}; use crate::{ops::hash::IndexHash, Table}; -pub(crate) struct ProbeTable { +pub struct ProbeTable { schema: SchemaRef, hash_table: HashMap, IdentityBuildHasher>, tables: Vec, @@ -52,7 +52,7 @@ impl ProbeTable { fn probe<'a>( &'a self, input: &'a Table, - ) -> DaftResult> + 'a> { + ) -> DaftResult> + 'a> { assert_eq!(self.schema.len(), input.schema.len()); assert!(self .schema @@ -173,7 +173,7 @@ impl Probeable for ProbeTable { } } -pub(crate) struct ProbeTableBuilder(pub ProbeTable); +pub struct ProbeTableBuilder(pub ProbeTable); impl ProbeableBuilder for ProbeTableBuilder { fn add_table(&mut self, table: &Table) -> DaftResult<()> { diff --git a/src/daft-table/src/python.rs b/src/daft-table/src/python.rs index 3bacbcf019..89f1a12016 100644 --- a/src/daft-table/src/python.rs +++ b/src/daft-table/src/python.rs @@ -43,7 +43,8 @@ impl PyTable { } pub fn filter(&self, py: Python, exprs: Vec) -> PyResult { - let converted_exprs: Vec = exprs.into_iter().map(|e| e.into()).collect(); + let converted_exprs: Vec = + exprs.into_iter().map(std::convert::Into::into).collect(); py.allow_threads(|| Ok(self.table.filter(converted_exprs.as_slice())?.into())) } @@ -53,8 +54,10 @@ impl PyTable { sort_keys: Vec, descending: Vec, ) -> PyResult { - let converted_exprs: Vec = - sort_keys.into_iter().map(|e| e.into()).collect(); + let converted_exprs: Vec = sort_keys + .into_iter() + .map(std::convert::Into::into) + .collect(); py.allow_threads(|| { Ok(self .table @@ -69,8 +72,10 @@ impl PyTable { sort_keys: Vec, descending: Vec, ) -> PyResult { - let converted_exprs: Vec = - sort_keys.into_iter().map(|e| e.into()).collect(); + let converted_exprs: Vec = sort_keys + .into_iter() + .map(std::convert::Into::into) + .collect(); py.allow_threads(|| { Ok(self .table @@ -81,9 +86,9 @@ impl PyTable { pub fn agg(&self, py: Python, to_agg: Vec, group_by: Vec) -> PyResult { let converted_to_agg: Vec = - to_agg.into_iter().map(|e| e.into()).collect(); + to_agg.into_iter().map(std::convert::Into::into).collect(); let converted_group_by: Vec = - group_by.into_iter().map(|e| e.into()).collect(); + group_by.into_iter().map(std::convert::Into::into).collect(); py.allow_threads(|| { Ok(self .table @@ -101,7 +106,7 @@ impl PyTable { names: Vec, ) -> PyResult { let converted_group_by: Vec = - group_by.into_iter().map(|e| e.into()).collect(); + group_by.into_iter().map(std::convert::Into::into).collect(); let converted_pivot_col: daft_dsl::ExprRef = pivot_col.into(); let converted_values_col: daft_dsl::ExprRef = values_col.into(); py.allow_threads(|| { @@ -125,8 +130,10 @@ impl PyTable { right_on: Vec, how: JoinType, ) -> PyResult { - let left_exprs: Vec = left_on.into_iter().map(|e| e.into()).collect(); - let right_exprs: Vec = right_on.into_iter().map(|e| e.into()).collect(); + let left_exprs: Vec = + left_on.into_iter().map(std::convert::Into::into).collect(); + let right_exprs: Vec = + right_on.into_iter().map(std::convert::Into::into).collect(); py.allow_threads(|| { Ok(self .table @@ -148,8 +155,10 @@ impl PyTable { right_on: Vec, is_sorted: bool, ) -> PyResult { - let left_exprs: Vec = left_on.into_iter().map(|e| e.into()).collect(); - let right_exprs: Vec = right_on.into_iter().map(|e| e.into()).collect(); + let left_exprs: Vec = + left_on.into_iter().map(std::convert::Into::into).collect(); + let right_exprs: Vec = + right_on.into_iter().map(std::convert::Into::into).collect(); py.allow_threads(|| { Ok(self .table @@ -254,13 +263,14 @@ impl PyTable { "Can not partition into negative number of partitions: {num_partitions}" ))); } - let exprs: Vec = exprs.into_iter().map(|e| e.into()).collect(); + let exprs: Vec = + exprs.into_iter().map(std::convert::Into::into).collect(); py.allow_threads(|| { Ok(self .table .partition_by_hash(exprs.as_slice(), num_partitions as usize)? .into_iter() - .map(|t| t.into()) + .map(std::convert::Into::into) .collect::>()) }) } @@ -287,7 +297,7 @@ impl PyTable { .table .partition_by_random(num_partitions as usize, seed as u64)? .into_iter() - .map(|t| t.into()) + .map(std::convert::Into::into) .collect::>()) }) } @@ -299,13 +309,16 @@ impl PyTable { boundaries: &Self, descending: Vec, ) -> PyResult> { - let exprs: Vec = partition_keys.into_iter().map(|e| e.into()).collect(); + let exprs: Vec = partition_keys + .into_iter() + .map(std::convert::Into::into) + .collect(); py.allow_threads(|| { Ok(self .table .partition_by_range(exprs.as_slice(), &boundaries.table, descending.as_slice())? .into_iter() - .map(|t| t.into()) + .map(std::convert::Into::into) .collect::>()) }) } @@ -315,10 +328,16 @@ impl PyTable { py: Python, partition_keys: Vec, ) -> PyResult<(Vec, Self)> { - let exprs: Vec = partition_keys.into_iter().map(|e| e.into()).collect(); + let exprs: Vec = partition_keys + .into_iter() + .map(std::convert::Into::into) + .collect(); py.allow_threads(|| { let (tables, values) = self.table.partition_by_value(exprs.as_slice())?; - let pytables = tables.into_iter().map(|t| t.into()).collect::>(); + let pytables = tables + .into_iter() + .map(std::convert::Into::into) + .collect::>(); let values = values.into(); Ok((pytables, values)) }) @@ -346,6 +365,7 @@ impl PyTable { Ok(self.table.size_bytes()?) } + #[must_use] pub fn column_names(&self) -> Vec { self.table.column_names() } @@ -414,13 +434,13 @@ impl PyTable { fields.reserve(dict.len()); columns.reserve(dict.len()); - for (name, series) in dict.into_iter() { + for (name, series) in dict { let series = series.series; fields.push(Field::new(name.as_str(), series.data_type().clone())); columns.push(series.rename(name)); } - let num_rows = columns.first().map(|s| s.len()).unwrap_or(0); + let num_rows = columns.first().map_or(0, daft_core::series::Series::len); if !columns.is_empty() { let first = columns.first().unwrap(); for s in columns.iter().skip(1) { diff --git a/src/daft-table/src/repr_html.rs b/src/daft-table/src/repr_html.rs index 79ecaf063a..0e46bb80b2 100644 --- a/src/daft-table/src/repr_html.rs +++ b/src/daft-table/src/repr_html.rs @@ -102,7 +102,7 @@ pub fn html_value(s: &Series, idx: usize) -> String { let arr = s.struct_().unwrap(); arr.html_value(idx) } - DataType::Map(_) => { + DataType::Map { .. } => { let arr = s.map().unwrap(); arr.html_value(idx) } diff --git a/src/hyperloglog/src/lib.rs b/src/hyperloglog/src/lib.rs index 3f63eebd55..240ce78220 100644 --- a/src/hyperloglog/src/lib.rs +++ b/src/hyperloglog/src/lib.rs @@ -60,6 +60,7 @@ impl Default for HyperLogLog<'_> { } impl<'a> HyperLogLog<'a> { + #[must_use] pub fn new_with_byte_slice(slice: &'a [u8]) -> Self { assert_eq!( slice.len(), @@ -77,6 +78,7 @@ impl<'a> HyperLogLog<'a> { impl HyperLogLog<'_> { /// Creates a new, empty HyperLogLog. + #[must_use] pub fn new() -> Self { let registers = [0; NUM_REGISTERS]; Self::new_with_registers(registers) @@ -85,6 +87,7 @@ impl HyperLogLog<'_> { /// Creates a HyperLogLog from already populated registers /// note that this method should not be invoked in untrusted environment /// because the internal structure of registers are not examined. + #[must_use] pub fn new_with_registers(registers: [u8; NUM_REGISTERS]) -> Self { Self { registers: Cow::Owned(registers), @@ -127,6 +130,7 @@ impl HyperLogLog<'_> { } /// Guess the number of unique elements seen by the HyperLogLog. + #[must_use] pub fn count(&self) -> usize { let histogram = self.get_histogram(); let m = NUM_REGISTERS as f64; diff --git a/tests/actor_pool/test_pyactor_pool.py b/tests/actor_pool/test_pyactor_pool.py index e95feec9ed..f34d91bd7b 100644 --- a/tests/actor_pool/test_pyactor_pool.py +++ b/tests/actor_pool/test_pyactor_pool.py @@ -69,5 +69,7 @@ def test_pyactor_pool_not_enough_resources(): assert isinstance(runner, PyRunner) with pytest.raises(RuntimeError, match=f"Requested {float(cpu_count + 1)} CPUs but found only"): - with runner.actor_pool_context("my-pool", ResourceRequest(num_cpus=1), cpu_count + 1, projection) as _: + with runner.actor_pool_context( + "my-pool", ResourceRequest(num_cpus=1), ResourceRequest(), cpu_count + 1, projection + ) as _: pass diff --git a/tests/cookbook/test_write.py b/tests/cookbook/test_write.py index 46db61d47e..ddd2c9b040 100644 --- a/tests/cookbook/test_write.py +++ b/tests/cookbook/test_write.py @@ -199,6 +199,26 @@ def test_parquet_write_multifile_with_partitioning(tmp_path, smaller_parquet_tar assert readback["y"] == [y % 2 for y in data["x"]] +def test_parquet_write_with_some_empty_partitions(tmp_path): + data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} + output_files = daft.from_pydict(data).into_partitions(4).write_parquet(tmp_path) + + assert len(output_files) == 3 + + read_back = daft.read_parquet(tmp_path.as_posix() + "/**/*.parquet").sort("x").to_pydict() + assert read_back == data + + +def test_parquet_partitioned_write_with_some_empty_partitions(tmp_path): + data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} + output_files = daft.from_pydict(data).into_partitions(4).write_parquet(tmp_path, partition_cols=["x"]) + + assert len(output_files) == 3 + + read_back = daft.read_parquet(tmp_path.as_posix() + "/**/*.parquet").sort("x").to_pydict() + assert read_back == data + + def test_csv_write(tmp_path): df = daft.read_csv(COOKBOOK_DATA_CSV) @@ -262,3 +282,23 @@ def test_empty_csv_write_with_partitioning(tmp_path): assert len(pd_df) == 1 assert len(pd_df._preview.preview_partition) == 1 + + +def test_csv_write_with_some_empty_partitions(tmp_path): + data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} + output_files = daft.from_pydict(data).into_partitions(4).write_csv(tmp_path) + + assert len(output_files) == 3 + + read_back = daft.read_csv(tmp_path.as_posix() + "/**/*.csv").sort("x").to_pydict() + assert read_back == data + + +def test_csv_partitioned_write_with_some_empty_partitions(tmp_path): + data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} + output_files = daft.from_pydict(data).into_partitions(4).write_csv(tmp_path, partition_cols=["x"]) + + assert len(output_files) == 3 + + read_back = daft.read_csv(tmp_path.as_posix() + "/**/*.csv").sort("x").to_pydict() + assert read_back == data diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index b0bdbf9df4..5e79acf698 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -53,6 +53,17 @@ def test_columns_after_join(make_df): assert set(joined_df2.schema().column_names()) == set(["A", "B"]) +def test_rename_join_keys_in_dataframe(make_df): + df1 = make_df({"A": [1, 2], "B": [2, 2]}) + + df2 = make_df({"A": [1, 2]}) + joined_df1 = df1.join(df2, left_on=["A", "B"], right_on=["A", "A"]) + joined_df2 = df1.join(df2, left_on=["B", "A"], right_on=["A", "A"]) + + assert set(joined_df1.schema().column_names()) == set(["A", "B"]) + assert set(joined_df2.schema().column_names()) == set(["A", "B"]) + + @pytest.mark.parametrize("n_partitions", [1, 2, 4]) @pytest.mark.parametrize( "join_strategy", diff --git a/tests/dataframe/test_stddev.py b/tests/dataframe/test_stddev.py new file mode 100644 index 0000000000..464d20bd41 --- /dev/null +++ b/tests/dataframe/test_stddev.py @@ -0,0 +1,144 @@ +import functools +import math +from typing import Any, List, Tuple + +import pandas as pd +import pytest + +import daft + + +def grouped_stddev(rows) -> Tuple[List[Any], List[Any]]: + map = {} + for key, data in rows: + if key not in map: + map[key] = [] + map[key].append(data) + + keys = [] + stddevs = [] + for key, nums in map.items(): + keys.append(key) + stddevs.append(stddev(nums)) + + return keys, stddevs + + +def stddev(nums) -> float: + nums = [num for num in nums if num is not None] + + if not nums: + return 0.0 + sum_: float = sum(nums) + count = len(nums) + mean = sum_ / count + + squared_sums = functools.reduce(lambda acc, num: acc + (num - mean) ** 2, nums, 0) + stddev = math.sqrt(squared_sums / count) + return stddev + + +TESTS = [ + [nums := [0], stddev(nums)], + [nums := [1], stddev(nums)], + [nums := [0, 1, 2], stddev(nums)], + [nums := [100, 100, 100], stddev(nums)], + [nums := [None, 100, None], stddev(nums)], + [nums := [None] * 10 + [100], stddev(nums)], +] + + +@pytest.mark.parametrize("data_and_expected", TESTS) +def test_stddev_with_single_partition(data_and_expected): + data, expected = data_and_expected + df = daft.from_pydict({"a": data}) + result = df.agg(daft.col("a").stddev()).collect() + rows = result.iter_rows() + stddev = next(rows) + try: + next(rows) + assert False + except StopIteration: + pass + + assert stddev["a"] == expected + + +@pytest.mark.parametrize("data_and_expected", TESTS) +def test_stddev_with_multiple_partitions(data_and_expected): + data, expected = data_and_expected + df = daft.from_pydict({"a": data}).into_partitions(2) + result = df.agg(daft.col("a").stddev()).collect() + rows = result.iter_rows() + stddev = next(rows) + try: + next(rows) + assert False + except StopIteration: + pass + + assert stddev["a"] == expected + + +GROUPED_TESTS = [ + [rows := [("k1", 0), ("k2", 1), ("k1", 1)], *grouped_stddev(rows)], + [rows := [("k0", 100), ("k1", 100), ("k2", 100)], *grouped_stddev(rows)], + [rows := [("k0", 100), ("k0", 100), ("k0", 100)], *grouped_stddev(rows)], + [rows := [("k0", 0), ("k0", 1), ("k0", 2)], *grouped_stddev(rows)], + [rows := [("k0", None), ("k0", None), ("k0", 100)], *grouped_stddev(rows)], +] + + +def unzip_rows(rows: list) -> Tuple[List, List]: + keys = [] + nums = [] + for key, data in rows: + keys.append(key) + nums.append(data) + return keys, nums + + +@pytest.mark.parametrize("data_and_expected", GROUPED_TESTS) +def test_grouped_stddev_with_single_partition(data_and_expected): + nums, expected_keys, expected_stddevs = data_and_expected + expected_df = daft.from_pydict({"keys": expected_keys, "data": expected_stddevs}) + keys, data = unzip_rows(nums) + df = daft.from_pydict({"keys": keys, "data": data}) + result_df = df.groupby("keys").agg(daft.col("data").stddev()).collect() + + result = result_df.to_pydict() + expected = expected_df.to_pydict() + + pd.testing.assert_series_equal( + pd.Series(result["keys"]).sort_values(), + pd.Series(expected["keys"]).sort_values(), + check_index=False, + ) + pd.testing.assert_series_equal( + pd.Series(result["data"]).sort_values(), + pd.Series(expected["data"]).sort_values(), + check_index=False, + ) + + +@pytest.mark.parametrize("data_and_expected", GROUPED_TESTS) +def test_grouped_stddev_with_multiple_partitions(data_and_expected): + nums, expected_keys, expected_stddevs = data_and_expected + expected_df = daft.from_pydict({"keys": expected_keys, "data": expected_stddevs}) + keys, data = unzip_rows(nums) + df = daft.from_pydict({"keys": keys, "data": data}).into_partitions(2) + result_df = df.groupby("keys").agg(daft.col("data").stddev()).collect() + + result = result_df.to_pydict() + expected = expected_df.to_pydict() + + pd.testing.assert_series_equal( + pd.Series(result["keys"]).sort_values(), + pd.Series(expected["keys"]).sort_values(), + check_index=False, + ) + pd.testing.assert_series_equal( + pd.Series(result["data"]).sort_values(), + pd.Series(expected["data"]).sort_values(), + check_index=False, + ) diff --git a/tests/dataframe/test_temporals.py b/tests/dataframe/test_temporals.py index 8843028b01..599e63eaf9 100644 --- a/tests/dataframe/test_temporals.py +++ b/tests/dataframe/test_temporals.py @@ -152,6 +152,33 @@ def test_python_duration() -> None: assert res == duration +def test_temporal_arithmetic_with_duration_lit() -> None: + df = daft.from_pydict( + { + "duration": [timedelta(days=1)], + "date": [datetime(2021, 1, 1)], + "timestamp": [datetime(2021, 1, 1)], + } + ) + + df = df.select( + (df["date"] + timedelta(days=1)).alias("add_date"), + (df["date"] - timedelta(days=1)).alias("sub_date"), + (df["timestamp"] + timedelta(days=1)).alias("add_timestamp"), + (df["timestamp"] - timedelta(days=1)).alias("sub_timestamp"), + (df["duration"] + timedelta(days=1)).alias("add_dur"), + (df["duration"] - timedelta(days=1)).alias("sub_dur"), + ) + + result = df.to_pydict() + assert result["add_date"] == [datetime(2021, 1, 2)] + assert result["sub_date"] == [datetime(2020, 12, 31)] + assert result["add_timestamp"] == [datetime(2021, 1, 2)] + assert result["sub_timestamp"] == [datetime(2020, 12, 31)] + assert result["add_dur"] == [timedelta(days=2)] + assert result["sub_dur"] == [timedelta(0)] + + @pytest.mark.parametrize( "timeunit", ["s", "ms", "us", "ns"], diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index d3727c2ac3..755abf5502 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -6,6 +6,7 @@ import pytest import pytz +import daft from daft.datatype import DataType, TimeUnit from daft.expressions import col, lit from daft.expressions.testing import expr_structurally_equal @@ -504,7 +505,108 @@ def test_datetime_lit_different_timeunits(timeunit, expected) -> None: assert timestamp_repr == expected +@pytest.mark.parametrize( + "input, expected", + [ + ( + timedelta(days=1), + "lit(1d)", + ), + ( + timedelta(days=1, hours=12, minutes=30, seconds=59), + "lit(1d 12h 30m 59s)", + ), + ( + timedelta(days=1, hours=12, minutes=30, seconds=59, microseconds=123456), + "lit(1d 12h 30m 59s 123456µs)", + ), + ( + timedelta(weeks=1, days=1, hours=12, minutes=30, seconds=59, microseconds=123456), + "lit(8d 12h 30m 59s 123456µs)", + ), + ], +) +def test_duration_lit(input, expected) -> None: + d = lit(input) + output = repr(d) + assert output == expected + + def test_repr_series_lit() -> None: s = lit(Series.from_pylist([1, 2, 3])) output = repr(s) assert output == "lit([1, 2, 3])" + + +def test_list_value_counts(): + # Create a MicroPartition with a list column + mp = MicroPartition.from_pydict( + {"list_col": [["a", "b", "a", "c"], ["b", "b", "c"], ["a", "a", "a"], [], ["d", None, "d"]]} + ) + + # Apply list_value_counts operation + result = mp.eval_expression_list([col("list_col").list.value_counts().alias("value_counts")]) + value_counts = result.to_pydict()["value_counts"] + + # Expected output + expected = [[("a", 2), ("b", 1), ("c", 1)], [("b", 2), ("c", 1)], [("a", 3)], [], [("d", 2)]] + + # Check the result + assert value_counts == expected + + # Test with empty input (no proper type -> should raise error) + empty_mp = MicroPartition.from_pydict({"list_col": []}) + with pytest.raises(ValueError): + empty_mp.eval_expression_list([col("list_col").list.value_counts().alias("value_counts")]) + + # Test with empty input (no proper type -> should raise error) + none_mp = MicroPartition.from_pydict({"list_col": [None, None, None]}) + with pytest.raises(ValueError): + none_mp.eval_expression_list([col("list_col").list.value_counts().alias("value_counts")]) + + +def test_list_value_counts_nested(): + # Create a MicroPartition with a nested list column + mp = MicroPartition.from_pydict( + { + "nested_list_col": [ + [[1, 2], [3, 4]], + [[1, 2], [5, 6]], + [[3, 4], [1, 2]], + [], + None, + [[1, 2], [1, 2]], + ] + } + ) + + # Apply list_value_counts operation and expect an exception + with pytest.raises(daft.exceptions.DaftCoreException) as exc_info: + mp.eval_expression_list([col("nested_list_col").list.value_counts().alias("value_counts")]) + + # Check the exception message + assert ( + 'DaftError::ArrowError Invalid argument error: The data type type LargeList(Field { name: "item", data_type: Int64, is_nullable: true, metadata: {} }) has no natural order' + in str(exc_info.value) + ) + + +def test_list_value_counts_degenerate(): + import pyarrow as pa + + # Create a MicroPartition with an empty list column of specified type + empty_mp = MicroPartition.from_pydict({"empty_list_col": pa.array([], type=pa.list_(pa.string()))}) + + # Apply list_value_counts operation + result = empty_mp.eval_expression_list([col("empty_list_col").list.value_counts().alias("value_counts")]) + + # Check the result + assert result.to_pydict() == {"value_counts": []} + + # Test with null values + null_mp = MicroPartition.from_pydict({"null_list_col": pa.array([None, None], type=pa.list_(pa.string()))}) + + result_null = null_mp.eval_expression_list([col("null_list_col").list.value_counts().alias("value_counts")]) + + # Check the result for null values + assert result_null.to_pydict() == {"value_counts": [[], []]} diff --git a/tests/expressions/test_udf.py b/tests/expressions/test_udf.py index 2572eb1adc..5ac09387e0 100644 --- a/tests/expressions/test_udf.py +++ b/tests/expressions/test_udf.py @@ -4,7 +4,9 @@ import pyarrow as pa import pytest +import daft from daft import col +from daft.context import get_context, set_planning_config from daft.datatype import DataType from daft.expressions import Expression from daft.expressions.testing import expr_structurally_equal @@ -13,6 +15,21 @@ from daft.udf import udf +@pytest.fixture(scope="function", params=[False, True]) +def actor_pool_enabled(request): + if request.param and get_context().daft_execution_config.enable_native_executor: + pytest.skip("Native executor does not support stateful UDFs") + + original_config = get_context().daft_planning_config + try: + set_planning_config( + config=get_context().daft_planning_config.with_config_values(enable_actor_pool_projections=request.param) + ) + yield request.param + finally: + set_planning_config(config=original_config) + + def test_udf(): table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) @@ -30,8 +47,8 @@ def repeat_n(data, n): @pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10]) -def test_class_udf(batch_size): - table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) +def test_class_udf(batch_size, actor_pool_enabled): + df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string(), batch_size=batch_size) class RepeatN: @@ -41,18 +58,21 @@ def __init__(self): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + if actor_pool_enabled: + RepeatN = RepeatN.with_concurrency(1) + expr = RepeatN(col("a")) - field = expr._to_field(table.schema()) + field = expr._to_field(df.schema()) assert field.name == "a" assert field.dtype == DataType.string() - result = table.eval_expression_list([expr]) + result = df.select(expr) assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} @pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10]) -def test_class_udf_init_args(batch_size): - table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) +def test_class_udf_init_args(batch_size, actor_pool_enabled): + df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string(), batch_size=batch_size) class RepeatN: @@ -62,24 +82,27 @@ def __init__(self, initial_n: int = 2): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + if actor_pool_enabled: + RepeatN = RepeatN.with_concurrency(1) + expr = RepeatN(col("a")) - field = expr._to_field(table.schema()) + field = expr._to_field(df.schema()) assert field.name == "a" assert field.dtype == DataType.string() - result = table.eval_expression_list([expr]) + result = df.select(expr) assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} expr = RepeatN.with_init_args(initial_n=3)(col("a")) - field = expr._to_field(table.schema()) + field = expr._to_field(df.schema()) assert field.name == "a" assert field.dtype == DataType.string() - result = table.eval_expression_list([expr]) + result = df.select(expr) assert result.to_pydict() == {"a": ["foofoofoo", "barbarbar", "bazbazbaz"]} @pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10]) -def test_class_udf_init_args_no_default(batch_size): - table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) +def test_class_udf_init_args_no_default(batch_size, actor_pool_enabled): + df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string(), batch_size=batch_size) class RepeatN: @@ -89,18 +112,21 @@ def __init__(self, initial_n): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + if actor_pool_enabled: + RepeatN = RepeatN.with_concurrency(1) + with pytest.raises(ValueError, match="Cannot call StatefulUDF without initialization arguments."): RepeatN(col("a")) expr = RepeatN.with_init_args(initial_n=2)(col("a")) - field = expr._to_field(table.schema()) + field = expr._to_field(df.schema()) assert field.name == "a" assert field.dtype == DataType.string() - result = table.eval_expression_list([expr]) + result = df.select(expr) assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} -def test_class_udf_init_args_bad_args(): +def test_class_udf_init_args_bad_args(actor_pool_enabled): @udf(return_dtype=DataType.string()) class RepeatN: def __init__(self, initial_n): @@ -109,10 +135,37 @@ def __init__(self, initial_n): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + if actor_pool_enabled: + RepeatN = RepeatN.with_concurrency(1) + with pytest.raises(TypeError, match="missing a required argument: 'initial_n'"): RepeatN.with_init_args(wrong=5) +@pytest.mark.parametrize("concurrency", [1, 2, 4]) +@pytest.mark.parametrize("actor_pool_enabled", [True], indirect=True) +def test_stateful_udf_concurrency(concurrency, actor_pool_enabled): + df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) + + @udf(return_dtype=DataType.string(), batch_size=1) + class RepeatN: + def __init__(self): + self.n = 2 + + def __call__(self, data): + return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + + RepeatN = RepeatN.with_concurrency(concurrency) + + expr = RepeatN(col("a")) + field = expr._to_field(df.schema()) + assert field.name == "a" + assert field.dtype == DataType.string() + + result = df.select(expr) + assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} + + def test_udf_kwargs(): table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) @@ -208,8 +261,8 @@ def full_udf(e_arg, val, kwarg_val=None, kwarg_ex=None): full_udf() -def test_class_udf_initialization_error(): - table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) +def test_class_udf_initialization_error(actor_pool_enabled): + df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string()) class IdentityWithInitError: @@ -219,9 +272,16 @@ def __init__(self): def __call__(self, data): return data + if actor_pool_enabled: + IdentityWithInitError = IdentityWithInitError.with_concurrency(1) + expr = IdentityWithInitError(col("a")) - with pytest.raises(RuntimeError, match="UDF INIT ERROR"): - table.eval_expression_list([expr]) + if actor_pool_enabled: + with pytest.raises(Exception): + df.select(expr).collect() + else: + with pytest.raises(RuntimeError, match="UDF INIT ERROR"): + df.select(expr).collect() def test_udf_equality(): diff --git a/tests/integration/iceberg/test_table_load.py b/tests/integration/iceberg/test_table_load.py index 2252f446d6..5cc1f2bf7d 100644 --- a/tests/integration/iceberg/test_table_load.py +++ b/tests/integration/iceberg/test_table_load.py @@ -24,7 +24,7 @@ def test_daft_iceberg_table_open(local_iceberg_tables): WORKING_SHOW_COLLECT = [ - "test_all_types", + # "test_all_types", # Commented out due to issue https://github.com/Eventual-Inc/Daft/issues/2996 "test_limit", "test_null_nan", "test_null_nan_rewritten", diff --git a/tests/io/delta_lake/test_table_read.py b/tests/io/delta_lake/test_table_read.py index 9cb5881a72..273006659f 100644 --- a/tests/io/delta_lake/test_table_read.py +++ b/tests/io/delta_lake/test_table_read.py @@ -94,3 +94,25 @@ def test_deltalake_read_row_group_splits_with_limit(tmp_path, base_table): df = df.limit(2) df.collect() assert len(df) == 2, "Length of non-materialized data when read through deltalake should be correct" + + +def test_deltalake_read_versioned(tmp_path, base_table): + deltalake = pytest.importorskip("deltalake") + path = tmp_path / "some_table" + deltalake.write_deltalake(path, base_table) + + updated_columns = base_table.columns + [pa.array(["x", "y", "z"])] + updated_column_names = base_table.column_names + ["new_column"] + updated_table = pa.Table.from_arrays(updated_columns, names=updated_column_names) + deltalake.write_deltalake(path, updated_table, mode="overwrite", schema_mode="overwrite") + + for version in [None, 1]: + df = daft.read_deltalake(str(path), version=version) + expected_schema = Schema.from_pyarrow_schema(deltalake.DeltaTable(path).schema().to_pyarrow()) + assert df.schema() == expected_schema + assert_pyarrow_tables_equal(df.to_arrow(), updated_table) + + df = daft.read_deltalake(str(path), version=0) + expected_schema = Schema.from_pyarrow_schema(deltalake.DeltaTable(path, version=0).schema().to_pyarrow()) + assert df.schema() == expected_schema + assert_pyarrow_tables_equal(df.to_arrow(), base_table) diff --git a/tests/io/delta_lake/test_table_write.py b/tests/io/delta_lake/test_table_write.py index 6519e85d0f..7a65d835cb 100644 --- a/tests/io/delta_lake/test_table_write.py +++ b/tests/io/delta_lake/test_table_write.py @@ -185,6 +185,21 @@ def test_deltalake_write_ignore(tmp_path): assert read_delta.to_pyarrow_table() == df1.to_arrow() +def test_deltalake_write_with_empty_partition(tmp_path, base_table): + deltalake = pytest.importorskip("deltalake") + path = tmp_path / "some_table" + df = daft.from_arrow(base_table).into_partitions(4) + result = df.write_deltalake(str(path)) + result = result.to_pydict() + assert result["operation"] == ["ADD", "ADD", "ADD"] + assert result["rows"] == [1, 1, 1] + + read_delta = deltalake.DeltaTable(str(path)) + expected_schema = Schema.from_pyarrow_schema(read_delta.schema().to_pyarrow()) + assert df.schema() == expected_schema + assert read_delta.to_pyarrow_table() == base_table + + def check_equal_both_daft_and_delta_rs(df: daft.DataFrame, path: Path, sort_order: list[tuple[str, str]]): deltalake = pytest.importorskip("deltalake") @@ -256,6 +271,16 @@ def test_deltalake_write_partitioned_empty(tmp_path): check_equal_both_daft_and_delta_rs(df, path, [("int", "ascending")]) +def test_deltalake_write_partitioned_some_empty(tmp_path): + path = tmp_path / "some_table" + + df = daft.from_pydict({"int": [1, 2, 3, None], "string": ["foo", "foo", "bar", None]}).into_partitions(5) + + df.write_deltalake(str(path), partition_cols=["int"]) + + check_equal_both_daft_and_delta_rs(df, path, [("int", "ascending")]) + + def test_deltalake_write_partitioned_existing_table(tmp_path): path = tmp_path / "some_table" diff --git a/tests/io/iceberg/test_iceberg_writes.py b/tests/io/iceberg/test_iceberg_writes.py index aab68a0c5c..7f85447dba 100644 --- a/tests/io/iceberg/test_iceberg_writes.py +++ b/tests/io/iceberg/test_iceberg_writes.py @@ -209,6 +209,18 @@ def test_read_after_write_nested_fields(local_catalog): assert as_arrow == read_back.to_arrow() +def test_read_after_write_with_empty_partition(local_catalog): + df = daft.from_pydict({"x": [1, 2, 3]}).into_partitions(4) + as_arrow = df.to_arrow() + table = local_catalog.create_table("default.test", as_arrow.schema) + result = df.write_iceberg(table) + as_dict = result.to_pydict() + assert as_dict["operation"] == ["ADD", "ADD", "ADD"] + assert as_dict["rows"] == [1, 1, 1] + read_back = daft.read_iceberg(table) + assert as_arrow == read_back.to_arrow() + + @pytest.fixture def complex_table() -> tuple[pa.Table, Schema]: table = pa.table( diff --git a/tests/io/test_parquet_roundtrip.py b/tests/io/test_parquet_roundtrip.py index 6904805831..292c5b98e1 100644 --- a/tests/io/test_parquet_roundtrip.py +++ b/tests/io/test_parquet_roundtrip.py @@ -112,15 +112,31 @@ def test_roundtrip_temporal_arrow_types(tmp_path, data, pa_type, expected_dtype) def test_roundtrip_tensor_types(tmp_path): - expected_dtype = DataType.tensor(DataType.int64()) - data = [np.array([[1, 2], [3, 4]]), None, None] - before = daft.from_pydict({"foo": Series.from_pylist(data)}) - before = before.concat(before) - before.write_parquet(str(tmp_path)) - after = daft.read_parquet(str(tmp_path)) - assert before.schema()["foo"].dtype == expected_dtype - assert after.schema()["foo"].dtype == expected_dtype - assert before.to_arrow() == after.to_arrow() + # Define the expected data type for the tensor column + expected_tensor_dtype = DataType.tensor(DataType.int64()) + + # Create sample tensor data with some null values + tensor_data = [np.array([[1, 2], [3, 4]]), None, None] + + # Create a Daft DataFrame with the tensor data + df_original = daft.from_pydict({"tensor_col": Series.from_pylist(tensor_data)}) + + # Double the size of the DataFrame to ensure we test with more data + df_original = df_original.concat(df_original) + + assert df_original.schema()["tensor_col"].dtype == expected_tensor_dtype + + # Write the DataFrame to a Parquet file + df_original.write_parquet(str(tmp_path)) + + # Read the Parquet file back into a new DataFrame + df_roundtrip = daft.read_parquet(str(tmp_path)) + + # Verify that the data type is preserved after the roundtrip + assert df_roundtrip.schema()["tensor_col"].dtype == expected_tensor_dtype + + # Ensure the data content is identical after the roundtrip + assert df_original.to_arrow() == df_roundtrip.to_arrow() @pytest.mark.parametrize("fixed_shape", [True, False]) diff --git a/tests/series/test_cast.py b/tests/series/test_cast.py index cd0e74bdfa..eb53334ea8 100644 --- a/tests/series/test_cast.py +++ b/tests/series/test_cast.py @@ -262,7 +262,7 @@ def test_cast_binary_to_fixed_size_binary(): assert casted.to_pylist() == [b"abc", b"def", None, b"bcd", None] -def test_cast_binary_to_fixed_size_binary_fails_with_variable_lengths(): +def test_cast_binary_to_fixed_size_binary_fails_with_variable_length(): data = [b"abc", b"def", None, b"bcd", None, b"long"] input = Series.from_pylist(data) @@ -368,7 +368,7 @@ def test_series_cast_python_to_list(dtype) -> None: assert t.datatype() == target_dtype assert len(t) == len(data) - assert t.list.lengths().to_pylist() == [3, 3, 3, 3, 2, 2, None] + assert t.list.length().to_pylist() == [3, 3, 3, 3, 2, 2, None] pydata = t.to_pylist() assert pydata[-1] is None @@ -397,7 +397,7 @@ def test_series_cast_python_to_fixed_size_list(dtype) -> None: assert t.datatype() == target_dtype assert len(t) == len(data) - assert t.list.lengths().to_pylist() == [3, 3, 3, 3, 3, 3, None] + assert t.list.length().to_pylist() == [3, 3, 3, 3, 3, 3, None] pydata = t.to_pylist() assert pydata[-1] is None @@ -426,7 +426,7 @@ def test_series_cast_python_to_embedding(dtype) -> None: assert t.datatype() == target_dtype assert len(t) == len(data) - assert t.list.lengths().to_pylist() == [3, 3, 3, 3, 3, 3, None] + assert t.list.length().to_pylist() == [3, 3, 3, 3, 3, 3, None] pydata = t.to_pylist() assert pydata[-1] is None @@ -448,7 +448,7 @@ def test_series_cast_list_to_embedding(dtype) -> None: assert t.datatype() == target_dtype assert len(t) == len(data) - assert t.list.lengths().to_pylist() == [3, 3, 3, None] + assert t.list.length().to_pylist() == [3, 3, 3, None] pydata = t.to_pylist() assert pydata[-1] is None @@ -473,7 +473,7 @@ def test_series_cast_numpy_to_image() -> None: assert t.datatype() == target_dtype assert len(t) == len(data) - assert t.list.lengths().to_pylist() == [12, 27, None] + assert t.list.length().to_pylist() == [12, 27, None] pydata = t.to_pylist() assert pydata[-1] is None @@ -495,7 +495,7 @@ def test_series_cast_numpy_to_image_infer_mode() -> None: assert t.datatype() == target_dtype assert len(t) == len(data) - assert t.list.lengths().to_pylist() == [4, 27, None] + assert t.list.length().to_pylist() == [4, 27, None] pydata = t.to_arrow().to_pylist() assert pydata[0] == { @@ -536,7 +536,7 @@ def test_series_cast_python_to_fixed_shape_image() -> None: assert t.datatype() == target_dtype assert len(t) == len(data) - assert t.list.lengths().to_pylist() == [12, 12, None] + assert t.list.length().to_pylist() == [12, 12, None] pydata = t.to_pylist() assert pydata[-1] is None @@ -1161,3 +1161,38 @@ def test_series_cast_fixed_size_list_to_list() -> None: assert data.datatype() == DataType.fixed_size_list(DataType.int64(), 2) casted = data.cast(DataType.list(DataType.int64())) assert casted.to_pylist() == [[1, 2], [3, 4], [5, 6]] + + +### Sparse ### + + +def to_coo_sparse_dict(ndarray: np.ndarray) -> dict[str, np.ndarray]: + flat_array = ndarray.ravel() + indices = np.flatnonzero(flat_array).astype(np.uint64) + values = flat_array[indices] + shape = list(ndarray.shape) + return {"values": values, "indices": indices, "shape": shape} + + +def test_series_cast_sparse_to_python() -> None: + data = [np.zeros(shape=(1, 2), dtype=np.uint8), None, np.ones(shape=(2, 2), dtype=np.uint8)] + series = Series.from_pylist(data).cast(DataType.sparse_tensor(DataType.uint8())) + assert series.datatype() == DataType.sparse_tensor(DataType.uint8()) + + given = series.to_pylist() + expected = [to_coo_sparse_dict(ndarray) if ndarray is not None else None for ndarray in data] + np.testing.assert_equal(given, expected) + + +def test_series_cast_fixed_shape_sparse_to_python() -> None: + data = [np.zeros(shape=(2, 2), dtype=np.uint8), None, np.ones(shape=(2, 2), dtype=np.uint8)] + series = ( + Series.from_pylist(data) + .cast(DataType.tensor(DataType.uint8(), shape=(2, 2))) # TODO: direct cast to fixed shape sparse + .cast(DataType.sparse_tensor(DataType.uint8(), shape=(2, 2))) + ) + assert series.datatype() == DataType.sparse_tensor(DataType.uint8(), shape=(2, 2)) + + given = series.to_pylist() + expected = [to_coo_sparse_dict(ndarray) if ndarray is not None else None for ndarray in data] + np.testing.assert_equal(given, expected) diff --git a/tests/sql/test_exprs.py b/tests/sql/test_exprs.py index 4debfc0885..e3ae320094 100644 --- a/tests/sql/test_exprs.py +++ b/tests/sql/test_exprs.py @@ -1,4 +1,7 @@ +import pytest + import daft +from daft import col def test_nested(): @@ -20,3 +23,47 @@ def test_nested(): expected = df.with_column("try_this", df["A"] + 1).collect() assert actual.to_pydict() == expected.to_pydict() + + +def test_hash_exprs(): + df = daft.from_pydict( + { + "a": ["foo", "bar", "baz", "qux"], + "ints": [1, 2, 3, 4], + "floats": [1.5, 2.5, 3.5, 4.5], + } + ) + + actual = ( + daft.sql(""" + SELECT + hash(a) as hash_a, + hash(a, 0) as hash_a_0, + hash(a, seed:=0) as hash_a_seed_0, + minhash(a, num_hashes:=10, ngram_size:= 100, seed:=10) as minhash_a, + minhash(a, num_hashes:=10, ngram_size:= 100) as minhash_a_no_seed, + FROM df + """) + .collect() + .to_pydict() + ) + + expected = ( + df.select( + col("a").hash().alias("hash_a"), + col("a").hash(0).alias("hash_a_0"), + col("a").hash(seed=0).alias("hash_a_seed_0"), + col("a").minhash(num_hashes=10, ngram_size=100, seed=10).alias("minhash_a"), + col("a").minhash(num_hashes=10, ngram_size=100).alias("minhash_a_no_seed"), + ) + .collect() + .to_pydict() + ) + + assert actual == expected + + with pytest.raises(Exception, match="Invalid arguments for minhash"): + daft.sql("SELECT minhash() as hash_a FROM df").collect() + + with pytest.raises(Exception, match="num_hashes is required"): + daft.sql("SELECT minhash(a) as hash_a FROM df").collect() diff --git a/tests/sql/test_table_funcs.py b/tests/sql/test_table_funcs.py new file mode 100644 index 0000000000..16fbf040c7 --- /dev/null +++ b/tests/sql/test_table_funcs.py @@ -0,0 +1,7 @@ +import daft + + +def test_sql_read_parquet(): + df = daft.sql("SELECT * FROM read_parquet('tests/assets/parquet-data/mvp.parquet')").collect() + expected = daft.read_parquet("tests/assets/parquet-data/mvp.parquet").collect() + assert df.to_pydict() == expected.to_pydict() diff --git a/tests/sql/test_utf8_exprs.py b/tests/sql/test_utf8_exprs.py new file mode 100644 index 0000000000..12b53a9ebc --- /dev/null +++ b/tests/sql/test_utf8_exprs.py @@ -0,0 +1,113 @@ +import daft +from daft import col + + +def test_utf8_exprs(): + df = daft.from_pydict( + { + "a": [ + "a", + "df_daft", + "foo", + "bar", + "baz", + "lorém", + "ipsum", + "dolor", + "sit", + "amet", + "😊", + "🌟", + "🎉", + "This is a longer with some words", + "THIS is ", + "", + ], + } + ) + + sql = """ + SELECT + ends_with(a, 'a') as ends_with_a, + starts_with(a, 'a') as starts_with_a, + contains(a, 'a') as contains_a, + split(a, ' ') as split_a, + regexp_match(a, 'ba.') as match_a, + regexp_extract(a, 'ba.') as extract_a, + regexp_extract_all(a, 'ba.') as extract_all_a, + regexp_replace(a, 'ba.', 'foo') as replace_a, + regexp_split(a, '\\s+') as regexp_split_a, + length(a) as length_a, + length_bytes(a) as length_bytes_a, + lower(a) as lower_a, + lstrip(a) as lstrip_a, + rstrip(a) as rstrip_a, + reverse(a) as reverse_a, + capitalize(a) as capitalize_a, + left(a, 4) as left_a, + right(a, 4) as right_a, + find(a, 'a') as find_a, + rpad(a, 10, '<') as rpad_a, + lpad(a, 10, '>') as lpad_a, + repeat(a, 2) as repeat_a, + a like 'a%' as like_a, + a ilike 'a%' as ilike_a, + substring(a, 1, 3) as substring_a, + count_matches(a, 'a') as count_matches_a_0, + count_matches(a, 'a', case_sensitive := true) as count_matches_a_1, + count_matches(a, 'a', case_sensitive := false, whole_words := false) as count_matches_a_2, + count_matches(a, 'a', case_sensitive := true, whole_words := true) as count_matches_a_3, + normalize(a) as normalize_a, + normalize(a, remove_punct:=true) as normalize_remove_punct_a, + normalize(a, remove_punct:=true, lowercase:=true) as normalize_remove_punct_lower_a, + normalize(a, remove_punct:=true, lowercase:=true, white_space:=true) as normalize_remove_punct_lower_ws_a, + tokenize_encode(a, 'r50k_base') as tokenize_encode_a, + tokenize_decode(tokenize_encode(a, 'r50k_base'), 'r50k_base') as tokenize_decode_a + FROM df + """ + actual = daft.sql(sql).collect() + expected = ( + df.select( + col("a").str.endswith("a").alias("ends_with_a"), + col("a").str.startswith("a").alias("starts_with_a"), + col("a").str.contains("a").alias("contains_a"), + col("a").str.split(" ").alias("split_a"), + col("a").str.match("ba.").alias("match_a"), + col("a").str.extract("ba.").alias("extract_a"), + col("a").str.extract_all("ba.").alias("extract_all_a"), + col("a").str.split(r"\s+", regex=True).alias("regexp_split_a"), + col("a").str.replace("ba.", "foo").alias("replace_a"), + col("a").str.length().alias("length_a"), + col("a").str.length_bytes().alias("length_bytes_a"), + col("a").str.lower().alias("lower_a"), + col("a").str.lstrip().alias("lstrip_a"), + col("a").str.rstrip().alias("rstrip_a"), + col("a").str.reverse().alias("reverse_a"), + col("a").str.capitalize().alias("capitalize_a"), + col("a").str.left(4).alias("left_a"), + col("a").str.right(4).alias("right_a"), + col("a").str.find("a").alias("find_a"), + col("a").str.rpad(10, "<").alias("rpad_a"), + col("a").str.lpad(10, ">").alias("lpad_a"), + col("a").str.repeat(2).alias("repeat_a"), + col("a").str.like("a%").alias("like_a"), + col("a").str.ilike("a%").alias("ilike_a"), + col("a").str.substr(1, 3).alias("substring_a"), + col("a").str.count_matches("a").alias("count_matches_a_0"), + col("a").str.count_matches("a", case_sensitive=True).alias("count_matches_a_1"), + col("a").str.count_matches("a", case_sensitive=False, whole_words=False).alias("count_matches_a_2"), + col("a").str.count_matches("a", case_sensitive=True, whole_words=True).alias("count_matches_a_3"), + col("a").str.normalize().alias("normalize_a"), + col("a").str.normalize(remove_punct=True).alias("normalize_remove_punct_a"), + col("a").str.normalize(remove_punct=True, lowercase=True).alias("normalize_remove_punct_lower_a"), + col("a") + .str.normalize(remove_punct=True, lowercase=True, white_space=True) + .alias("normalize_remove_punct_lower_ws_a"), + col("a").str.tokenize_encode("r50k_base").alias("tokenize_encode_a"), + col("a").str.tokenize_encode("r50k_base").str.tokenize_decode("r50k_base").alias("tokenize_decode_a"), + ) + .collect() + .to_pydict() + ) + actual = actual.to_pydict() + assert actual == expected diff --git a/tests/table/list/test_list_count_lengths.py b/tests/table/list/test_list_count_length.py similarity index 83% rename from tests/table/list/test_list_count_lengths.py rename to tests/table/list/test_list_count_length.py index 321219d6da..7e48f3088d 100644 --- a/tests/table/list/test_list_count_lengths.py +++ b/tests/table/list/test_list_count_length.py @@ -50,3 +50,12 @@ def test_fixed_list_count(fixed_table): result = fixed_table.eval_expression_list([col("col").list.count(CountMode.Null)]) assert result.to_pydict() == {"col": [0, 0, 1, 2, None]} + + +def test_list_length(fixed_table): + with pytest.warns(DeprecationWarning): + lengths_result = fixed_table.eval_expression_list([col("col").list.lengths()]) + length_result = fixed_table.eval_expression_list([col("col").list.length()]) + + assert lengths_result.to_pydict() == {"col": [2, 2, 2, 2, None]} + assert length_result.to_pydict() == {"col": [2, 2, 2, 2, None]} diff --git a/tests/table/map/test_map_get.py b/tests/table/map/test_map_get.py index 6ab7a31ab8..9c7548d1e3 100644 --- a/tests/table/map/test_map_get.py +++ b/tests/table/map/test_map_get.py @@ -49,7 +49,8 @@ def test_map_get_logical_type(): ) table = MicroPartition.from_arrow(pa.table({"map_col": data})) - result = table.eval_expression_list([col("map_col").map.get("foo")]) + map = col("map_col").map + result = table.eval_expression_list([map.get("foo")]) assert result.to_pydict() == {"value": [datetime.date(2022, 1, 1), datetime.date(2022, 1, 2), None]} diff --git a/tests/table/test_table_aggs.py b/tests/table/test_table_aggs.py index 01749a1cdb..b71df2bfdd 100644 --- a/tests/table/test_table_aggs.py +++ b/tests/table/test_table_aggs.py @@ -669,6 +669,32 @@ def test_global_pyobj_list_aggs() -> None: assert result.to_pydict()["list"][0] == input +def test_global_list_list_aggs() -> None: + input = [[1], [2, 3, 4], [5, None], [], None] + table = MicroPartition.from_pydict({"input": input}) + result = table.eval_expression_list([col("input").alias("list").agg_list()]) + assert result.get_column("list").datatype() == DataType.list(DataType.list(DataType.int64())) + assert result.to_pydict()["list"][0] == input + + +def test_global_fixed_size_list_list_aggs() -> None: + input = Series.from_pylist([[1, 2], [3, 4], [5, None], None]).cast(DataType.fixed_size_list(DataType.int64(), 2)) + table = MicroPartition.from_pydict({"input": input}) + result = table.eval_expression_list([col("input").alias("list").agg_list()]) + assert result.get_column("list").datatype() == DataType.list(DataType.fixed_size_list(DataType.int64(), 2)) + assert result.to_pydict()["list"][0] == [[1, 2], [3, 4], [5, None], None] + + +def test_global_struct_list_aggs() -> None: + input = [{"a": 1, "b": 2}, {"a": 3, "b": None}, None] + table = MicroPartition.from_pydict({"input": input}) + result = table.eval_expression_list([col("input").alias("list").agg_list()]) + assert result.get_column("list").datatype() == DataType.list( + DataType.struct({"a": DataType.int64(), "b": DataType.int64()}) + ) + assert result.to_pydict()["list"][0] == input + + @pytest.mark.parametrize( "dtype", daft_nonnull_types + daft_null_types, ids=[f"{_}" for _ in daft_nonnull_types + daft_null_types] ) @@ -701,6 +727,58 @@ def test_grouped_pyobj_list_aggs() -> None: assert result.to_pydict() == {"groups": [1, 2, None], "list": expected_groups} +def test_grouped_list_list_aggs() -> None: + groups = [None, 1, None, 1, 2, 2] + input = [[1], [2, 3, 4], [5, None], None, [], [8, 9]] + expected_idx = [[1, 3], [4, 5], [0, 2]] + + daft_table = MicroPartition.from_pydict({"groups": groups, "input": input}) + daft_table = daft_table.eval_expression_list([col("groups"), col("input")]) + result = daft_table.agg([col("input").alias("list").agg_list()], group_by=[col("groups")]).sort([col("groups")]) + assert result.get_column("list").datatype() == DataType.list(DataType.list(DataType.int64())) + + input_as_dtype = daft_table.get_column("input").to_pylist() + expected_groups = [[input_as_dtype[i] for i in group] for group in expected_idx] + + assert result.to_pydict() == {"groups": [1, 2, None], "list": expected_groups} + + +def test_grouped_fixed_size_list_list_aggs() -> None: + groups = [None, 1, None, 1, 2, 2] + input = Series.from_pylist([[1, 2], [3, 4], [5, None], None, [6, 7], [8, 9]]).cast( + DataType.fixed_size_list(DataType.int64(), 2) + ) + expected_idx = [[1, 3], [4, 5], [0, 2]] + + daft_table = MicroPartition.from_pydict({"groups": groups, "input": input}) + daft_table = daft_table.eval_expression_list([col("groups"), col("input")]) + result = daft_table.agg([col("input").alias("list").agg_list()], group_by=[col("groups")]).sort([col("groups")]) + assert result.get_column("list").datatype() == DataType.list(DataType.fixed_size_list(DataType.int64(), 2)) + + input_as_dtype = daft_table.get_column("input").to_pylist() + expected_groups = [[input_as_dtype[i] for i in group] for group in expected_idx] + + assert result.to_pydict() == {"groups": [1, 2, None], "list": expected_groups} + + +def test_grouped_struct_list_aggs() -> None: + groups = [None, 1, None, 1, 2, 2] + input = [{"x": 1, "y": 2}, {"x": 3, "y": 4}, {"x": 5, "y": None}, None, {"x": 6, "y": 7}, {"x": 8, "y": 9}] + expected_idx = [[1, 3], [4, 5], [0, 2]] + + daft_table = MicroPartition.from_pydict({"groups": groups, "input": input}) + daft_table = daft_table.eval_expression_list([col("groups"), col("input")]) + result = daft_table.agg([col("input").alias("list").agg_list()], group_by=[col("groups")]).sort([col("groups")]) + assert result.get_column("list").datatype() == DataType.list( + DataType.struct({"x": DataType.int64(), "y": DataType.int64()}) + ) + + input_as_dtype = daft_table.get_column("input").to_pylist() + expected_groups = [[input_as_dtype[i] for i in group] for group in expected_idx] + + assert result.to_pydict() == {"groups": [1, 2, None], "list": expected_groups} + + def test_list_aggs_empty() -> None: daft_table = MicroPartition.from_pydict({"col_A": [], "col_B": []}) daft_table = daft_table.agg( diff --git a/tests/test_resource_requests.py b/tests/test_resource_requests.py index 20a2aadf21..ec867aada7 100644 --- a/tests/test_resource_requests.py +++ b/tests/test_resource_requests.py @@ -8,7 +8,7 @@ import daft from daft import context, udf -from daft.context import get_context +from daft.context import get_context, set_planning_config from daft.daft import SystemInfo from daft.expressions import col from daft.internal.gpu import cuda_device_count @@ -127,6 +127,19 @@ def test_requesting_too_much_memory(): ### +@pytest.fixture(scope="function", params=[True]) +def enable_actor_pool(): + try: + original_config = get_context().daft_planning_config + + set_planning_config( + config=get_context().daft_planning_config.with_config_values(enable_actor_pool_projections=True) + ) + yield + finally: + set_planning_config(config=original_config) + + @udf(return_dtype=daft.DataType.int64()) def assert_resources(c, num_cpus=None, num_gpus=None, memory=None): assigned_resources = ray.get_runtime_context().get_assigned_resources() @@ -141,6 +154,24 @@ def assert_resources(c, num_cpus=None, num_gpus=None, memory=None): return c +@udf(return_dtype=daft.DataType.int64()) +class AssertResourcesStateful: + def __init__(self): + pass + + def __call__(self, c, num_cpus=None, num_gpus=None, memory=None): + assigned_resources = ray.get_runtime_context().get_assigned_resources() + + for resource, ray_resource_key in [(num_cpus, "CPU"), (num_gpus, "GPU"), (memory, "memory")]: + if resource is None: + assert ray_resource_key not in assigned_resources or assigned_resources[ray_resource_key] is None + else: + assert ray_resource_key in assigned_resources + assert assigned_resources[ray_resource_key] == resource + + return c + + RAY_VERSION_LT_2 = int(ray.__version__.split(".")[0]) < 2 @@ -187,6 +218,51 @@ def test_with_column_folded_rayrunner(): df.collect() +@pytest.mark.skipif( + RAY_VERSION_LT_2, reason="The ray.get_runtime_context().get_assigned_resources() was only added in Ray >= 2.0" +) +@pytest.mark.skipif(get_context().runner_config.name not in {"ray"}, reason="requires RayRunner to be in use") +def test_with_column_rayrunner_class(enable_actor_pool): + assert_resources = AssertResourcesStateful.with_concurrency(1) + + df = daft.from_pydict(DATA).repartition(2) + + assert_resources_parametrized = assert_resources.override_options(num_cpus=1, memory_bytes=1_000_000, num_gpus=None) + df = df.with_column( + "resources_ok", + assert_resources_parametrized(col("id"), num_cpus=1, num_gpus=None, memory=1_000_000), + ) + + df.collect() + + +@pytest.mark.skipif( + RAY_VERSION_LT_2, reason="The ray.get_runtime_context().get_assigned_resources() was only added in Ray >= 2.0" +) +@pytest.mark.skipif(get_context().runner_config.name not in {"ray"}, reason="requires RayRunner to be in use") +def test_with_column_folded_rayrunner_class(enable_actor_pool): + assert_resources = AssertResourcesStateful.with_concurrency(1) + + df = daft.from_pydict(DATA).repartition(2) + + df = df.with_column( + "no_requests", + assert_resources(col("id"), num_cpus=1), # UDFs have 1 CPU by default + ) + + assert_resources_1 = assert_resources.override_options(num_cpus=1, memory_bytes=5_000_000) + df = df.with_column( + "more_memory_request", + assert_resources_1(col("id"), num_cpus=1, memory=5_000_000), + ) + assert_resources_2 = assert_resources.override_options(num_cpus=1, memory_bytes=None) + df = df.with_column( + "more_cpu_request", + assert_resources_2(col("id"), num_cpus=1), + ) + df.collect() + + ### # GPU tests - can only run if machine has a GPU ### diff --git a/tutorials/delta_lake/2-distributed-batch-inference.ipynb b/tutorials/delta_lake/2-distributed-batch-inference.ipynb index 8462a74e0f..478980b9d9 100644 --- a/tutorials/delta_lake/2-distributed-batch-inference.ipynb +++ b/tutorials/delta_lake/2-distributed-batch-inference.ipynb @@ -138,7 +138,7 @@ "\n", "# Prune data\n", "df = df.limit(NUM_ROWS)\n", - "df = df.where(df[\"object\"].list.lengths() == 1)" + "df = df.where(df[\"object\"].list.length() == 1)" ] }, {
"; @@ -714,7 +720,7 @@ impl Table { // Begin row. res.push_str("
...